├── README.md
├── create_pretrain_data.sh
├── create_pretraining_data.py
├── modeling.py
├── optimization.py
├── optimization_finetuning.py
├── resources
├── RoBERTa_zh_Large_Learning_Curve.png
└── vocab.txt
├── run_classifier.py
├── run_pretraining.py
└── tokenization.py
/README.md:
--------------------------------------------------------------------------------
1 | RoBERTa for Chinese, TensorFlow & PyTorch
2 |
3 | 中文预训练RoBERTa模型
4 | -------------------------------------------------
5 | RoBERTa是BERT的改进版,通过改进训练任务和数据生成方式、训练更久、使用更大批次、使用更多数据等获得了State of The Art的效果;可以用Bert直接加载。
6 |
7 | 本项目是用TensorFlow实现了在大规模中文上RoBERTa的预训练,也会提供PyTorch的预训练模型和加载方式。
8 |
9 | *** 2019-10-12:添加【阅读理解】不同模型上测试效果对比 ***
10 |
11 | *** 2019-09-08: 添加国内下载地址、PyTorch版本、与多个模型bert-wwm、xlnet等模型效果初步对比 ***
12 |
13 |
14 | NLP自动标注工具(提效最多100X)-预约
15 |
16 | Pre-trained model of albert, chinese version is also available for you now.
17 |
18 | 中文预训练RoBERTa模型-下载
19 | -------------------------------------------------
20 | *** 6层RoBERTa体验版 ***
21 | RoBERTa-zh-Layer6: Google Drive 或 百度网盘,TensorFlow版本,Bert 直接加载, 大小为200M
22 |
23 | ###### ** 推荐 RoBERTa-zh-Large 通过验证**
24 | RoBERTa-zh-Large: Google Drive 或 百度网盘 ,TensorFlow版本,Bert 直接加载
25 |
26 | RoBERTa-zh-Large: Google Drive 或 百度网盘 ,PyTorch版本,Bert的PyTorch版直接加载
27 |
28 | RoBERTa 24/12层版训练数据:30G原始文本,近3亿个句子,100亿个中文字(token),产生了2.5亿个训练数据(instance);
29 |
30 | 覆盖新闻、社区问答、多个百科数据等;
31 |
32 | 本项目与中文预训练24层XLNet模型 XLNet_zh项目,使用相同的训练数据。
33 |
34 | RoBERTa_zh_L12: Google Drive 或 百度网盘 TensorFlow版本,Bert 直接加载
35 |
36 | RoBERTa_zh_L12: Google Drive 或百度网盘 PyTorch版本,Bert的PyTorch版直接加载
37 |
38 | ---------------------------------------------------------------
39 |
40 | Roberta_l24_zh_base TensorFlow版本,Bert 直接加载
41 |
42 | 24层base版训练数据:10G文本,包含新闻、社区问答、多个百科数据等。
43 |
44 |
45 |
46 | What is RoBERTa:
47 | -------------------------------------------------
48 | A robustly optimized method for pretraining natural language processing (NLP) systems that improves on Bidirectional Encoder Representations from Transformers, or BERT, the self-supervised method released by Google in 2018.
49 |
50 | RoBERTa, produces state-of-the-art results on the widely used NLP benchmark, General Language Understanding Evaluation (GLUE). The model delivered state-of-the-art performance on the MNLI, QNLI, RTE, STS-B, and RACE tasks and a sizable performance improvement on the GLUE benchmark. With a score of 88.5, RoBERTa reached the top position on the GLUE leaderboard, matching the performance of the previous leader, XLNet-Large.
51 |
52 | (Introduction from Facebook blog)
53 |
54 | 发布计划 Release Plan:
55 | -------------------------------------------------
56 | 1、24层RoBERTa模型(roberta_l24_zh),使用30G文件训练, 9月8日
57 |
58 | 2、12层RoBERTa模型(roberta_l12_zh),使用30G文件训练, 9月8日
59 |
60 | 3、6层RoBERTa模型(roberta_l6_zh), 使用30G文件训练, 9月8日
61 |
62 | 4、PyTorch版本的模型(roberta_l6_zh_pytorch) 9月8日
63 |
64 | 5、30G中文语料,预训练格式,可直接训练(bert,xlent,gpt2) 待定
65 |
66 | 6、测试集测试和效果对比 9月14日
67 |
68 | 效果测试与对比 Performance
69 | -------------------------------------------------
70 | ### 互联网新闻情感分析:CCF-Sentiment-Analysis
71 |
72 | | 模型 | 线上F1 |
73 | | :------- | :---------: |
74 | | BERT | 80.3 |
75 | | Bert-wwm-ext | 80.5 |
76 | | XLNet | 79.6 |
77 | | Roberta-mid | 80.5 |
78 | | Roberta-large (max_seq_length=512, split_num=1) | 81.25 |
79 |
80 | 注:数据来源于guoday的开源项目;数据集和任务介绍见:CCF互联网新闻情感分析
81 |
82 | ### 自然语言推断:XNLI
83 |
84 | | 模型 | 开发集 | 测试集 |
85 | | :------- | :---------: | :---------: |
86 | | BERT | 77.8 (77.4) | 77.8 (77.5) |
87 | | ERNIE | 79.7 (79.4) | 78.6 (78.2) |
88 | | BERT-wwm | 79.0 (78.4) | 78.2 (78.0) |
89 | | BERT-wwm-ext | 79.4 (78.6) | 78.7 (78.3) |
90 | | XLNet | 79.2 | 78.7 |
91 | | RoBERTa-zh-base | 79.8 |78.8 |
92 | | **RoBERTa-zh-Large** | **80.2 (80.0)** | **79.9 (79.5)** |
93 |
94 | 注:RoBERTa_l24_zh,只跑了两次,Performance可能还会提升;
95 |
96 | BERT-wwm-ext来自于这里;XLNet来自于这里; RoBERTa-zh-base,指12层RoBERTa中文模型
97 |
98 | ### 问题匹配语任务:LCQMC(Sentence Pair Matching)
99 |
100 | | 模型 | 开发集(Dev) | 测试集(Test) |
101 | | :------- | :---------: | :---------: |
102 | | BERT | 89.4(88.4) | 86.9(86.4) |
103 | | ERNIE | 89.8 (89.6) | **87.2** (87.0) |
104 | | BERT-wwm |89.4 (89.2) | 87.0 (86.8) |
105 | | BERT-wwm-ext | - |- |
106 | | RoBERTa-zh-base | 88.7 | 87.0 |
107 | | **RoBERTa-zh-Large** | **89.9**(89.6) | **87.2**(86.7) |
108 | | RoBERTa-zh-Large(20w_steps) | 89.7| 87.0 |
109 |
110 | 注:RoBERTa_l24_zh,只跑了两次,Performance可能还会提升。保持训练轮次和论文一致:
111 |
112 | ### 阅读理解测试
113 | 目前阅读理解类问题bert和roberta最优参数均为epoch2, batch=32, lr=3e-5, warmup=0.1
114 |
115 | #### cmrc2018(阅读理解)
116 |
117 | | models | DEV |
118 | | ------ | ------ |
119 | | sibert_base | F1:87.521(88.628) EM:67.381(69.152) |
120 | | sialbert_middle | F1:87.6956(87.878) EM:67.897(68.624) |
121 | | 哈工大讯飞 roberta_wwm_ext_base | F1:87.521(88.628) EM:67.381(69.152) |
122 | | brightmart roberta_middle | F1:86.841(87.242) EM:67.195(68.313) |
123 | | brightmart roberta_large | **F1:88.608(89.431) EM:69.935(72.538)** |
124 |
125 | #### DRCD(阅读理解)
126 |
127 | | models | DEV |
128 | | ------ | ------ |
129 | | siBert_base | F1:93.343(93.524) EM:87.968(88.28) |
130 | | siALBert_middle | F1:93.865(93.975) EM:88.723(88.961) |
131 | | 哈工大讯飞 roberta_wwm_ext_base | F1:94.257(94.48) EM:89.291(89.642) |
132 | | brightmart roberta_large | **F1:94.933(95.057) EM:90.113(90.238)** |
133 |
134 | #### CJRC(带有yes,no,unkown的阅读理解)
135 |
136 | | models | DEV |
137 | | ------ | ------ |
138 | | siBert_base | F1:80.714(81.14) EM:64.44(65.04) |
139 | | siALBert_middle | F1:80.9838(81.299) EM:63.796(64.202) |
140 | | 哈工大讯飞 roberta_wwm_ext_base | F1:81.510(81.684) EM:64.924(65.574) |
141 | | brightmart roberta_large | F1:80.16(80.475) EM:65.249(66.133) |
142 |
143 | 阅读理解测试对比数据来源bert_cn_finetune
144 |
145 | ? 处地方,将会很快更新到具体的值
146 |
147 | RoBERTa中文版 Chinese Version
148 | -------------------------------------------------
149 | 本项目所指的中文预训练RoBERTa模型只指按照RoBERTa论文主要精神训练的模型。包括:
150 |
151 | 1、数据生成方式和任务改进:取消下一个句子预测,并且数据连续从一个文档中获得(见:Model Input Format and Next Sentence Prediction,DOC-SENTENCES)
152 |
153 | 2、更大更多样性的数据:使用30G中文训练,包含3亿个句子,100亿个字(即token)。由新闻、社区讨论、多个百科,包罗万象,覆盖数十万个主题,
154 |
155 | 所以数据具有多样性(为了更有多样性,可以可以加入网络书籍、小说、故事类文学、微博等)。
156 |
157 | 3、训练更久:总共训练了近20万,总共见过近16亿个训练数据(instance); 在Cloud TPU v3-256 上训练了24小时,相当于在TPU v3-8(128G显存)上需要训练一个月。
158 |
159 | 4、更大批次:使用了超大(8k)的批次batch size。
160 |
161 | 5、调整优化器等超参数。
162 |
163 | 除以上外,本项目中文版,使用了全词mask(whole word mask)。在全词Mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask,即全词Mask。
164 |
165 | 本项目中并没有直接实现dynamic mask。通过复制一个训练样本得到多份数据,每份数据使用不同mask,并加大复制的份数,可间接得到dynamic mask效果。
166 |
167 | ##### 使用说明 Instructions for Use
168 |
169 | 当前本项目是使用sequence length为256训练的,所以可能对长度在这个范围内的效果不错;如果你的任务的输入比较长(如序列长度为512),或许效果有影响。
170 |
171 | 有同学结合滑动窗口的形式,将序列做拆分,还是得到了比较好的效果,见#issue-16
172 |
173 | ##### 中文全词遮蔽 Whole Word Mask
174 |
175 | | 说明 | 样例 |
176 | | :------- | :--------- |
177 | | 原始文本 | 使用语言模型来预测下一个词的probability。 |
178 | | 分词文本 | 使用 语言 模型 来 预测 下 一个 词 的 probability 。 |
179 | | 原始Mask输入 | 使 用 语 言 [MASK] 型 来 [MASK] 测 下 一 个 词 的 pro [MASK] ##lity 。 |
180 | | 全词Mask输入 | 使 用 语 言 [MASK] [MASK] 来 [MASK] [MASK] 下 一 个 词 的 [MASK] [MASK] [MASK] 。 |
181 |
182 | 模型加载(以Sentence Pair Matching即句子对任务,LCQMC为例)
183 | -------------------------------------------------
184 |
185 | 下载LCQMC数据集,包含训练、验证和测试集,训练集包含24万口语化描述的中文句子对,标签为1或0。1为句子语义相似,0为语义不相似。
186 |
187 | tensorFlow版本:
188 |
189 | 1、复制本项目: git clone https://github.com/brightmart/roberta_zh
190 |
191 | 2、进到项目(roberta_zh)中。
192 |
193 | 假设你将RoBERTa预训练模型下载并解压到该改项目的roberta_zh_large目录,即roberta_zh/roberta_zh_large
194 |
195 | 运行命令:
196 |
197 | export BERT_BASE_DIR=./roberta_zh_large
198 | export MY_DATA_DIR=./data/lcqmc
199 | python run_classifier.py \
200 | --task_name=lcqmc_pair \
201 | --do_train=true \
202 | --do_eval=true \
203 | --data_dir=$MY_DATA_DIR \
204 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
205 | --bert_config_file=$BERT_BASE_DIR/bert_config_large.json \
206 | --init_checkpoint=$BERT_BASE_DIR/roberta_zh_large_model.ckpt \
207 | --max_seq_length=128 \
208 | --train_batch_size=64 \
209 | --learning_rate=2e-5 \
210 | --num_train_epochs=3 \
211 | --output_dir=./checkpoint_lcqmc
212 |
213 | 注:task_name为lcqmc_pair。这里已经在run_classifier.py中的添加一个processor,并加到processors中,用于指定做lcqmc任务,并加载训练和验证数据。
214 |
215 | PyTorch加载方式,先参考issue 9;将很快提供更具体方式。
216 |
217 | 预训练 Pre-training
218 | -------------------------------------------------
219 | #### 1) 预训练的数据 data of pre-training
220 | 你可以使用你的任务相关领域的数据来训练,也可以从通用的语料中筛选出一部分与你领域相关的数据做训练。
221 |
222 | 通用语料数据见nlp_chinese_corpus:包含多个拥有数千万句子的语料的数据集。
223 |
224 | #### 2) 生成预训练数据 generate data for pre-training
225 | 包括使用参照DOC-SENTENCES的形式,连续从一个文档中获得数据;以及做全词遮蔽(whole word mask)
226 |
227 | shell脚本:批量将多个txt文本转化为tfrecord的数据。
228 |
229 | 如将第1到10个txt转化为tfrecords文件:
230 |
231 | nohup bash create_pretrain_data.sh 1 10 &
232 |
233 | 注:在我们的实验中使用15%的比例做全词遮蔽,模型学习难度大、收敛困难,所以我们用了10%的比例;
234 |
235 | #### 3)运行预训练命令 pre-training
236 | 去掉next sentence prediction任务
237 |
238 | export BERT_BASE_DIR=
239 | nohup python3 run_pretraining.py --input_file=./tf_records_all/tf*.tfrecord \
240 | --output_dir=my_new_model_path --do_train=True --do_eval=True --bert_config_file=$BERT_BASE_DIR/bert_config.json \
241 | --train_batch_size=8192 --max_seq_length=256 --max_predictions_per_seq=23 \
242 | --num_train_steps=200000 --num_warmup_steps=10000 --learning_rate=1e-4 \
243 | --save_checkpoints_steps=3000 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt &
244 |
245 | 注:如果你重头开始训练,可以不指定init_checkpoint;
246 | 如果你从现有的模型基础上训练,指定一下BERT_BASE_DIR的路径,并确保bert_config_file和init_checkpoint两个参数的值能对应到相应的文件上;
247 | 领域上的预训练,可以不用训练特别久。
248 |
249 | Learning Curve 学习曲线
250 | -------------------------------------------------
251 |
252 |
253 | 对显存的要求 Trade off between batch Size and sequence length
254 | -------------------------------------------------
255 |
256 | System | Seq Length | Max Batch Size
257 | ------------ | ---------- | --------------
258 | `RoBERTa-Base` | 64 | 64
259 | ... | 128 | 32
260 | ... | 256 | 16
261 | ... | 320 | 14
262 | ... | 384 | 12
263 | ... | 512 | 6
264 | `RoBERTa-Large` | 64 | 12
265 | ... | 128 | 6
266 | ... | 256 | 2
267 | ... | 320 | 1
268 | ... | 384 | 0
269 | ... | 512 | 0
270 |
271 |
272 |
273 | #### 技术交流与问题讨论QQ群: 836811304
274 |
275 | If you have any question, you can raise an issue, or send me an email: brightmart@hotmail.com;
276 |
277 | You can also send pull request to report you performance on your task or add methods on how to load models for PyTorch and so on.
278 |
279 | If you have ideas for generate best performance pre-training Chinese model, please also let me know.
280 |
281 | 请报告在你的任务上的准确率情况及与其他模型的比较。
282 |
283 |
284 | 项目贡献者,还包括:
285 | -------------------------------------------------
286 | skyhawk1990
287 |
288 |
289 | ##### Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC)
290 |
291 |
292 |
293 |
294 | Reference
295 | -------------------------------------------------
296 | 1、RoBERTa: A Robustly Optimized BERT Pretraining Approach
297 |
298 | 2、Pre-Training with Whole Word Masking for Chinese BERT
299 |
300 | 3、BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
301 |
302 | 4、LCQMC: A Large-scale Chinese Question Matching Corpus
303 |
--------------------------------------------------------------------------------
/create_pretrain_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo $1,$2
3 |
4 | for((i=$1;i<=$2;i++));
5 | do
6 | python3 create_pretraining_data.py --do_whole_word_mask=True --input_file=./raw_text/news2016zh_$i.txt \
7 | --output_file=./tf_records_all/tf_news2016zh_$i.tfrecord --vocab_file=./resources/vocab.txt \
8 | --do_lower_case=True --max_seq_length=256 --max_predictions_per_seq=23 --masked_lm_prob=0.10 --random_seed=12345 --dupe_factor=5
9 | done
10 |
--------------------------------------------------------------------------------
/create_pretraining_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import random
23 | import re
24 | import tokenization
25 | import tensorflow as tf
26 | import jieba
27 |
28 | flags = tf.flags
29 |
30 | FLAGS = flags.FLAGS
31 |
32 | flags.DEFINE_string("input_file", None,
33 | "Input raw text file (or comma-separated list of files).")
34 |
35 | flags.DEFINE_string(
36 | "output_file", None,
37 | "Output TF example file (or comma-separated list of files).")
38 |
39 | flags.DEFINE_string("vocab_file", None,
40 | "The vocabulary file that the BERT model was trained on.")
41 |
42 | flags.DEFINE_bool(
43 | "do_lower_case", True,
44 | "Whether to lower case the input text. Should be True for uncased "
45 | "models and False for cased models.")
46 |
47 | flags.DEFINE_bool(
48 | "do_whole_word_mask", False,
49 | "Whether to use whole word masking rather than per-WordPiece masking.")
50 |
51 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
52 |
53 | flags.DEFINE_integer("max_predictions_per_seq", 20,
54 | "Maximum number of masked LM predictions per sequence.")
55 |
56 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
57 |
58 | flags.DEFINE_integer(
59 | "dupe_factor", 10,
60 | "Number of times to duplicate the input data (with different masks).")
61 |
62 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
63 |
64 | flags.DEFINE_float(
65 | "short_seq_prob", 0.1,
66 | "Probability of creating sequences which are shorter than the "
67 | "maximum length.")
68 |
69 |
70 | class TrainingInstance(object):
71 | """A single training instance (sentence pair)."""
72 |
73 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
74 | is_random_next):
75 | self.tokens = tokens
76 | self.segment_ids = segment_ids
77 | self.is_random_next = is_random_next
78 | self.masked_lm_positions = masked_lm_positions
79 | self.masked_lm_labels = masked_lm_labels
80 |
81 | def __str__(self):
82 | s = ""
83 | s += "tokens: %s\n" % (" ".join(
84 | [tokenization.printable_text(x) for x in self.tokens]))
85 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
86 | s += "is_random_next: %s\n" % self.is_random_next
87 | s += "masked_lm_positions: %s\n" % (" ".join(
88 | [str(x) for x in self.masked_lm_positions]))
89 | s += "masked_lm_labels: %s\n" % (" ".join(
90 | [tokenization.printable_text(x) for x in self.masked_lm_labels]))
91 | s += "\n"
92 | return s
93 |
94 | def __repr__(self):
95 | return self.__str__()
96 |
97 |
98 | def write_instance_to_example_files(instances, tokenizer, max_seq_length,
99 | max_predictions_per_seq, output_files):
100 | """Create TF example files from `TrainingInstance`s."""
101 | writers = []
102 | for output_file in output_files:
103 | writers.append(tf.python_io.TFRecordWriter(output_file))
104 |
105 | writer_index = 0
106 |
107 | total_written = 0
108 | for (inst_index, instance) in enumerate(instances):
109 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
110 | input_mask = [1] * len(input_ids)
111 | segment_ids = list(instance.segment_ids)
112 | assert len(input_ids) <= max_seq_length
113 |
114 | while len(input_ids) < max_seq_length:
115 | input_ids.append(0)
116 | input_mask.append(0)
117 | segment_ids.append(0)
118 |
119 | assert len(input_ids) == max_seq_length
120 | assert len(input_mask) == max_seq_length
121 | # print("length of segment_ids:",len(segment_ids),"max_seq_length:", max_seq_length)
122 | assert len(segment_ids) == max_seq_length
123 |
124 | masked_lm_positions = list(instance.masked_lm_positions)
125 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
126 | masked_lm_weights = [1.0] * len(masked_lm_ids)
127 |
128 | while len(masked_lm_positions) < max_predictions_per_seq:
129 | masked_lm_positions.append(0)
130 | masked_lm_ids.append(0)
131 | masked_lm_weights.append(0.0)
132 |
133 | next_sentence_label = 1 if instance.is_random_next else 0
134 |
135 | features = collections.OrderedDict()
136 | features["input_ids"] = create_int_feature(input_ids)
137 | features["input_mask"] = create_int_feature(input_mask)
138 | features["segment_ids"] = create_int_feature(segment_ids)
139 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
140 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
141 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
142 | features["next_sentence_labels"] = create_int_feature([next_sentence_label])
143 |
144 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
145 |
146 | writers[writer_index].write(tf_example.SerializeToString())
147 | writer_index = (writer_index + 1) % len(writers)
148 |
149 | total_written += 1
150 |
151 | if inst_index < 20:
152 | tf.logging.info("*** Example ***")
153 | tf.logging.info("tokens: %s" % " ".join(
154 | [tokenization.printable_text(x) for x in instance.tokens]))
155 |
156 | for feature_name in features.keys():
157 | feature = features[feature_name]
158 | values = []
159 | if feature.int64_list.value:
160 | values = feature.int64_list.value
161 | elif feature.float_list.value:
162 | values = feature.float_list.value
163 | tf.logging.info(
164 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
165 |
166 | for writer in writers:
167 | writer.close()
168 |
169 | tf.logging.info("Wrote %d total instances", total_written)
170 |
171 |
172 | def create_int_feature(values):
173 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
174 | return feature
175 |
176 |
177 | def create_float_feature(values):
178 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
179 | return feature
180 |
181 |
182 | def create_training_instances(input_files, tokenizer, max_seq_length,
183 | dupe_factor, short_seq_prob, masked_lm_prob,
184 | max_predictions_per_seq, rng):
185 | """Create `TrainingInstance`s from raw text."""
186 | all_documents = [[]]
187 |
188 | # Input file format:
189 | # (1) One sentence per line. These should ideally be actual sentences, not
190 | # entire paragraphs or arbitrary spans of text. (Because we use the
191 | # sentence boundaries for the "next sentence prediction" task).
192 | # (2) Blank lines between documents. Document boundaries are needed so
193 | # that the "next sentence prediction" task doesn't span between documents.
194 | print("create_training_instances.started...")
195 | for input_file in input_files:
196 | with tf.gfile.GFile(input_file, "r") as reader:
197 | while True:
198 | line = tokenization.convert_to_unicode(reader.readline().replace("",""))# .replace("”","")) # 将、”替换掉。
199 | if not line:
200 | break
201 | line = line.strip()
202 |
203 | # Empty lines are used as document delimiters
204 | if not line:
205 | all_documents.append([])
206 | tokens = tokenizer.tokenize(line)
207 | if tokens:
208 | all_documents[-1].append(tokens)
209 |
210 | # Remove empty documents
211 | all_documents = [x for x in all_documents if x]
212 | rng.shuffle(all_documents)
213 |
214 | vocab_words = list(tokenizer.vocab.keys())
215 | instances = []
216 | for _ in range(dupe_factor):
217 | for document_index in range(len(all_documents)):
218 | instances.extend(
219 | create_instances_from_document(
220 | all_documents, document_index, max_seq_length, short_seq_prob,
221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
222 |
223 | rng.shuffle(instances)
224 | print("create_training_instances.ended...")
225 |
226 | return instances
227 |
228 |
229 | def _is_chinese_char(cp):
230 | """Checks whether CP is the codepoint of a CJK character."""
231 | # This defines a "chinese character" as anything in the CJK Unicode block:
232 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
233 | #
234 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
235 | # despite its name. The modern Korean Hangul alphabet is a different block,
236 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
237 | # space-separated words, so they are not treated specially and handled
238 | # like the all of the other languages.
239 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
240 | (cp >= 0x3400 and cp <= 0x4DBF) or #
241 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
242 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
243 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
244 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
245 | (cp >= 0xF900 and cp <= 0xFAFF) or #
246 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
247 | return True
248 |
249 |
250 | def get_new_segment(segment): # 新增的方法 ####
251 | """
252 | 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。
253 | :param segment: 一句话
254 | :return: 一句处理过的话
255 | """
256 | seq_cws = jieba.lcut("".join(segment))
257 | seq_cws_dict = {x: 1 for x in seq_cws}
258 | new_segment = []
259 | i = 0
260 | while i < len(segment):
261 | if len(re.findall('[\u4E00-\u9FA5]', segment[i]))==0: # 不是中文的,原文加进去。
262 | new_segment.append(segment[i])
263 | i += 1
264 | continue
265 |
266 | has_add = False
267 | for length in range(3,0,-1):
268 | if i+length>len(segment):
269 | continue
270 | if ''.join(segment[i:i+length]) in seq_cws_dict:
271 | new_segment.append(segment[i])
272 | for l in range(1, length):
273 | new_segment.append('##' + segment[i+l])
274 | i += length
275 | has_add = True
276 | break
277 | if not has_add:
278 | new_segment.append(segment[i])
279 | i += 1
280 | return new_segment
281 |
282 | def get_raw_instance(document,max_sequence_length): # 新增的方法
283 | """
284 | 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。
285 | :param document: 一整段
286 | :param max_sequence_length:
287 | :return: a list. each element is a sequence of text
288 | """
289 | max_sequence_length_allowed=max_sequence_length-2
290 | document = [seq for seq in document if len(seq)max_sequence_length_allowed/2: # /2
306 | result_list.append(curr_seq)
307 |
308 | # # 计算总共可以得到多少份
309 | # num_instance=int(len(big_list)/max_sequence_length_allowed)+1
310 | # print("num_instance:",num_instance)
311 | # # 切分成多份,添加到列表中
312 | # result_list=[]
313 | # for j in range(num_instance):
314 | # index=j*max_sequence_length_allowed
315 | # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
316 | # result_list.append(big_list[index:end_index])
317 | return result_list
318 |
319 | def create_instances_from_document( # 新增的方法
320 | # 目标按照RoBERTa的思路,使用DOC-SENTENCES,并会去掉NSP任务: 从一个文档中连续的获得文本,直到达到最大长度。如果是从下一个文档中获得,那么加上一个分隔符
321 | # document即一整段话,包含多个句子。每个句子叫做segment.
322 | # 给定一个document即一整段话,生成一些instance.
323 | all_documents, document_index, max_seq_length, short_seq_prob,
324 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
325 | """Creates `TrainingInstance`s for a single document."""
326 | document = all_documents[document_index]
327 |
328 | # Account for [CLS], [SEP], [SEP]
329 | max_num_tokens = max_seq_length - 3
330 |
331 | # We *usually* want to fill up the entire sequence since we are padding
332 | # to `max_seq_length` anyways, so short sequences are generally wasted
333 | # computation. However, we *sometimes*
334 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
335 | # sequences to minimize the mismatch between pre-training and fine-tuning.
336 | # The `target_seq_length` is just a rough target however, whereas
337 | # `max_seq_length` is a hard limit.
338 |
339 | #target_seq_length = max_num_tokens
340 | #if rng.random() < short_seq_prob:
341 | # target_seq_length = rng.randint(2, max_num_tokens)
342 |
343 | instances = []
344 | raw_text_list_list=get_raw_instance(document, max_seq_length) # document即一整段话,包含多个句子。每个句子叫做segment.
345 | for j, raw_text_list in enumerate(raw_text_list_list):
346 | ####################################################################################################################
347 | raw_text_list = get_new_segment(raw_text_list) # 结合分词的中文的whole mask设置即在需要的地方加上“##”
348 | # 1、设置token, segment_ids
349 | is_random_next=True # this will not be used, so it's value doesn't matter
350 | tokens = []
351 | segment_ids = []
352 | tokens.append("[CLS]")
353 | segment_ids.append(0)
354 | for token in raw_text_list:
355 | tokens.append(token)
356 | segment_ids.append(0)
357 | tokens.append("[SEP]")
358 | segment_ids.append(0)
359 | ################################################################################################################
360 | # 2、调用原有的方法
361 | (tokens, masked_lm_positions,
362 | masked_lm_labels) = create_masked_lm_predictions(
363 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
364 | instance = TrainingInstance(
365 | tokens=tokens,
366 | segment_ids=segment_ids,
367 | is_random_next=is_random_next,
368 | masked_lm_positions=masked_lm_positions,
369 | masked_lm_labels=masked_lm_labels)
370 | instances.append(instance)
371 |
372 | return instances
373 |
374 |
375 |
376 | def create_instances_from_document_original(
377 | all_documents, document_index, max_seq_length, short_seq_prob,
378 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
379 | """Creates `TrainingInstance`s for a single document."""
380 | document = all_documents[document_index]
381 |
382 | # Account for [CLS], [SEP], [SEP]
383 | max_num_tokens = max_seq_length - 3
384 |
385 | # We *usually* want to fill up the entire sequence since we are padding
386 | # to `max_seq_length` anyways, so short sequences are generally wasted
387 | # computation. However, we *sometimes*
388 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
389 | # sequences to minimize the mismatch between pre-training and fine-tuning.
390 | # The `target_seq_length` is just a rough target however, whereas
391 | # `max_seq_length` is a hard limit.
392 | target_seq_length = max_num_tokens
393 | if rng.random() < short_seq_prob:
394 | target_seq_length = rng.randint(2, max_num_tokens)
395 |
396 | # We DON'T just concatenate all of the tokens from a document into a long
397 | # sequence and choose an arbitrary split point because this would make the
398 | # next sentence prediction task too easy. Instead, we split the input into
399 | # segments "A" and "B" based on the actual "sentences" provided by the user
400 | # input.
401 | instances = []
402 | current_chunk = []
403 | current_length = 0
404 | i = 0
405 | print("document_index:",document_index,"document:",type(document)," ;document:",document) # document即一整段话,包含多个句子。每个句子叫做segment.
406 | while i < len(document):
407 | segment = document[i] # 取到一个部分(可能是一段话)
408 | print("i:",i," ;segment:",segment)
409 | ####################################################################################################################
410 | segment = get_new_segment(segment) # 结合分词的中文的whole mask设置即在需要的地方加上“##”
411 | ###################################################################################################################
412 | current_chunk.append(segment)
413 | current_length += len(segment)
414 | print("#####condition:",i == len(document) - 1 or current_length >= target_seq_length)
415 | if i == len(document) - 1 or current_length >= target_seq_length:
416 | if current_chunk:
417 | # `a_end` is how many segments from `current_chunk` go into the `A`
418 | # (first) sentence.
419 | a_end = 1
420 | if len(current_chunk) >= 2:
421 | a_end = rng.randint(1, len(current_chunk) - 1)
422 |
423 | tokens_a = []
424 | for j in range(a_end):
425 | tokens_a.extend(current_chunk[j])
426 |
427 | tokens_b = []
428 | # Random next
429 | is_random_next = False
430 | if len(current_chunk) == 1 or rng.random() < 0.5:
431 | is_random_next = True
432 | target_b_length = target_seq_length - len(tokens_a)
433 |
434 | # This should rarely go for more than one iteration for large
435 | # corpora. However, just to be careful, we try to make sure that
436 | # the random document is not the same as the document
437 | # we're processing.
438 | for _ in range(10):
439 | random_document_index = rng.randint(0, len(all_documents) - 1)
440 | if random_document_index != document_index:
441 | break
442 |
443 | random_document = all_documents[random_document_index]
444 | random_start = rng.randint(0, len(random_document) - 1)
445 | for j in range(random_start, len(random_document)):
446 | tokens_b.extend(random_document[j])
447 | if len(tokens_b) >= target_b_length:
448 | break
449 | # We didn't actually use these segments so we "put them back" so
450 | # they don't go to waste.
451 | num_unused_segments = len(current_chunk) - a_end
452 | i -= num_unused_segments
453 | # Actual next
454 | else:
455 | is_random_next = False
456 | for j in range(a_end, len(current_chunk)):
457 | tokens_b.extend(current_chunk[j])
458 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
459 |
460 | assert len(tokens_a) >= 1
461 | assert len(tokens_b) >= 1
462 |
463 | tokens = []
464 | segment_ids = []
465 | tokens.append("[CLS]")
466 | segment_ids.append(0)
467 | for token in tokens_a:
468 | tokens.append(token)
469 | segment_ids.append(0)
470 |
471 | tokens.append("[SEP]")
472 | segment_ids.append(0)
473 |
474 | for token in tokens_b:
475 | tokens.append(token)
476 | segment_ids.append(1)
477 | tokens.append("[SEP]")
478 | segment_ids.append(1)
479 |
480 | (tokens, masked_lm_positions,
481 | masked_lm_labels) = create_masked_lm_predictions(
482 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
483 | instance = TrainingInstance(
484 | tokens=tokens,
485 | segment_ids=segment_ids,
486 | is_random_next=is_random_next,
487 | masked_lm_positions=masked_lm_positions,
488 | masked_lm_labels=masked_lm_labels)
489 | instances.append(instance)
490 | current_chunk = []
491 | current_length = 0
492 | i += 1
493 |
494 | return instances
495 |
496 |
497 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
498 | ["index", "label"])
499 |
500 |
501 | def create_masked_lm_predictions(tokens, masked_lm_prob,
502 | max_predictions_per_seq, vocab_words, rng):
503 | """Creates the predictions for the masked LM objective."""
504 |
505 | cand_indexes = []
506 | for (i, token) in enumerate(tokens):
507 | if token == "[CLS]" or token == "[SEP]":
508 | continue
509 | # Whole Word Masking means that if we mask all of the wordpieces
510 | # corresponding to an original word. When a word has been split into
511 | # WordPieces, the first token does not have any marker and any subsequence
512 | # tokens are prefixed with ##. So whenever we see the ## token, we
513 | # append it to the previous set of word indexes.
514 | #
515 | # Note that Whole Word Masking does *not* change the training code
516 | # at all -- we still predict each WordPiece independently, softmaxed
517 | # over the entire vocabulary.
518 | if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
519 | token.startswith("##")):
520 | cand_indexes[-1].append(i)
521 | else:
522 | cand_indexes.append([i])
523 |
524 | rng.shuffle(cand_indexes)
525 |
526 | output_tokens = [t[2:] if len(re.findall('##[\u4E00-\u9FA5]', t))>0 else t for t in tokens]
527 |
528 | num_to_predict = min(max_predictions_per_seq,
529 | max(1, int(round(len(tokens) * masked_lm_prob))))
530 |
531 | masked_lms = []
532 | covered_indexes = set()
533 | for index_set in cand_indexes:
534 | if len(masked_lms) >= num_to_predict:
535 | break
536 | # If adding a whole-word mask would exceed the maximum number of
537 | # predictions, then just skip this candidate.
538 | if len(masked_lms) + len(index_set) > num_to_predict:
539 | continue
540 | is_any_index_covered = False
541 | for index in index_set:
542 | if index in covered_indexes:
543 | is_any_index_covered = True
544 | break
545 | if is_any_index_covered:
546 | continue
547 | for index in index_set:
548 | covered_indexes.add(index)
549 |
550 | masked_token = None
551 | # 80% of the time, replace with [MASK]
552 | if rng.random() < 0.8:
553 | masked_token = "[MASK]"
554 | else:
555 | # 10% of the time, keep original
556 | if rng.random() < 0.5:
557 | masked_token = tokens[index][2:] if len(re.findall('##[\u4E00-\u9FA5]', tokens[index]))>0 else tokens[index]
558 | # 10% of the time, replace with random word
559 | else:
560 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
561 |
562 | output_tokens[index] = masked_token
563 |
564 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
565 | assert len(masked_lms) <= num_to_predict
566 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
567 |
568 | masked_lm_positions = []
569 | masked_lm_labels = []
570 | for p in masked_lms:
571 | masked_lm_positions.append(p.index)
572 | masked_lm_labels.append(p.label)
573 |
574 | # tf.logging.info('%s' % (tokens))
575 | # tf.logging.info('%s' % (output_tokens))
576 | return (output_tokens, masked_lm_positions, masked_lm_labels)
577 |
578 |
579 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
580 | """Truncates a pair of sequences to a maximum sequence length."""
581 | while True:
582 | total_length = len(tokens_a) + len(tokens_b)
583 | if total_length <= max_num_tokens:
584 | break
585 |
586 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
587 | assert len(trunc_tokens) >= 1
588 |
589 | # We want to sometimes truncate from the front and sometimes from the
590 | # back to add more randomness and avoid biases.
591 | if rng.random() < 0.5:
592 | del trunc_tokens[0]
593 | else:
594 | trunc_tokens.pop()
595 |
596 |
597 | def main(_):
598 | tf.logging.set_verbosity(tf.logging.INFO)
599 |
600 | tokenizer = tokenization.FullTokenizer(
601 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
602 |
603 | input_files = []
604 | for input_pattern in FLAGS.input_file.split(","):
605 | input_files.extend(tf.gfile.Glob(input_pattern))
606 |
607 | tf.logging.info("*** Reading from input files ***")
608 | for input_file in input_files:
609 | tf.logging.info(" %s", input_file)
610 |
611 | rng = random.Random(FLAGS.random_seed)
612 | instances = create_training_instances(
613 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
614 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
615 | rng)
616 |
617 | output_files = FLAGS.output_file.split(",")
618 | tf.logging.info("*** Writing to output files ***")
619 | for output_file in output_files:
620 | tf.logging.info(" %s", output_file)
621 |
622 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
623 | FLAGS.max_predictions_per_seq, output_files)
624 |
625 |
626 | if __name__ == "__main__":
627 | flags.mark_flag_as_required("input_file")
628 | flags.mark_flag_as_required("output_file")
629 | flags.mark_flag_as_required("vocab_file")
630 | tf.app.run()
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import numpy as np
27 | import six
28 | import tensorflow as tf
29 |
30 |
31 | class BertConfig(object):
32 | """Configuration for `BertModel`."""
33 |
34 | def __init__(self,
35 | vocab_size,
36 | hidden_size=768,
37 | num_hidden_layers=12,
38 | num_attention_heads=12,
39 | intermediate_size=3072,
40 | hidden_act="gelu",
41 | hidden_dropout_prob=0.1,
42 | attention_probs_dropout_prob=0.1,
43 | max_position_embeddings=512,
44 | type_vocab_size=16,
45 | initializer_range=0.02):
46 | """Constructs BertConfig.
47 |
48 | Args:
49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
50 | hidden_size: Size of the encoder layers and the pooler layer.
51 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
52 | num_attention_heads: Number of attention heads for each attention layer in
53 | the Transformer encoder.
54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
55 | layer in the Transformer encoder.
56 | hidden_act: The non-linear activation function (function or string) in the
57 | encoder and pooler.
58 | hidden_dropout_prob: The dropout probability for all fully connected
59 | layers in the embeddings, encoder, and pooler.
60 | attention_probs_dropout_prob: The dropout ratio for the attention
61 | probabilities.
62 | max_position_embeddings: The maximum sequence length that this model might
63 | ever be used with. Typically set this to something large just in case
64 | (e.g., 512 or 1024 or 2048).
65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
66 | `BertModel`.
67 | initializer_range: The stdev of the truncated_normal_initializer for
68 | initializing all weight matrices.
69 | """
70 | self.vocab_size = vocab_size
71 | self.hidden_size = hidden_size
72 | self.num_hidden_layers = num_hidden_layers
73 | self.num_attention_heads = num_attention_heads
74 | self.hidden_act = hidden_act
75 | self.intermediate_size = intermediate_size
76 | self.hidden_dropout_prob = hidden_dropout_prob
77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
78 | self.max_position_embeddings = max_position_embeddings
79 | self.type_vocab_size = type_vocab_size
80 | self.initializer_range = initializer_range
81 |
82 | @classmethod
83 | def from_dict(cls, json_object):
84 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
85 | config = BertConfig(vocab_size=None)
86 | for (key, value) in six.iteritems(json_object):
87 | config.__dict__[key] = value
88 | return config
89 |
90 | @classmethod
91 | def from_json_file(cls, json_file):
92 | """Constructs a `BertConfig` from a json file of parameters."""
93 | with tf.gfile.GFile(json_file, "r") as reader:
94 | text = reader.read()
95 | return cls.from_dict(json.loads(text))
96 |
97 | def to_dict(self):
98 | """Serializes this instance to a Python dictionary."""
99 | output = copy.deepcopy(self.__dict__)
100 | return output
101 |
102 | def to_json_string(self):
103 | """Serializes this instance to a JSON string."""
104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
105 |
106 |
107 | class BertModel(object):
108 | """BERT model ("Bidirectional Encoder Representations from Transformers").
109 |
110 | Example usage:
111 |
112 | ```python
113 | # Already been converted into WordPiece token ids
114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
117 |
118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
120 |
121 | model = modeling.BertModel(config=config, is_training=True,
122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
123 |
124 | label_embeddings = tf.get_variable(...)
125 | pooled_output = model.get_pooled_output()
126 | logits = tf.matmul(pooled_output, label_embeddings)
127 | ...
128 | ```
129 | """
130 |
131 | def __init__(self,
132 | config,
133 | is_training,
134 | input_ids,
135 | input_mask=None,
136 | token_type_ids=None,
137 | use_one_hot_embeddings=False,
138 | scope=None):
139 | """Constructor for BertModel.
140 |
141 | Args:
142 | config: `BertConfig` instance.
143 | is_training: bool. true for training model, false for eval model. Controls
144 | whether dropout will be applied.
145 | input_ids: int32 Tensor of shape [batch_size, seq_length].
146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
149 | embeddings or tf.embedding_lookup() for the word embeddings.
150 | scope: (optional) variable scope. Defaults to "bert".
151 |
152 | Raises:
153 | ValueError: The config is invalid or one of the input tensor shapes
154 | is invalid.
155 | """
156 | config = copy.deepcopy(config)
157 | if not is_training:
158 | config.hidden_dropout_prob = 0.0
159 | config.attention_probs_dropout_prob = 0.0
160 |
161 | input_shape = get_shape_list(input_ids, expected_rank=2)
162 | batch_size = input_shape[0]
163 | seq_length = input_shape[1]
164 |
165 | if input_mask is None:
166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
167 |
168 | if token_type_ids is None:
169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
170 |
171 | with tf.variable_scope(scope, default_name="bert"):
172 | with tf.variable_scope("embeddings"):
173 | # Perform embedding lookup on the word ids.
174 | (self.embedding_output, self.embedding_table) = embedding_lookup(
175 | input_ids=input_ids,
176 | vocab_size=config.vocab_size,
177 | embedding_size=config.hidden_size,
178 | initializer_range=config.initializer_range,
179 | word_embedding_name="word_embeddings",
180 | use_one_hot_embeddings=use_one_hot_embeddings)
181 |
182 | # Add positional embeddings and token type embeddings, then layer
183 | # normalize and perform dropout.
184 | self.embedding_output = embedding_postprocessor(
185 | input_tensor=self.embedding_output,
186 | use_token_type=True,
187 | token_type_ids=token_type_ids,
188 | token_type_vocab_size=config.type_vocab_size,
189 | token_type_embedding_name="token_type_embeddings",
190 | use_position_embeddings=True,
191 | position_embedding_name="position_embeddings",
192 | initializer_range=config.initializer_range,
193 | max_position_embeddings=config.max_position_embeddings,
194 | dropout_prob=config.hidden_dropout_prob)
195 |
196 | with tf.variable_scope("encoder"):
197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
198 | # mask of shape [batch_size, seq_length, seq_length] which is used
199 | # for the attention scores.
200 | attention_mask = create_attention_mask_from_input_mask(
201 | input_ids, input_mask)
202 |
203 | # Run the stacked transformer.
204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
205 | self.all_encoder_layers = transformer_model(
206 | input_tensor=self.embedding_output,
207 | attention_mask=attention_mask,
208 | hidden_size=config.hidden_size,
209 | num_hidden_layers=config.num_hidden_layers,
210 | num_attention_heads=config.num_attention_heads,
211 | intermediate_size=config.intermediate_size,
212 | intermediate_act_fn=get_activation(config.hidden_act),
213 | hidden_dropout_prob=config.hidden_dropout_prob,
214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
215 | initializer_range=config.initializer_range,
216 | do_return_all_layers=True)
217 |
218 | self.sequence_output = self.all_encoder_layers[-1] # [batch_size, seq_length, hidden_size]
219 | # The "pooler" converts the encoded sequence tensor of shape
220 | # [batch_size, seq_length, hidden_size] to a tensor of shape
221 | # [batch_size, hidden_size]. This is necessary for segment-level
222 | # (or segment-pair-level) classification tasks where we need a fixed
223 | # dimensional representation of the segment.
224 | with tf.variable_scope("pooler"):
225 | # We "pool" the model by simply taking the hidden state corresponding
226 | # to the first token. We assume that this has been pre-trained
227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
228 | self.pooled_output = tf.layers.dense(
229 | first_token_tensor,
230 | config.hidden_size,
231 | activation=tf.tanh,
232 | kernel_initializer=create_initializer(config.initializer_range))
233 |
234 | def get_pooled_output(self):
235 | return self.pooled_output
236 |
237 | def get_sequence_output(self):
238 | """Gets final hidden layer of encoder.
239 |
240 | Returns:
241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
242 | to the final hidden of the transformer encoder.
243 | """
244 | return self.sequence_output
245 |
246 | def get_all_encoder_layers(self):
247 | return self.all_encoder_layers
248 |
249 | def get_embedding_output(self):
250 | """Gets output of the embedding lookup (i.e., input to the transformer).
251 |
252 | Returns:
253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
254 | to the output of the embedding layer, after summing the word
255 | embeddings with the positional embeddings and the token type embeddings,
256 | then performing layer normalization. This is the input to the transformer.
257 | """
258 | return self.embedding_output
259 |
260 | def get_embedding_table(self):
261 | return self.embedding_table
262 |
263 |
264 | def gelu(x):
265 | """Gaussian Error Linear Unit.
266 |
267 | This is a smoother version of the RELU.
268 | Original paper: https://arxiv.org/abs/1606.08415
269 | Args:
270 | x: float Tensor to perform activation.
271 |
272 | Returns:
273 | `x` with the GELU activation applied.
274 | """
275 | cdf = 0.5 * (1.0 + tf.tanh(
276 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
277 | return x * cdf
278 |
279 |
280 | def get_activation(activation_string):
281 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
282 |
283 | Args:
284 | activation_string: String name of the activation function.
285 |
286 | Returns:
287 | A Python function corresponding to the activation function. If
288 | `activation_string` is None, empty, or "linear", this will return None.
289 | If `activation_string` is not a string, it will return `activation_string`.
290 |
291 | Raises:
292 | ValueError: The `activation_string` does not correspond to a known
293 | activation.
294 | """
295 |
296 | # We assume that anything that"s not a string is already an activation
297 | # function, so we just return it.
298 | if not isinstance(activation_string, six.string_types):
299 | return activation_string
300 |
301 | if not activation_string:
302 | return None
303 |
304 | act = activation_string.lower()
305 | if act == "linear":
306 | return None
307 | elif act == "relu":
308 | return tf.nn.relu
309 | elif act == "gelu":
310 | return gelu
311 | elif act == "tanh":
312 | return tf.tanh
313 | else:
314 | raise ValueError("Unsupported activation: %s" % act)
315 |
316 |
317 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
318 | """Compute the union of the current variables and checkpoint variables."""
319 | assignment_map = {}
320 | initialized_variable_names = {}
321 |
322 | name_to_variable = collections.OrderedDict()
323 | for var in tvars:
324 | name = var.name
325 | m = re.match("^(.*):\\d+$", name)
326 | if m is not None:
327 | name = m.group(1)
328 | name_to_variable[name] = var
329 |
330 | init_vars = tf.train.list_variables(init_checkpoint)
331 |
332 | assignment_map = collections.OrderedDict()
333 | for x in init_vars:
334 | (name, var) = (x[0], x[1])
335 | if name not in name_to_variable:
336 | continue
337 | assignment_map[name] = name
338 | initialized_variable_names[name] = 1
339 | initialized_variable_names[name + ":0"] = 1
340 |
341 | return (assignment_map, initialized_variable_names)
342 |
343 |
344 | def dropout(input_tensor, dropout_prob):
345 | """Perform dropout.
346 |
347 | Args:
348 | input_tensor: float Tensor.
349 | dropout_prob: Python float. The probability of dropping out a value (NOT of
350 | *keeping* a dimension as in `tf.nn.dropout`).
351 |
352 | Returns:
353 | A version of `input_tensor` with dropout applied.
354 | """
355 | if dropout_prob is None or dropout_prob == 0.0:
356 | return input_tensor
357 |
358 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
359 | return output
360 |
361 |
362 | def layer_norm(input_tensor, name=None):
363 | """Run layer normalization on the last dimension of the tensor."""
364 | return tf.contrib.layers.layer_norm(
365 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
366 |
367 |
368 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
369 | """Runs layer normalization followed by dropout."""
370 | output_tensor = layer_norm(input_tensor, name)
371 | output_tensor = dropout(output_tensor, dropout_prob)
372 | return output_tensor
373 |
374 |
375 | def create_initializer(initializer_range=0.02):
376 | """Creates a `truncated_normal_initializer` with the given range."""
377 | return tf.truncated_normal_initializer(stddev=initializer_range)
378 |
379 |
380 | def embedding_lookup(input_ids,
381 | vocab_size,
382 | embedding_size=128,
383 | initializer_range=0.02,
384 | word_embedding_name="word_embeddings",
385 | use_one_hot_embeddings=False):
386 | """Looks up words embeddings for id tensor.
387 |
388 | Args:
389 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
390 | ids.
391 | vocab_size: int. Size of the embedding vocabulary.
392 | embedding_size: int. Width of the word embeddings.
393 | initializer_range: float. Embedding initialization range.
394 | word_embedding_name: string. Name of the embedding table.
395 | use_one_hot_embeddings: bool. If True, use one-hot method for word
396 | embeddings. If False, use `tf.gather()`.
397 |
398 | Returns:
399 | float Tensor of shape [batch_size, seq_length, embedding_size].
400 | """
401 | # This function assumes that the input is of shape [batch_size, seq_length,
402 | # num_inputs].
403 | #
404 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
405 | # reshape to [batch_size, seq_length, 1].
406 | if input_ids.shape.ndims == 2:
407 | input_ids = tf.expand_dims(input_ids, axis=[-1])
408 |
409 | embedding_table = tf.get_variable(
410 | name=word_embedding_name,
411 | shape=[vocab_size, embedding_size],
412 | initializer=create_initializer(initializer_range))
413 |
414 | flat_input_ids = tf.reshape(input_ids, [-1])
415 | if use_one_hot_embeddings:
416 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
417 | output = tf.matmul(one_hot_input_ids, embedding_table)
418 | else:
419 | output = tf.gather(embedding_table, flat_input_ids)
420 |
421 | input_shape = get_shape_list(input_ids)
422 |
423 | output = tf.reshape(output,
424 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
425 | return (output, embedding_table)
426 |
427 |
428 | def embedding_postprocessor(input_tensor,
429 | use_token_type=False,
430 | token_type_ids=None,
431 | token_type_vocab_size=16,
432 | token_type_embedding_name="token_type_embeddings",
433 | use_position_embeddings=True,
434 | position_embedding_name="position_embeddings",
435 | initializer_range=0.02,
436 | max_position_embeddings=512,
437 | dropout_prob=0.1):
438 | """Performs various post-processing on a word embedding tensor.
439 |
440 | Args:
441 | input_tensor: float Tensor of shape [batch_size, seq_length,
442 | embedding_size].
443 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
444 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
445 | Must be specified if `use_token_type` is True.
446 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
447 | token_type_embedding_name: string. The name of the embedding table variable
448 | for token type ids.
449 | use_position_embeddings: bool. Whether to add position embeddings for the
450 | position of each token in the sequence.
451 | position_embedding_name: string. The name of the embedding table variable
452 | for positional embeddings.
453 | initializer_range: float. Range of the weight initialization.
454 | max_position_embeddings: int. Maximum sequence length that might ever be
455 | used with this model. This can be longer than the sequence length of
456 | input_tensor, but cannot be shorter.
457 | dropout_prob: float. Dropout probability applied to the final output tensor.
458 |
459 | Returns:
460 | float tensor with same shape as `input_tensor`.
461 |
462 | Raises:
463 | ValueError: One of the tensor shapes or input values is invalid.
464 | """
465 | input_shape = get_shape_list(input_tensor, expected_rank=3)
466 | batch_size = input_shape[0]
467 | seq_length = input_shape[1]
468 | width = input_shape[2]
469 |
470 | output = input_tensor
471 |
472 | if use_token_type:
473 | if token_type_ids is None:
474 | raise ValueError("`token_type_ids` must be specified if"
475 | "`use_token_type` is True.")
476 | token_type_table = tf.get_variable(
477 | name=token_type_embedding_name,
478 | shape=[token_type_vocab_size, width],
479 | initializer=create_initializer(initializer_range))
480 | # This vocab will be small so we always do one-hot here, since it is always
481 | # faster for a small vocabulary.
482 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
483 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
484 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
485 | token_type_embeddings = tf.reshape(token_type_embeddings,
486 | [batch_size, seq_length, width])
487 | output += token_type_embeddings
488 |
489 | if use_position_embeddings:
490 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
491 | with tf.control_dependencies([assert_op]):
492 | full_position_embeddings = tf.get_variable(
493 | name=position_embedding_name,
494 | shape=[max_position_embeddings, width],
495 | initializer=create_initializer(initializer_range))
496 | # Since the position embedding table is a learned variable, we create it
497 | # using a (long) sequence length `max_position_embeddings`. The actual
498 | # sequence length might be shorter than this, for faster training of
499 | # tasks that do not have long sequences.
500 | #
501 | # So `full_position_embeddings` is effectively an embedding table
502 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
503 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
504 | # perform a slice.
505 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
506 | [seq_length, -1])
507 | num_dims = len(output.shape.as_list())
508 |
509 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
510 | # we broadcast among the first dimensions, which is typically just
511 | # the batch size.
512 | position_broadcast_shape = []
513 | for _ in range(num_dims - 2):
514 | position_broadcast_shape.append(1)
515 | position_broadcast_shape.extend([seq_length, width])
516 | position_embeddings = tf.reshape(position_embeddings,
517 | position_broadcast_shape)
518 | output += position_embeddings
519 |
520 | output = layer_norm_and_dropout(output, dropout_prob)
521 | return output
522 |
523 |
524 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
525 | """Create 3D attention mask from a 2D tensor mask.
526 |
527 | Args:
528 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
529 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
530 |
531 | Returns:
532 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
533 | """
534 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
535 | batch_size = from_shape[0]
536 | from_seq_length = from_shape[1]
537 |
538 | to_shape = get_shape_list(to_mask, expected_rank=2)
539 | to_seq_length = to_shape[1]
540 |
541 | to_mask = tf.cast(
542 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
543 |
544 | # We don't assume that `from_tensor` is a mask (although it could be). We
545 | # don't actually care if we attend *from* padding tokens (only *to* padding)
546 | # tokens so we create a tensor of all ones.
547 | #
548 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
549 | broadcast_ones = tf.ones(
550 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
551 |
552 | # Here we broadcast along two dimensions to create the mask.
553 | mask = broadcast_ones * to_mask
554 |
555 | return mask
556 |
557 |
558 | def attention_layer(from_tensor,
559 | to_tensor,
560 | attention_mask=None,
561 | num_attention_heads=1,
562 | size_per_head=512,
563 | query_act=None,
564 | key_act=None,
565 | value_act=None,
566 | attention_probs_dropout_prob=0.0,
567 | initializer_range=0.02,
568 | do_return_2d_tensor=False,
569 | batch_size=None,
570 | from_seq_length=None,
571 | to_seq_length=None):
572 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
573 |
574 | This is an implementation of multi-headed attention based on "Attention
575 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
576 | this is self-attention. Each timestep in `from_tensor` attends to the
577 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
578 |
579 | This function first projects `from_tensor` into a "query" tensor and
580 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
581 | of tensors of length `num_attention_heads`, where each tensor is of shape
582 | [batch_size, seq_length, size_per_head].
583 |
584 | Then, the query and key tensors are dot-producted and scaled. These are
585 | softmaxed to obtain attention probabilities. The value tensors are then
586 | interpolated by these probabilities, then concatenated back to a single
587 | tensor and returned.
588 |
589 | In practice, the multi-headed attention are done with transposes and
590 | reshapes rather than actual separate tensors.
591 |
592 | Args:
593 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
594 | from_width].
595 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
596 | attention_mask: (optional) int32 Tensor of shape [batch_size,
597 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
598 | attention scores will effectively be set to -infinity for any positions in
599 | the mask that are 0, and will be unchanged for positions that are 1.
600 | num_attention_heads: int. Number of attention heads.
601 | size_per_head: int. Size of each attention head.
602 | query_act: (optional) Activation function for the query transform.
603 | key_act: (optional) Activation function for the key transform.
604 | value_act: (optional) Activation function for the value transform.
605 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
606 | attention probabilities.
607 | initializer_range: float. Range of the weight initializer.
608 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
609 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
610 | output will be of shape [batch_size, from_seq_length, num_attention_heads
611 | * size_per_head].
612 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
613 | of the 3D version of the `from_tensor` and `to_tensor`.
614 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
615 | of the 3D version of the `from_tensor`.
616 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
617 | of the 3D version of the `to_tensor`.
618 |
619 | Returns:
620 | float Tensor of shape [batch_size, from_seq_length,
621 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
622 | true, this will be of shape [batch_size * from_seq_length,
623 | num_attention_heads * size_per_head]).
624 |
625 | Raises:
626 | ValueError: Any of the arguments or tensor shapes are invalid.
627 | """
628 |
629 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
630 | seq_length, width):
631 | output_tensor = tf.reshape(
632 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
633 |
634 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
635 | return output_tensor
636 |
637 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
638 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
639 |
640 | if len(from_shape) != len(to_shape):
641 | raise ValueError(
642 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
643 |
644 | if len(from_shape) == 3:
645 | batch_size = from_shape[0]
646 | from_seq_length = from_shape[1]
647 | to_seq_length = to_shape[1]
648 | elif len(from_shape) == 2:
649 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
650 | raise ValueError(
651 | "When passing in rank 2 tensors to attention_layer, the values "
652 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
653 | "must all be specified.")
654 |
655 | # Scalar dimensions referenced here:
656 | # B = batch size (number of sequences)
657 | # F = `from_tensor` sequence length
658 | # T = `to_tensor` sequence length
659 | # N = `num_attention_heads`
660 | # H = `size_per_head`
661 |
662 | from_tensor_2d = reshape_to_matrix(from_tensor)
663 | to_tensor_2d = reshape_to_matrix(to_tensor)
664 |
665 | # `query_layer` = [B*F, N*H]
666 | query_layer = tf.layers.dense(
667 | from_tensor_2d,
668 | num_attention_heads * size_per_head,
669 | activation=query_act,
670 | name="query",
671 | kernel_initializer=create_initializer(initializer_range))
672 |
673 | # `key_layer` = [B*T, N*H]
674 | key_layer = tf.layers.dense(
675 | to_tensor_2d,
676 | num_attention_heads * size_per_head,
677 | activation=key_act,
678 | name="key",
679 | kernel_initializer=create_initializer(initializer_range))
680 |
681 | # `value_layer` = [B*T, N*H]
682 | value_layer = tf.layers.dense(
683 | to_tensor_2d,
684 | num_attention_heads * size_per_head,
685 | activation=value_act,
686 | name="value",
687 | kernel_initializer=create_initializer(initializer_range))
688 |
689 | # `query_layer` = [B, N, F, H]
690 | query_layer = transpose_for_scores(query_layer, batch_size,
691 | num_attention_heads, from_seq_length,
692 | size_per_head)
693 |
694 | # `key_layer` = [B, N, T, H]
695 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
696 | to_seq_length, size_per_head)
697 |
698 | # Take the dot product between "query" and "key" to get the raw
699 | # attention scores.
700 | # `attention_scores` = [B, N, F, T]
701 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
702 | attention_scores = tf.multiply(attention_scores,
703 | 1.0 / math.sqrt(float(size_per_head)))
704 |
705 | if attention_mask is not None:
706 | # `attention_mask` = [B, 1, F, T]
707 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
708 |
709 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
710 | # masked positions, this operation will create a tensor which is 0.0 for
711 | # positions we want to attend and -10000.0 for masked positions.
712 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
713 |
714 | # Since we are adding it to the raw scores before the softmax, this is
715 | # effectively the same as removing these entirely.
716 | attention_scores += adder
717 |
718 | # Normalize the attention scores to probabilities.
719 | # `attention_probs` = [B, N, F, T]
720 | attention_probs = tf.nn.softmax(attention_scores)
721 |
722 | # This is actually dropping out entire tokens to attend to, which might
723 | # seem a bit unusual, but is taken from the original Transformer paper.
724 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
725 |
726 | # `value_layer` = [B, T, N, H]
727 | value_layer = tf.reshape(
728 | value_layer,
729 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
730 |
731 | # `value_layer` = [B, N, T, H]
732 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
733 |
734 | # `context_layer` = [B, N, F, H]
735 | context_layer = tf.matmul(attention_probs, value_layer)
736 |
737 | # `context_layer` = [B, F, N, H]
738 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
739 |
740 | if do_return_2d_tensor:
741 | # `context_layer` = [B*F, N*H]
742 | context_layer = tf.reshape(
743 | context_layer,
744 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
745 | else:
746 | # `context_layer` = [B, F, N*H]
747 | context_layer = tf.reshape(
748 | context_layer,
749 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
750 |
751 | return context_layer
752 |
753 |
754 | def transformer_model(input_tensor,
755 | attention_mask=None,
756 | hidden_size=768,
757 | num_hidden_layers=12,
758 | num_attention_heads=12,
759 | intermediate_size=3072,
760 | intermediate_act_fn=gelu,
761 | hidden_dropout_prob=0.1,
762 | attention_probs_dropout_prob=0.1,
763 | initializer_range=0.02,
764 | do_return_all_layers=False):
765 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
766 |
767 | This is almost an exact implementation of the original Transformer encoder.
768 |
769 | See the original paper:
770 | https://arxiv.org/abs/1706.03762
771 |
772 | Also see:
773 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
774 |
775 | Args:
776 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
777 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
778 | seq_length], with 1 for positions that can be attended to and 0 in
779 | positions that should not be.
780 | hidden_size: int. Hidden size of the Transformer.
781 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
782 | num_attention_heads: int. Number of attention heads in the Transformer.
783 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
784 | forward) layer.
785 | intermediate_act_fn: function. The non-linear activation function to apply
786 | to the output of the intermediate/feed-forward layer.
787 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
788 | attention_probs_dropout_prob: float. Dropout probability of the attention
789 | probabilities.
790 | initializer_range: float. Range of the initializer (stddev of truncated
791 | normal).
792 | do_return_all_layers: Whether to also return all layers or just the final
793 | layer.
794 |
795 | Returns:
796 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
797 | hidden layer of the Transformer.
798 |
799 | Raises:
800 | ValueError: A Tensor shape or parameter is invalid.
801 | """
802 | if hidden_size % num_attention_heads != 0:
803 | raise ValueError(
804 | "The hidden size (%d) is not a multiple of the number of attention "
805 | "heads (%d)" % (hidden_size, num_attention_heads))
806 |
807 | attention_head_size = int(hidden_size / num_attention_heads)
808 | input_shape = get_shape_list(input_tensor, expected_rank=3)
809 | batch_size = input_shape[0]
810 | seq_length = input_shape[1]
811 | input_width = input_shape[2]
812 |
813 | # The Transformer performs sum residuals on all layers so the input needs
814 | # to be the same as the hidden size.
815 | if input_width != hidden_size:
816 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
817 | (input_width, hidden_size))
818 |
819 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
820 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
821 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
822 | # help the optimizer.
823 | prev_output = reshape_to_matrix(input_tensor)
824 |
825 | all_layer_outputs = []
826 | for layer_idx in range(num_hidden_layers):
827 | with tf.variable_scope("layer_%d" % layer_idx):
828 | layer_input = prev_output
829 |
830 | with tf.variable_scope("attention"):
831 | attention_heads = []
832 | with tf.variable_scope("self"):
833 | attention_head = attention_layer(
834 | from_tensor=layer_input,
835 | to_tensor=layer_input,
836 | attention_mask=attention_mask,
837 | num_attention_heads=num_attention_heads,
838 | size_per_head=attention_head_size,
839 | attention_probs_dropout_prob=attention_probs_dropout_prob,
840 | initializer_range=initializer_range,
841 | do_return_2d_tensor=True,
842 | batch_size=batch_size,
843 | from_seq_length=seq_length,
844 | to_seq_length=seq_length)
845 | attention_heads.append(attention_head)
846 |
847 | attention_output = None
848 | if len(attention_heads) == 1:
849 | attention_output = attention_heads[0]
850 | else:
851 | # In the case where we have other sequences, we just concatenate
852 | # them to the self-attention head before the projection.
853 | attention_output = tf.concat(attention_heads, axis=-1)
854 |
855 | # Run a linear projection of `hidden_size` then add a residual
856 | # with `layer_input`.
857 | with tf.variable_scope("output"):
858 | attention_output = tf.layers.dense(
859 | attention_output,
860 | hidden_size,
861 | kernel_initializer=create_initializer(initializer_range))
862 | attention_output = dropout(attention_output, hidden_dropout_prob)
863 | attention_output = layer_norm(attention_output + layer_input)
864 |
865 | # The activation is only applied to the "intermediate" hidden layer.
866 | with tf.variable_scope("intermediate"):
867 | intermediate_output = tf.layers.dense(
868 | attention_output,
869 | intermediate_size,
870 | activation=intermediate_act_fn,
871 | kernel_initializer=create_initializer(initializer_range))
872 |
873 | # Down-project back to `hidden_size` then add the residual.
874 | with tf.variable_scope("output"):
875 | layer_output = tf.layers.dense(
876 | intermediate_output,
877 | hidden_size,
878 | kernel_initializer=create_initializer(initializer_range))
879 | layer_output = dropout(layer_output, hidden_dropout_prob)
880 | layer_output = layer_norm(layer_output + attention_output)
881 | prev_output = layer_output
882 | all_layer_outputs.append(layer_output)
883 |
884 | if do_return_all_layers:
885 | final_outputs = []
886 | for layer_output in all_layer_outputs:
887 | final_output = reshape_from_matrix(layer_output, input_shape)
888 | final_outputs.append(final_output)
889 | return final_outputs
890 | else:
891 | final_output = reshape_from_matrix(prev_output, input_shape)
892 | return final_output
893 |
894 |
895 | def get_shape_list(tensor, expected_rank=None, name=None):
896 | """Returns a list of the shape of tensor, preferring static dimensions.
897 |
898 | Args:
899 | tensor: A tf.Tensor object to find the shape of.
900 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
901 | specified and the `tensor` has a different rank, and exception will be
902 | thrown.
903 | name: Optional name of the tensor for the error message.
904 |
905 | Returns:
906 | A list of dimensions of the shape of tensor. All static dimensions will
907 | be returned as python integers, and dynamic dimensions will be returned
908 | as tf.Tensor scalars.
909 | """
910 | if name is None:
911 | name = tensor.name
912 |
913 | if expected_rank is not None:
914 | assert_rank(tensor, expected_rank, name)
915 |
916 | shape = tensor.shape.as_list()
917 |
918 | non_static_indexes = []
919 | for (index, dim) in enumerate(shape):
920 | if dim is None:
921 | non_static_indexes.append(index)
922 |
923 | if not non_static_indexes:
924 | return shape
925 |
926 | dyn_shape = tf.shape(tensor)
927 | for index in non_static_indexes:
928 | shape[index] = dyn_shape[index]
929 | return shape
930 |
931 |
932 | def reshape_to_matrix(input_tensor):
933 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
934 | ndims = input_tensor.shape.ndims
935 | if ndims < 2:
936 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
937 | (input_tensor.shape))
938 | if ndims == 2:
939 | return input_tensor
940 |
941 | width = input_tensor.shape[-1]
942 | output_tensor = tf.reshape(input_tensor, [-1, width])
943 | return output_tensor
944 |
945 |
946 | def reshape_from_matrix(output_tensor, orig_shape_list):
947 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
948 | if len(orig_shape_list) == 2:
949 | return output_tensor
950 |
951 | output_shape = get_shape_list(output_tensor)
952 |
953 | orig_dims = orig_shape_list[0:-1]
954 | width = output_shape[-1]
955 |
956 | return tf.reshape(output_tensor, orig_dims + [width])
957 |
958 |
959 | def assert_rank(tensor, expected_rank, name=None):
960 | """Raises an exception if the tensor rank is not of the expected rank.
961 |
962 | Args:
963 | tensor: A tf.Tensor to check the rank of.
964 | expected_rank: Python integer or list of integers, expected rank.
965 | name: Optional name of the tensor for the error message.
966 |
967 | Raises:
968 | ValueError: If the expected shape doesn't match the actual shape.
969 | """
970 | if name is None:
971 | name = tensor.name
972 |
973 | expected_rank_dict = {}
974 | if isinstance(expected_rank, six.integer_types):
975 | expected_rank_dict[expected_rank] = True
976 | else:
977 | for x in expected_rank:
978 | expected_rank_dict[x] = True
979 |
980 | actual_rank = tensor.shape.ndims
981 | if actual_rank not in expected_rank_dict:
982 | scope_name = tf.get_variable_scope().name
983 | raise ValueError(
984 | "For the tensor `%s` in scope `%s`, the actual rank "
985 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
986 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
987 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.98, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 |
72 | # tvars=find_train_variables(tvars) # fix parameters from layer 0 to layer9.
73 |
74 | grads = tf.gradients(loss, tvars)
75 |
76 | # This is how the model was pre-trained.
77 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
78 |
79 | train_op = optimizer.apply_gradients(
80 | zip(grads, tvars), global_step=global_step)
81 |
82 | # Normally the global step update is done inside of `apply_gradients`.
83 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
84 | # a different optimizer, you should probably take this line out.
85 | new_global_step = global_step + 1
86 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
87 | return train_op
88 |
89 | def find_train_variables(tvars):
90 | """
91 | get trainable variables only to train from layer 9 to last layer
92 | :param tvars: a list
93 | :return: a new tvars, which is list
94 | """
95 | # bert/encoder/layer_21, bert/encoder/layer_9, bert/encoder/layer_20/attention/output/dense/bias:0, bert/encoder/layer_20/attention/output/dense/kernel:
96 | tvars_result_list=[]
97 |
98 | for var in tvars:
99 | if 'cls/predictions' in var.name or 'bert/pooler/dense' in var.name: # 最后几层
100 | tvars_result_list.append(var)
101 | else: # 后半个网络的参数
102 | layer_number_list=re.findall("layer_(.+?)/", var.name)
103 | if len(layer_number_list)>0 and isinstance(layer_number_list[0],int): # 匹配到了一个数字
104 | layer_number=int(layer_number_list[0])
105 | if layer_number>=9:
106 | tvars_result_list.append(var)
107 |
108 | # print train variables
109 | for i,var_ in enumerate(tvars_result_list):
110 | print("####find_train_variables.i:",i, "variable name:",var_.name)
111 |
112 | print("####find_train_variables:length of tvars_result_list:",tvars_result_list)
113 | return tvars_result_list
114 |
115 |
116 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
117 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
118 |
119 | def __init__(self,
120 | learning_rate,
121 | weight_decay_rate=0.0,
122 | beta_1=0.9,
123 | beta_2=0.999,
124 | epsilon=1e-6,
125 | exclude_from_weight_decay=None,
126 | name="AdamWeightDecayOptimizer"):
127 | """Constructs a AdamWeightDecayOptimizer."""
128 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
129 |
130 | self.learning_rate = learning_rate
131 | self.weight_decay_rate = weight_decay_rate
132 | self.beta_1 = beta_1
133 | self.beta_2 = beta_2
134 | self.epsilon = epsilon
135 | self.exclude_from_weight_decay = exclude_from_weight_decay
136 |
137 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
138 | """See base class."""
139 | assignments = []
140 | for (grad, param) in grads_and_vars:
141 | if grad is None or param is None:
142 | continue
143 |
144 | param_name = self._get_variable_name(param.name)
145 |
146 | m = tf.get_variable(
147 | name=param_name + "/adam_m",
148 | shape=param.shape.as_list(),
149 | dtype=tf.float32,
150 | trainable=False,
151 | initializer=tf.zeros_initializer())
152 | v = tf.get_variable(
153 | name=param_name + "/adam_v",
154 | shape=param.shape.as_list(),
155 | dtype=tf.float32,
156 | trainable=False,
157 | initializer=tf.zeros_initializer())
158 |
159 | # Standard Adam update.
160 | next_m = (
161 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
162 | next_v = (
163 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
164 | tf.square(grad)))
165 |
166 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
167 |
168 | # Just adding the square of the weights to the loss function is *not*
169 | # the correct way of using L2 regularization/weight decay with Adam,
170 | # since that will interact with the m and v parameters in strange ways.
171 | #
172 | # Instead we want ot decay the weights in a manner that doesn't interact
173 | # with the m/v parameters. This is equivalent to adding the square
174 | # of the weights to the loss with plain (non-momentum) SGD.
175 | if self._do_use_weight_decay(param_name):
176 | update += self.weight_decay_rate * param
177 |
178 | update_with_lr = self.learning_rate * update
179 |
180 | next_param = param - update_with_lr
181 |
182 | assignments.extend(
183 | [param.assign(next_param),
184 | m.assign(next_m),
185 | v.assign(next_v)])
186 | return tf.group(*assignments, name=name)
187 |
188 | def _do_use_weight_decay(self, param_name):
189 | """Whether to use L2 weight decay for `param_name`."""
190 | if not self.weight_decay_rate:
191 | return False
192 | if self.exclude_from_weight_decay:
193 | for r in self.exclude_from_weight_decay:
194 | if re.search(r, param_name) is not None:
195 | return False
196 | return True
197 |
198 | def _get_variable_name(self, param_name):
199 | """Get the variable name from the tensor name."""
200 | m = re.match("^(.*):\\d+$", param_name)
201 | if m is not None:
202 | param_name = m.group(1)
203 | return param_name
204 |
--------------------------------------------------------------------------------
/optimization_finetuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/resources/RoBERTa_zh_Large_Learning_Curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brightmart/roberta_zh/438476f7da1661faf45e5b8da9f55df403e44997/resources/RoBERTa_zh_Large_Learning_Curve.png
--------------------------------------------------------------------------------
/run_classifier.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import csv
23 | import os
24 | import modeling
25 | import optimization_finetuning as optimization
26 | import tokenization
27 | import tensorflow as tf
28 | # from loss import bi_tempered_logistic_loss
29 |
30 | flags = tf.flags
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | ## Required parameters
35 | flags.DEFINE_string(
36 | "data_dir", None,
37 | "The input data dir. Should contain the .tsv files (or other data files) "
38 | "for the task.")
39 |
40 | flags.DEFINE_string(
41 | "bert_config_file", None,
42 | "The config json file corresponding to the pre-trained BERT model. "
43 | "This specifies the model architecture.")
44 |
45 | flags.DEFINE_string("task_name", None, "The name of the task to train.")
46 |
47 | flags.DEFINE_string("vocab_file", None,
48 | "The vocabulary file that the BERT model was trained on.")
49 |
50 | flags.DEFINE_string(
51 | "output_dir", None,
52 | "The output directory where the model checkpoints will be written.")
53 |
54 | ## Other parameters
55 |
56 | flags.DEFINE_string(
57 | "init_checkpoint", None,
58 | "Initial checkpoint (usually from a pre-trained BERT model).")
59 |
60 | flags.DEFINE_bool(
61 | "do_lower_case", True,
62 | "Whether to lower case the input text. Should be True for uncased "
63 | "models and False for cased models.")
64 |
65 | flags.DEFINE_integer(
66 | "max_seq_length", 128,
67 | "The maximum total input sequence length after WordPiece tokenization. "
68 | "Sequences longer than this will be truncated, and sequences shorter "
69 | "than this will be padded.")
70 |
71 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
72 |
73 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
74 |
75 | flags.DEFINE_bool(
76 | "do_predict", False,
77 | "Whether to run the model in inference mode on the test set.")
78 |
79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
80 |
81 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
82 |
83 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
84 |
85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
86 |
87 | flags.DEFINE_float("num_train_epochs", 3.0,
88 | "Total number of training epochs to perform.")
89 |
90 | flags.DEFINE_float(
91 | "warmup_proportion", 0.1,
92 | "Proportion of training to perform linear learning rate warmup for. "
93 | "E.g., 0.1 = 10% of training.")
94 |
95 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
96 | "How often to save the model checkpoint.")
97 |
98 | flags.DEFINE_integer("iterations_per_loop", 1000,
99 | "How many steps to make in each estimator call.")
100 |
101 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
102 |
103 | tf.flags.DEFINE_string(
104 | "tpu_name", None,
105 | "The Cloud TPU to use for training. This should be either the name "
106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
107 | "url.")
108 |
109 | tf.flags.DEFINE_string(
110 | "tpu_zone", None,
111 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
112 | "specified, we will attempt to automatically detect the GCE project from "
113 | "metadata.")
114 |
115 | tf.flags.DEFINE_string(
116 | "gcp_project", None,
117 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
118 | "specified, we will attempt to automatically detect the GCE project from "
119 | "metadata.")
120 |
121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
122 |
123 | flags.DEFINE_integer(
124 | "num_tpu_cores", 8,
125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
126 |
127 |
128 | class InputExample(object):
129 | """A single training/test example for simple sequence classification."""
130 |
131 | def __init__(self, guid, text_a, text_b=None, label=None):
132 | """Constructs a InputExample.
133 | Args:
134 | guid: Unique id for the example.
135 | text_a: string. The untokenized text of the first sequence. For single
136 | sequence tasks, only this sequence must be specified.
137 | text_b: (Optional) string. The untokenized text of the second sequence.
138 | Only must be specified for sequence pair tasks.
139 | label: (Optional) string. The label of the example. This should be
140 | specified for train and dev examples, but not for test examples.
141 | """
142 | self.guid = guid
143 | self.text_a = text_a
144 | self.text_b = text_b
145 | self.label = label
146 |
147 |
148 | class PaddingInputExample(object):
149 | """Fake example so the num input examples is a multiple of the batch size.
150 | When running eval/predict on the TPU, we need to pad the number of examples
151 | to be a multiple of the batch size, because the TPU requires a fixed batch
152 | size. The alternative is to drop the last batch, which is bad because it means
153 | the entire output data won't be generated.
154 | We use this class instead of `None` because treating `None` as padding
155 | battches could cause silent errors.
156 | """
157 |
158 |
159 | class InputFeatures(object):
160 | """A single set of features of data."""
161 |
162 | def __init__(self,
163 | input_ids,
164 | input_mask,
165 | segment_ids,
166 | label_id,
167 | is_real_example=True):
168 | self.input_ids = input_ids
169 | self.input_mask = input_mask
170 | self.segment_ids = segment_ids
171 | self.label_id = label_id
172 | self.is_real_example = is_real_example
173 |
174 |
175 | class DataProcessor(object):
176 | """Base class for data converters for sequence classification data sets."""
177 |
178 | def get_train_examples(self, data_dir):
179 | """Gets a collection of `InputExample`s for the train set."""
180 | raise NotImplementedError()
181 |
182 | def get_dev_examples(self, data_dir):
183 | """Gets a collection of `InputExample`s for the dev set."""
184 | raise NotImplementedError()
185 |
186 | def get_test_examples(self, data_dir):
187 | """Gets a collection of `InputExample`s for prediction."""
188 | raise NotImplementedError()
189 |
190 | def get_labels(self):
191 | """Gets the list of labels for this data set."""
192 | raise NotImplementedError()
193 |
194 | @classmethod
195 | def _read_tsv(cls, input_file, quotechar=None):
196 | """Reads a tab separated value file."""
197 | with tf.gfile.Open(input_file, "r") as f:
198 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
199 | lines = []
200 | for line in reader:
201 | lines.append(line)
202 | return lines
203 |
204 |
205 | class LCQMCPairClassificationProcessor(DataProcessor): # TODO NEED CHANGE2
206 | """Processor for the internal data set. sentence pair classification"""
207 | def __init__(self):
208 | self.language = "zh"
209 |
210 | def get_train_examples(self, data_dir):
211 | """See base class."""
212 | return self._create_examples(
213 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
214 | # dev_0827.tsv
215 |
216 | def get_dev_examples(self, data_dir):
217 | """See base class."""
218 | return self._create_examples(
219 | self._read_tsv(os.path.join(data_dir, "dev.txt")), "dev")
220 |
221 | def get_test_examples(self, data_dir):
222 | """See base class."""
223 | return self._create_examples(
224 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
225 |
226 | def get_labels(self):
227 | """See base class."""
228 | return ["0", "1"]
229 |
230 | def _create_examples(self, lines, set_type):
231 | """Creates examples for the training and dev sets."""
232 | examples = []
233 | print("length of lines:",len(lines))
234 | for (i, line) in enumerate(lines):
235 | #print('#i:',i,line)
236 | if i == 0:
237 | continue
238 | guid = "%s-%s" % (set_type, i)
239 | try:
240 | label = tokenization.convert_to_unicode(line[2])
241 | text_a = tokenization.convert_to_unicode(line[0])
242 | text_b = tokenization.convert_to_unicode(line[1])
243 | examples.append(
244 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
245 | except Exception:
246 | print('###error.i:', i, line)
247 | return examples
248 |
249 |
250 | def convert_single_example(ex_index, example, label_list, max_seq_length,
251 | tokenizer):
252 | """Converts a single `InputExample` into a single `InputFeatures`."""
253 |
254 | if isinstance(example, PaddingInputExample):
255 | return InputFeatures(
256 | input_ids=[0] * max_seq_length,
257 | input_mask=[0] * max_seq_length,
258 | segment_ids=[0] * max_seq_length,
259 | label_id=0,
260 | is_real_example=False)
261 |
262 | label_map = {}
263 | for (i, label) in enumerate(label_list):
264 | label_map[label] = i
265 |
266 | tokens_a = tokenizer.tokenize(example.text_a)
267 | tokens_b = None
268 | if example.text_b:
269 | tokens_b = tokenizer.tokenize(example.text_b)
270 |
271 | if tokens_b:
272 | # Modifies `tokens_a` and `tokens_b` in place so that the total
273 | # length is less than the specified length.
274 | # Account for [CLS], [SEP], [SEP] with "- 3"
275 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
276 | else:
277 | # Account for [CLS] and [SEP] with "- 2"
278 | if len(tokens_a) > max_seq_length - 2:
279 | tokens_a = tokens_a[0:(max_seq_length - 2)]
280 |
281 | # The convention in BERT is:
282 | # (a) For sequence pairs:
283 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
284 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
285 | # (b) For single sequences:
286 | # tokens: [CLS] the dog is hairy . [SEP]
287 | # type_ids: 0 0 0 0 0 0 0
288 | #
289 | # Where "type_ids" are used to indicate whether this is the first
290 | # sequence or the second sequence. The embedding vectors for `type=0` and
291 | # `type=1` were learned during pre-training and are added to the wordpiece
292 | # embedding vector (and position vector). This is not *strictly* necessary
293 | # since the [SEP] token unambiguously separates the sequences, but it makes
294 | # it easier for the model to learn the concept of sequences.
295 | #
296 | # For classification tasks, the first vector (corresponding to [CLS]) is
297 | # used as the "sentence vector". Note that this only makes sense because
298 | # the entire model is fine-tuned.
299 | tokens = []
300 | segment_ids = []
301 | tokens.append("[CLS]")
302 | segment_ids.append(0)
303 | for token in tokens_a:
304 | tokens.append(token)
305 | segment_ids.append(0)
306 | tokens.append("[SEP]")
307 | segment_ids.append(0)
308 |
309 | if tokens_b:
310 | for token in tokens_b:
311 | tokens.append(token)
312 | segment_ids.append(1)
313 | tokens.append("[SEP]")
314 | segment_ids.append(1)
315 |
316 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
317 |
318 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
319 | # tokens are attended to.
320 | input_mask = [1] * len(input_ids)
321 |
322 | # Zero-pad up to the sequence length.
323 | while len(input_ids) < max_seq_length:
324 | input_ids.append(0)
325 | input_mask.append(0)
326 | segment_ids.append(0)
327 |
328 | assert len(input_ids) == max_seq_length
329 | assert len(input_mask) == max_seq_length
330 | assert len(segment_ids) == max_seq_length
331 |
332 | label_id = label_map[example.label]
333 | if ex_index < 5:
334 | tf.logging.info("*** Example ***")
335 | tf.logging.info("guid: %s" % (example.guid))
336 | tf.logging.info("tokens: %s" % " ".join(
337 | [tokenization.printable_text(x) for x in tokens]))
338 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
339 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
340 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
341 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
342 |
343 | feature = InputFeatures(
344 | input_ids=input_ids,
345 | input_mask=input_mask,
346 | segment_ids=segment_ids,
347 | label_id=label_id,
348 | is_real_example=True)
349 | return feature
350 |
351 |
352 | def file_based_convert_examples_to_features(
353 | examples, label_list, max_seq_length, tokenizer, output_file):
354 | """Convert a set of `InputExample`s to a TFRecord file."""
355 |
356 | writer = tf.python_io.TFRecordWriter(output_file)
357 |
358 | for (ex_index, example) in enumerate(examples):
359 | if ex_index % 10000 == 0:
360 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
361 |
362 | feature = convert_single_example(ex_index, example, label_list,
363 | max_seq_length, tokenizer)
364 |
365 | def create_int_feature(values):
366 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
367 | return f
368 |
369 | features = collections.OrderedDict()
370 | features["input_ids"] = create_int_feature(feature.input_ids)
371 | features["input_mask"] = create_int_feature(feature.input_mask)
372 | features["segment_ids"] = create_int_feature(feature.segment_ids)
373 | features["label_ids"] = create_int_feature([feature.label_id])
374 | features["is_real_example"] = create_int_feature(
375 | [int(feature.is_real_example)])
376 |
377 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
378 | writer.write(tf_example.SerializeToString())
379 | writer.close()
380 |
381 |
382 | def file_based_input_fn_builder(input_file, seq_length, is_training,
383 | drop_remainder):
384 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
385 |
386 | name_to_features = {
387 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
388 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
389 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
390 | "label_ids": tf.FixedLenFeature([], tf.int64),
391 | "is_real_example": tf.FixedLenFeature([], tf.int64),
392 | }
393 |
394 | def _decode_record(record, name_to_features):
395 | """Decodes a record to a TensorFlow example."""
396 | example = tf.parse_single_example(record, name_to_features)
397 |
398 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
399 | # So cast all int64 to int32.
400 | for name in list(example.keys()):
401 | t = example[name]
402 | if t.dtype == tf.int64:
403 | t = tf.to_int32(t)
404 | example[name] = t
405 |
406 | return example
407 |
408 | def input_fn(params):
409 | """The actual input function."""
410 | batch_size = params["batch_size"]
411 |
412 | # For training, we want a lot of parallel reading and shuffling.
413 | # For eval, we want no shuffling and parallel reading doesn't matter.
414 | d = tf.data.TFRecordDataset(input_file)
415 | if is_training:
416 | d = d.repeat()
417 | d = d.shuffle(buffer_size=100)
418 |
419 | d = d.apply(
420 | tf.contrib.data.map_and_batch(
421 | lambda record: _decode_record(record, name_to_features),
422 | batch_size=batch_size,
423 | drop_remainder=drop_remainder))
424 |
425 | return d
426 |
427 | return input_fn
428 |
429 |
430 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
431 | """Truncates a sequence pair in place to the maximum length."""
432 |
433 | # This is a simple heuristic which will always truncate the longer sequence
434 | # one token at a time. This makes more sense than truncating an equal percent
435 | # of tokens from each, since if one sequence is very short then each token
436 | # that's truncated likely contains more information than a longer sequence.
437 | while True:
438 | total_length = len(tokens_a) + len(tokens_b)
439 | if total_length <= max_length:
440 | break
441 | if len(tokens_a) > len(tokens_b):
442 | tokens_a.pop()
443 | else:
444 | tokens_b.pop()
445 |
446 |
447 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
448 | labels, num_labels, use_one_hot_embeddings):
449 | """Creates a classification model."""
450 | model = modeling.BertModel(
451 | config=bert_config,
452 | is_training=is_training,
453 | input_ids=input_ids,
454 | input_mask=input_mask,
455 | token_type_ids=segment_ids,
456 | use_one_hot_embeddings=use_one_hot_embeddings)
457 |
458 | # In the demo, we are doing a simple classification task on the entire
459 | # segment.
460 | #
461 | # If you want to use the token-level output, use model.get_sequence_output()
462 | # instead.
463 | output_layer = model.get_pooled_output()
464 |
465 | hidden_size = output_layer.shape[-1].value
466 |
467 | output_weights = tf.get_variable(
468 | "output_weights", [num_labels, hidden_size],
469 | initializer=tf.truncated_normal_initializer(stddev=0.02))
470 |
471 | output_bias = tf.get_variable(
472 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
473 |
474 | with tf.variable_scope("loss"):
475 | if is_training:
476 | # I.e., 0.1 dropout
477 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
478 |
479 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
480 | logits = tf.nn.bias_add(logits, output_bias)
481 | probabilities = tf.nn.softmax(logits, axis=-1)
482 | log_probs = tf.nn.log_softmax(logits, axis=-1)
483 |
484 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
485 |
486 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) # todo 08-29 try temp-loss
487 | ###############bi_tempered_logistic_loss############################################################################
488 | # print("##cross entropy loss is used...."); tf.logging.info("##cross entropy loss is used....")
489 | # t1=0.9 #t1=0.90
490 | # t2=1.05 #t2=1.05
491 | # per_example_loss=bi_tempered_logistic_loss(log_probs,one_hot_labels,t1,t2,label_smoothing=0.1,num_iters=5) # TODO label_smoothing=0.0
492 | #tf.logging.info("per_example_loss:"+str(per_example_loss.shape))
493 | ##############bi_tempered_logistic_loss#############################################################################
494 |
495 | loss = tf.reduce_mean(per_example_loss)
496 |
497 | return (loss, per_example_loss, logits, probabilities)
498 |
499 |
500 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
501 | num_train_steps, num_warmup_steps, use_tpu,
502 | use_one_hot_embeddings):
503 | """Returns `model_fn` closure for TPUEstimator."""
504 |
505 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
506 | """The `model_fn` for TPUEstimator."""
507 |
508 | tf.logging.info("*** Features ***")
509 | for name in sorted(features.keys()):
510 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
511 |
512 | input_ids = features["input_ids"]
513 | input_mask = features["input_mask"]
514 | segment_ids = features["segment_ids"]
515 | label_ids = features["label_ids"]
516 | is_real_example = None
517 | if "is_real_example" in features:
518 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
519 | else:
520 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
521 |
522 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
523 |
524 | (total_loss, per_example_loss, logits, probabilities) = create_model(
525 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
526 | num_labels, use_one_hot_embeddings)
527 |
528 | tvars = tf.trainable_variables()
529 | initialized_variable_names = {}
530 | scaffold_fn = None
531 | if init_checkpoint:
532 | (assignment_map, initialized_variable_names
533 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
534 | if use_tpu:
535 |
536 | def tpu_scaffold():
537 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
538 | return tf.train.Scaffold()
539 |
540 | scaffold_fn = tpu_scaffold
541 | else:
542 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
543 |
544 | tf.logging.info("**** Trainable Variables ****")
545 | for var in tvars:
546 | init_string = ""
547 | if var.name in initialized_variable_names:
548 | init_string = ", *INIT_FROM_CKPT*"
549 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
550 | init_string)
551 |
552 | output_spec = None
553 | if mode == tf.estimator.ModeKeys.TRAIN:
554 |
555 | train_op = optimization.create_optimizer(
556 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
557 |
558 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
559 | mode=mode,
560 | loss=total_loss,
561 | train_op=train_op,
562 | scaffold_fn=scaffold_fn)
563 | elif mode == tf.estimator.ModeKeys.EVAL:
564 |
565 | def metric_fn(per_example_loss, label_ids, logits, is_real_example):
566 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
567 | accuracy = tf.metrics.accuracy(
568 | labels=label_ids, predictions=predictions, weights=is_real_example)
569 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
570 | return {
571 | "eval_accuracy": accuracy,
572 | "eval_loss": loss,
573 | }
574 |
575 | eval_metrics = (metric_fn,
576 | [per_example_loss, label_ids, logits, is_real_example])
577 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
578 | mode=mode,
579 | loss=total_loss,
580 | eval_metrics=eval_metrics,
581 | scaffold_fn=scaffold_fn)
582 | else:
583 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
584 | mode=mode,
585 | predictions={"probabilities": probabilities},
586 | scaffold_fn=scaffold_fn)
587 | return output_spec
588 |
589 | return model_fn
590 |
591 |
592 | # This function is not used by this file but is still used by the Colab and
593 | # people who depend on it.
594 | def input_fn_builder(features, seq_length, is_training, drop_remainder):
595 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
596 |
597 | all_input_ids = []
598 | all_input_mask = []
599 | all_segment_ids = []
600 | all_label_ids = []
601 |
602 | for feature in features:
603 | all_input_ids.append(feature.input_ids)
604 | all_input_mask.append(feature.input_mask)
605 | all_segment_ids.append(feature.segment_ids)
606 | all_label_ids.append(feature.label_id)
607 |
608 | def input_fn(params):
609 | """The actual input function."""
610 | batch_size = params["batch_size"]
611 |
612 | num_examples = len(features)
613 |
614 | # This is for demo purposes and does NOT scale to large data sets. We do
615 | # not use Dataset.from_generator() because that uses tf.py_func which is
616 | # not TPU compatible. The right way to load data is with TFRecordReader.
617 | d = tf.data.Dataset.from_tensor_slices({
618 | "input_ids":
619 | tf.constant(
620 | all_input_ids, shape=[num_examples, seq_length],
621 | dtype=tf.int32),
622 | "input_mask":
623 | tf.constant(
624 | all_input_mask,
625 | shape=[num_examples, seq_length],
626 | dtype=tf.int32),
627 | "segment_ids":
628 | tf.constant(
629 | all_segment_ids,
630 | shape=[num_examples, seq_length],
631 | dtype=tf.int32),
632 | "label_ids":
633 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
634 | })
635 |
636 | if is_training:
637 | d = d.repeat()
638 | d = d.shuffle(buffer_size=100)
639 |
640 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
641 | return d
642 |
643 | return input_fn
644 |
645 | class LCQMCPairClassificationProcessor(DataProcessor): # TODO NEED CHANGE2
646 | """Processor for the internal data set. sentence pair classification"""
647 | def __init__(self):
648 | self.language = "zh"
649 |
650 | def get_train_examples(self, data_dir):
651 | """See base class."""
652 | return self._create_examples(
653 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
654 | # dev_0827.tsv
655 |
656 | def get_dev_examples(self, data_dir):
657 | """See base class."""
658 | return self._create_examples(
659 | self._read_tsv(os.path.join(data_dir, "test.txt")), "dev") # todo change temp for test purpose
660 |
661 | def get_test_examples(self, data_dir):
662 | """See base class."""
663 | return self._create_examples(
664 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
665 |
666 | def get_labels(self):
667 | """See base class."""
668 | return ["0", "1"]
669 | #return ["-1","0", "1"]
670 |
671 | def _create_examples(self, lines, set_type):
672 | """Creates examples for the training and dev sets."""
673 | examples = []
674 | print("length of lines:",len(lines))
675 | for (i, line) in enumerate(lines):
676 | #print('#i:',i,line)
677 | if i == 0:
678 | continue
679 | guid = "%s-%s" % (set_type, i)
680 | try:
681 | label = tokenization.convert_to_unicode(line[2])
682 | text_a = tokenization.convert_to_unicode(line[0])
683 | text_b = tokenization.convert_to_unicode(line[1])
684 | examples.append(
685 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
686 | except Exception:
687 | print('###error.i:', i, line)
688 | return examples
689 |
690 | class SentencePairClassificationProcessor(DataProcessor):
691 | """Processor for the internal data set. sentence pair classification"""
692 | def __init__(self):
693 | self.language = "zh"
694 |
695 | def get_train_examples(self, data_dir):
696 | """See base class."""
697 | return self._create_examples(
698 | self._read_tsv(os.path.join(data_dir, "train_0827.tsv")), "train")
699 | # dev_0827.tsv
700 |
701 | def get_dev_examples(self, data_dir):
702 | """See base class."""
703 | return self._create_examples(
704 | self._read_tsv(os.path.join(data_dir, "dev_0827.tsv")), "dev")
705 |
706 | def get_test_examples(self, data_dir):
707 | """See base class."""
708 | return self._create_examples(
709 | self._read_tsv(os.path.join(data_dir, "test_0827.tsv")), "test")
710 |
711 | def get_labels(self):
712 | """See base class."""
713 | return ["0", "1"]
714 | #return ["-1","0", "1"]
715 |
716 | def _create_examples(self, lines, set_type):
717 | """Creates examples for the training and dev sets."""
718 | examples = []
719 | print("length of lines:",len(lines))
720 | for (i, line) in enumerate(lines):
721 | #print('#i:',i,line)
722 | if i == 0:
723 | continue
724 | guid = "%s-%s" % (set_type, i)
725 | try:
726 | label = tokenization.convert_to_unicode(line[0])
727 | text_a = tokenization.convert_to_unicode(line[1])
728 | text_b = tokenization.convert_to_unicode(line[2])
729 | examples.append(
730 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
731 | except Exception:
732 | print('###error.i:', i, line)
733 | return examples
734 |
735 | # This function is not used by this file but is still used by the Colab and
736 | # people who depend on it.
737 | def convert_examples_to_features(examples, label_list, max_seq_length,
738 | tokenizer):
739 | """Convert a set of `InputExample`s to a list of `InputFeatures`."""
740 |
741 | features = []
742 | for (ex_index, example) in enumerate(examples):
743 | if ex_index % 10000 == 0:
744 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
745 |
746 | feature = convert_single_example(ex_index, example, label_list,
747 | max_seq_length, tokenizer)
748 |
749 | features.append(feature)
750 | return features
751 |
752 |
753 | def main(_):
754 | tf.logging.set_verbosity(tf.logging.INFO)
755 |
756 | processors = {
757 | "sentence_pair": SentencePairClassificationProcessor,
758 | "lcqmc_pair":LCQMCPairClassificationProcessor
759 |
760 |
761 | }
762 |
763 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
764 | FLAGS.init_checkpoint)
765 |
766 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
767 | raise ValueError(
768 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.")
769 |
770 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
771 |
772 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
773 | raise ValueError(
774 | "Cannot use sequence length %d because the BERT model "
775 | "was only trained up to sequence length %d" %
776 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
777 |
778 | tf.gfile.MakeDirs(FLAGS.output_dir)
779 |
780 | task_name = FLAGS.task_name.lower()
781 |
782 | if task_name not in processors:
783 | raise ValueError("Task not found: %s" % (task_name))
784 |
785 | processor = processors[task_name]()
786 |
787 | label_list = processor.get_labels()
788 |
789 | tokenizer = tokenization.FullTokenizer(
790 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
791 |
792 | tpu_cluster_resolver = None
793 | if FLAGS.use_tpu and FLAGS.tpu_name:
794 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
795 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
796 |
797 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
798 | # Cloud TPU: Invalid TPU configuration, ensure ClusterResolver is passed to tpu.
799 | print("###tpu_cluster_resolver:",tpu_cluster_resolver)
800 | run_config = tf.contrib.tpu.RunConfig(
801 | cluster=tpu_cluster_resolver,
802 | master=FLAGS.master,
803 | model_dir=FLAGS.output_dir,
804 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
805 | tpu_config=tf.contrib.tpu.TPUConfig(
806 | iterations_per_loop=FLAGS.iterations_per_loop,
807 | num_shards=FLAGS.num_tpu_cores,
808 | per_host_input_for_training=is_per_host))
809 |
810 | train_examples = None
811 | num_train_steps = None
812 | num_warmup_steps = None
813 | if FLAGS.do_train:
814 | train_examples =processor.get_train_examples(FLAGS.data_dir) # TODO
815 | print("###length of total train_examples:",len(train_examples))
816 | num_train_steps = int(len(train_examples)/ FLAGS.train_batch_size * FLAGS.num_train_epochs)
817 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
818 |
819 | model_fn = model_fn_builder(
820 | bert_config=bert_config,
821 | num_labels=len(label_list),
822 | init_checkpoint=FLAGS.init_checkpoint,
823 | learning_rate=FLAGS.learning_rate,
824 | num_train_steps=num_train_steps,
825 | num_warmup_steps=num_warmup_steps,
826 | use_tpu=FLAGS.use_tpu,
827 | use_one_hot_embeddings=FLAGS.use_tpu)
828 |
829 | # If TPU is not available, this will fall back to normal Estimator on CPU
830 | # or GPU.
831 | estimator = tf.contrib.tpu.TPUEstimator(
832 | use_tpu=FLAGS.use_tpu,
833 | model_fn=model_fn,
834 | config=run_config,
835 | train_batch_size=FLAGS.train_batch_size,
836 | eval_batch_size=FLAGS.eval_batch_size,
837 | predict_batch_size=FLAGS.predict_batch_size)
838 |
839 | if FLAGS.do_train:
840 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
841 | train_file_exists=os.path.exists(train_file)
842 | print("###train_file_exists:", train_file_exists," ;train_file:",train_file)
843 | if not train_file_exists: # if tf_record file not exist, convert from raw text file. # TODO
844 | file_based_convert_examples_to_features(train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
845 | tf.logging.info("***** Running training *****")
846 | tf.logging.info(" Num examples = %d", len(train_examples))
847 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
848 | tf.logging.info(" Num steps = %d", num_train_steps)
849 | train_input_fn = file_based_input_fn_builder(
850 | input_file=train_file,
851 | seq_length=FLAGS.max_seq_length,
852 | is_training=True,
853 | drop_remainder=True)
854 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
855 |
856 | if FLAGS.do_eval:
857 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
858 | num_actual_eval_examples = len(eval_examples)
859 | if FLAGS.use_tpu:
860 | # TPU requires a fixed batch size for all batches, therefore the number
861 | # of examples must be a multiple of the batch size, or else examples
862 | # will get dropped. So we pad with fake examples which are ignored
863 | # later on. These do NOT count towards the metric (all tf.metrics
864 | # support a per-instance weight, and these get a weight of 0.0).
865 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
866 | eval_examples.append(PaddingInputExample())
867 |
868 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
869 | file_based_convert_examples_to_features(
870 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
871 |
872 | tf.logging.info("***** Running evaluation *****")
873 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
874 | len(eval_examples), num_actual_eval_examples,
875 | len(eval_examples) - num_actual_eval_examples)
876 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
877 |
878 | # This tells the estimator to run through the entire set.
879 | eval_steps = None
880 | # However, if running eval on the TPU, you will need to specify the
881 | # number of steps.
882 | if FLAGS.use_tpu:
883 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
884 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
885 |
886 | eval_drop_remainder = True if FLAGS.use_tpu else False
887 | eval_input_fn = file_based_input_fn_builder(
888 | input_file=eval_file,
889 | seq_length=FLAGS.max_seq_length,
890 | is_training=False,
891 | drop_remainder=eval_drop_remainder)
892 |
893 | #######################################################################################################################
894 | # evaluate 所有的checkpoint
895 | steps_and_files = []
896 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
897 | for filename in filenames:
898 | if filename.endswith(".index"):
899 | ckpt_name = filename[:-6]
900 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
901 | global_step = int(cur_filename.split("-")[-1])
902 | tf.logging.info("Add {} to eval list.".format(cur_filename))
903 | steps_and_files.append([global_step, cur_filename])
904 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0])
905 |
906 | output_eval_file = os.path.join(FLAGS.data_dir, "eval_results16-layer24-4million-2.txt") # finetuning-layer24-4million
907 | print("output_eval_file:",output_eval_file)
908 | tf.logging.info("output_eval_file:"+output_eval_file)
909 | with tf.gfile.GFile(output_eval_file, "w") as writer:
910 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]):
911 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=filename)
912 |
913 | tf.logging.info("***** Eval results %s *****" % (filename))
914 | writer.write("***** Eval results %s *****\n" % (filename))
915 | for key in sorted(result.keys()):
916 | tf.logging.info(" %s = %s", key, str(result[key]))
917 | writer.write("%s = %s\n" % (key, str(result[key])))
918 | #######################################################################################################################
919 |
920 | #result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
921 | #
922 | #output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
923 | #with tf.gfile.GFile(output_eval_file, "w") as writer:
924 | # tf.logging.info("***** Eval results *****")
925 | # for key in sorted(result.keys()):
926 | # tf.logging.info(" %s = %s", key, str(result[key]))
927 | # writer.write("%s = %s\n" % (key, str(result[key])))
928 |
929 | if FLAGS.do_predict:
930 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
931 | num_actual_predict_examples = len(predict_examples)
932 | if FLAGS.use_tpu:
933 | # TPU requires a fixed batch size for all batches, therefore the number
934 | # of examples must be a multiple of the batch size, or else examples
935 | # will get dropped. So we pad with fake examples which are ignored
936 | # later on.
937 | while len(predict_examples) % FLAGS.predict_batch_size != 0:
938 | predict_examples.append(PaddingInputExample())
939 |
940 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
941 | file_based_convert_examples_to_features(predict_examples, label_list,
942 | FLAGS.max_seq_length, tokenizer,
943 | predict_file)
944 |
945 | tf.logging.info("***** Running prediction*****")
946 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
947 | len(predict_examples), num_actual_predict_examples,
948 | len(predict_examples) - num_actual_predict_examples)
949 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
950 |
951 | predict_drop_remainder = True if FLAGS.use_tpu else False
952 | predict_input_fn = file_based_input_fn_builder(
953 | input_file=predict_file,
954 | seq_length=FLAGS.max_seq_length,
955 | is_training=False,
956 | drop_remainder=predict_drop_remainder)
957 |
958 | result = estimator.predict(input_fn=predict_input_fn)
959 |
960 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
961 | with tf.gfile.GFile(output_predict_file, "w") as writer:
962 | num_written_lines = 0
963 | tf.logging.info("***** Predict results *****")
964 | for (i, prediction) in enumerate(result):
965 | probabilities = prediction["probabilities"]
966 | if i >= num_actual_predict_examples:
967 | break
968 | output_line = "\t".join(
969 | str(class_probability)
970 | for class_probability in probabilities) + "\n"
971 | writer.write(output_line)
972 | num_written_lines += 1
973 | assert num_written_lines == num_actual_predict_examples
974 |
975 |
976 | if __name__ == "__main__":
977 | flags.mark_flag_as_required("data_dir")
978 | flags.mark_flag_as_required("task_name")
979 | flags.mark_flag_as_required("vocab_file")
980 | flags.mark_flag_as_required("bert_config_file")
981 | flags.mark_flag_as_required("output_dir")
982 | tf.app.run()
--------------------------------------------------------------------------------
/run_pretraining.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run masked LM/next sentence masked_lm pre-training for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import modeling
23 | import optimization
24 | import tensorflow as tf
25 |
26 | flags = tf.flags
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | ## Required parameters
31 | flags.DEFINE_string(
32 | "bert_config_file", None,
33 | "The config json file corresponding to the pre-trained BERT model. "
34 | "This specifies the model architecture.")
35 |
36 | flags.DEFINE_string(
37 | "input_file", None,
38 | "Input TF example files (can be a glob or comma separated).")
39 |
40 | flags.DEFINE_string(
41 | "output_dir", None,
42 | "The output directory where the model checkpoints will be written.")
43 |
44 | ## Other parameters
45 | flags.DEFINE_string(
46 | "init_checkpoint", None,
47 | "Initial checkpoint (usually from a pre-trained BERT model).")
48 |
49 | flags.DEFINE_integer(
50 | "max_seq_length", 128,
51 | "The maximum total input sequence length after WordPiece tokenization. "
52 | "Sequences longer than this will be truncated, and sequences shorter "
53 | "than this will be padded. Must match data generation.")
54 |
55 | flags.DEFINE_integer(
56 | "max_predictions_per_seq", 20,
57 | "Maximum number of masked LM predictions per sequence. "
58 | "Must match data generation.")
59 |
60 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
61 |
62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
63 |
64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
65 |
66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
67 |
68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
69 |
70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
71 |
72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
73 |
74 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
75 | "How often to save the model checkpoint.")
76 |
77 | flags.DEFINE_integer("iterations_per_loop", 1000,
78 | "How many steps to make in each estimator call.")
79 |
80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
81 |
82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
83 |
84 | tf.flags.DEFINE_string(
85 | "tpu_name", None,
86 | "The Cloud TPU to use for training. This should be either the name "
87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
88 | "url.")
89 |
90 | tf.flags.DEFINE_string(
91 | "tpu_zone", None,
92 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
93 | "specified, we will attempt to automatically detect the GCE project from "
94 | "metadata.")
95 |
96 | tf.flags.DEFINE_string(
97 | "gcp_project", None,
98 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
99 | "specified, we will attempt to automatically detect the GCE project from "
100 | "metadata.")
101 |
102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
103 |
104 | flags.DEFINE_integer(
105 | "num_tpu_cores", 8,
106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
107 |
108 |
109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate,
110 | num_train_steps, num_warmup_steps, use_tpu,
111 | use_one_hot_embeddings):
112 | """Returns `model_fn` closure for TPUEstimator."""
113 |
114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
115 | """The `model_fn` for TPUEstimator."""
116 |
117 | tf.logging.info("*** Features ***")
118 | for name in sorted(features.keys()):
119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
120 |
121 | input_ids = features["input_ids"]
122 | input_mask = features["input_mask"]
123 | segment_ids = features["segment_ids"]
124 | masked_lm_positions = features["masked_lm_positions"]
125 | masked_lm_ids = features["masked_lm_ids"]
126 | masked_lm_weights = features["masked_lm_weights"]
127 | next_sentence_labels = features["next_sentence_labels"]
128 |
129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
130 |
131 | model = modeling.BertModel(
132 | config=bert_config,
133 | is_training=is_training,
134 | input_ids=input_ids,
135 | input_mask=input_mask,
136 | token_type_ids=segment_ids,
137 | use_one_hot_embeddings=use_one_hot_embeddings)
138 |
139 | (masked_lm_loss,
140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
141 | bert_config, model.get_sequence_output(), model.get_embedding_table(),
142 | masked_lm_positions, masked_lm_ids, masked_lm_weights)
143 |
144 | (next_sentence_loss, next_sentence_example_loss, # TODO TODO TODO 可以计算单不算成绩
145 | next_sentence_log_probs) = get_next_sentence_output(
146 | bert_config, model.get_pooled_output(), next_sentence_labels)
147 | # batch_size=masked_lm_log_probs.shape[0]
148 | # next_sentence_example_loss=tf.zeros((batch_size)) #tf.constant(0.0,dtype=tf.float32)
149 | # next_sentence_log_probs=tf.zeros((batch_size,2))
150 | total_loss = masked_lm_loss # TODO remove next sentence loss 2019-08-08, + next_sentence_loss
151 |
152 | tvars = tf.trainable_variables()
153 |
154 | initialized_variable_names = {}
155 | print("init_checkpoint:",init_checkpoint)
156 | scaffold_fn = None
157 | if init_checkpoint:
158 | (assignment_map, initialized_variable_names
159 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
160 | if use_tpu:
161 |
162 | def tpu_scaffold():
163 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
164 | return tf.train.Scaffold()
165 |
166 | scaffold_fn = tpu_scaffold
167 | else:
168 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
169 |
170 | tf.logging.info("**** Trainable Variables ****")
171 | for var in tvars:
172 | init_string = ""
173 | if var.name in initialized_variable_names:
174 | init_string = ", *INIT_FROM_CKPT*"
175 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
176 | init_string)
177 |
178 | output_spec = None
179 | if mode == tf.estimator.ModeKeys.TRAIN:
180 | train_op = optimization.create_optimizer(
181 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
182 |
183 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
184 | mode=mode,
185 | loss=total_loss,
186 | train_op=train_op,
187 | scaffold_fn=scaffold_fn)
188 | elif mode == tf.estimator.ModeKeys.EVAL:
189 |
190 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
191 | masked_lm_weights, next_sentence_example_loss,
192 | next_sentence_log_probs, next_sentence_labels):
193 | """Computes the loss and accuracy of the model."""
194 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,[-1, masked_lm_log_probs.shape[-1]])
195 | masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32)
196 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
197 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
198 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
199 | masked_lm_accuracy = tf.metrics.accuracy(
200 | labels=masked_lm_ids,
201 | predictions=masked_lm_predictions,
202 | weights=masked_lm_weights)
203 | masked_lm_mean_loss = tf.metrics.mean(
204 | values=masked_lm_example_loss, weights=masked_lm_weights)
205 |
206 | next_sentence_log_probs = tf.reshape(
207 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
208 | next_sentence_predictions = tf.argmax(
209 | next_sentence_log_probs, axis=-1, output_type=tf.int32)
210 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
211 | next_sentence_accuracy = tf.metrics.accuracy(
212 | labels=next_sentence_labels, predictions=next_sentence_predictions)
213 | next_sentence_mean_loss = tf.metrics.mean(
214 | values=next_sentence_example_loss)
215 |
216 | return {
217 | "masked_lm_accuracy": masked_lm_accuracy,
218 | "masked_lm_loss": masked_lm_mean_loss,
219 | "next_sentence_accuracy": next_sentence_accuracy,
220 | "next_sentence_loss": next_sentence_mean_loss,
221 | }
222 |
223 | # next_sentence_example_loss=0.0 TODO
224 | # next_sentence_log_probs=0.0 # TODO
225 | eval_metrics = (metric_fn, [
226 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
227 | masked_lm_weights, next_sentence_example_loss,
228 | next_sentence_log_probs, next_sentence_labels
229 | ])
230 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
231 | mode=mode,
232 | loss=total_loss,
233 | eval_metrics=eval_metrics,
234 | scaffold_fn=scaffold_fn)
235 | else:
236 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
237 |
238 | return output_spec
239 |
240 | return model_fn
241 |
242 |
243 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
244 | label_ids, label_weights):
245 | """Get loss and log probs for the masked LM."""
246 | input_tensor = gather_indexes(input_tensor, positions)
247 |
248 | with tf.variable_scope("cls/predictions"):
249 | # We apply one more non-linear transformation before the output layer.
250 | # This matrix is not used after pre-training.
251 | with tf.variable_scope("transform"):
252 | input_tensor = tf.layers.dense(
253 | input_tensor,
254 | units=bert_config.hidden_size,
255 | activation=modeling.get_activation(bert_config.hidden_act),
256 | kernel_initializer=modeling.create_initializer(
257 | bert_config.initializer_range))
258 | input_tensor = modeling.layer_norm(input_tensor)
259 |
260 | # The output weights are the same as the input embeddings, but there is
261 | # an output-only bias for each token.
262 | output_bias = tf.get_variable(
263 | "output_bias",
264 | shape=[bert_config.vocab_size],
265 | initializer=tf.zeros_initializer())
266 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
267 | logits = tf.nn.bias_add(logits, output_bias)
268 | log_probs = tf.nn.log_softmax(logits, axis=-1)
269 |
270 | label_ids = tf.reshape(label_ids, [-1])
271 | label_weights = tf.reshape(label_weights, [-1])
272 |
273 | one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
274 |
275 | # The `positions` tensor might be zero-padded (if the sequence is too
276 | # short to have the maximum number of predictions). The `label_weights`
277 | # tensor has a value of 1.0 for every real prediction and 0.0 for the
278 | # padding predictions.
279 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
280 | numerator = tf.reduce_sum(label_weights * per_example_loss)
281 | denominator = tf.reduce_sum(label_weights) + 1e-5
282 | loss = numerator / denominator
283 |
284 | return (loss, per_example_loss, log_probs)
285 |
286 |
287 | def get_next_sentence_output(bert_config, input_tensor, labels):
288 | """Get loss and log probs for the next sentence prediction."""
289 |
290 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
291 | # "random sentence". This weight matrix is not used after pre-training.
292 | with tf.variable_scope("cls/seq_relationship"):
293 | output_weights = tf.get_variable(
294 | "output_weights",
295 | shape=[2, bert_config.hidden_size],
296 | initializer=modeling.create_initializer(bert_config.initializer_range))
297 | output_bias = tf.get_variable(
298 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
299 |
300 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
301 | logits = tf.nn.bias_add(logits, output_bias)
302 | log_probs = tf.nn.log_softmax(logits, axis=-1)
303 | labels = tf.reshape(labels, [-1])
304 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
305 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
306 | loss = tf.reduce_mean(per_example_loss)
307 | return (loss, per_example_loss, log_probs)
308 |
309 |
310 | def gather_indexes(sequence_tensor, positions):
311 | """Gathers the vectors at the specific positions over a minibatch."""
312 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
313 | batch_size = sequence_shape[0]
314 | seq_length = sequence_shape[1]
315 | width = sequence_shape[2]
316 |
317 | flat_offsets = tf.reshape(
318 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
319 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
320 | flat_sequence_tensor = tf.reshape(sequence_tensor,
321 | [batch_size * seq_length, width])
322 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
323 | return output_tensor
324 |
325 |
326 | def input_fn_builder(input_files,
327 | max_seq_length,
328 | max_predictions_per_seq,
329 | is_training,
330 | num_cpu_threads=4):
331 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
332 |
333 | def input_fn(params):
334 | """The actual input function."""
335 | batch_size = params["batch_size"]
336 |
337 | name_to_features = {
338 | "input_ids":
339 | tf.FixedLenFeature([max_seq_length], tf.int64),
340 | "input_mask":
341 | tf.FixedLenFeature([max_seq_length], tf.int64),
342 | "segment_ids":
343 | tf.FixedLenFeature([max_seq_length], tf.int64),
344 | "masked_lm_positions":
345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
346 | "masked_lm_ids":
347 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
348 | "masked_lm_weights":
349 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
350 | "next_sentence_labels":
351 | tf.FixedLenFeature([1], tf.int64),
352 | }
353 |
354 | # For training, we want a lot of parallel reading and shuffling.
355 | # For eval, we want no shuffling and parallel reading doesn't matter.
356 | if is_training:
357 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
358 | d = d.repeat()
359 | d = d.shuffle(buffer_size=len(input_files))
360 |
361 | # `cycle_length` is the number of parallel files that get read.
362 | cycle_length = min(num_cpu_threads, len(input_files))
363 |
364 | # `sloppy` mode means that the interleaving is not exact. This adds
365 | # even more randomness to the training pipeline.
366 | d = d.apply(
367 | tf.contrib.data.parallel_interleave(
368 | tf.data.TFRecordDataset,
369 | sloppy=is_training,
370 | cycle_length=cycle_length))
371 | d = d.shuffle(buffer_size=100)
372 | else:
373 | d = tf.data.TFRecordDataset(input_files)
374 | # Since we evaluate for a fixed number of steps we don't want to encounter
375 | # out-of-range exceptions.
376 | d = d.repeat()
377 |
378 | # We must `drop_remainder` on training because the TPU requires fixed
379 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
380 | # and we *don't* want to drop the remainder, otherwise we wont cover
381 | # every sample.
382 | d = d.apply(
383 | tf.contrib.data.map_and_batch(
384 | lambda record: _decode_record(record, name_to_features),
385 | batch_size=batch_size,
386 | num_parallel_batches=num_cpu_threads,
387 | drop_remainder=True))
388 | return d
389 |
390 | return input_fn
391 |
392 |
393 | def _decode_record(record, name_to_features):
394 | """Decodes a record to a TensorFlow example."""
395 | example = tf.parse_single_example(record, name_to_features)
396 |
397 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
398 | # So cast all int64 to int32.
399 | for name in list(example.keys()):
400 | t = example[name]
401 | if t.dtype == tf.int64:
402 | t = tf.to_int32(t)
403 | example[name] = t
404 |
405 | return example
406 |
407 |
408 | def main(_):
409 | tf.logging.set_verbosity(tf.logging.INFO)
410 |
411 | if not FLAGS.do_train and not FLAGS.do_eval: # 必须是训练或验证的类型
412 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
413 |
414 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) # 从json文件中获得配置信息
415 |
416 | tf.gfile.MakeDirs(FLAGS.output_dir)
417 |
418 | input_files = [] # 输入可以是多个文件,以“逗号隔开”;可以是一个匹配形式的,如“input_x*”
419 | for input_pattern in FLAGS.input_file.split(","):
420 | input_files.extend(tf.gfile.Glob(input_pattern))
421 |
422 | tf.logging.info("*** Input Files ***")
423 | for input_file in input_files:
424 | tf.logging.info(" %s" % input_file)
425 |
426 | tpu_cluster_resolver = None
427 | #if FLAGS.use_tpu and FLAGS.tpu_name:
428 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( # TODO
429 | tpu=FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
430 |
431 | print("###tpu_cluster_resolver:",tpu_cluster_resolver,";FLAGS.use_tpu:",FLAGS.use_tpu,";FLAGS.tpu_name:",FLAGS.tpu_name,";FLAGS.tpu_zone:",FLAGS.tpu_zone)
432 | # ###tpu_cluster_resolver: ;FLAGS.use_tpu: True ;FLAGS.tpu_name: grpc://10.240.1.83:8470
433 |
434 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
435 | run_config = tf.contrib.tpu.RunConfig(
436 | keep_checkpoint_max=20, # 10
437 | cluster=tpu_cluster_resolver,
438 | master=FLAGS.master,
439 | model_dir=FLAGS.output_dir,
440 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
441 | tpu_config=tf.contrib.tpu.TPUConfig(
442 | iterations_per_loop=FLAGS.iterations_per_loop,
443 | num_shards=FLAGS.num_tpu_cores,
444 | per_host_input_for_training=is_per_host))
445 |
446 | model_fn = model_fn_builder(
447 | bert_config=bert_config,
448 | init_checkpoint=FLAGS.init_checkpoint,
449 | learning_rate=FLAGS.learning_rate,
450 | num_train_steps=FLAGS.num_train_steps,
451 | num_warmup_steps=FLAGS.num_warmup_steps,
452 | use_tpu=FLAGS.use_tpu,
453 | use_one_hot_embeddings=FLAGS.use_tpu)
454 |
455 | # If TPU is not available, this will fall back to normal Estimator on CPU
456 | # or GPU.
457 | estimator = tf.contrib.tpu.TPUEstimator(
458 | use_tpu=FLAGS.use_tpu,
459 | model_fn=model_fn,
460 | config=run_config,
461 | train_batch_size=FLAGS.train_batch_size,
462 | eval_batch_size=FLAGS.eval_batch_size)
463 |
464 | if FLAGS.do_train:
465 | tf.logging.info("***** Running training *****")
466 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
467 | train_input_fn = input_fn_builder(
468 | input_files=input_files,
469 | max_seq_length=FLAGS.max_seq_length,
470 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
471 | is_training=True)
472 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
473 |
474 | if FLAGS.do_eval:
475 | tf.logging.info("***** Running evaluation *****")
476 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
477 |
478 | eval_input_fn = input_fn_builder(
479 | input_files=input_files,
480 | max_seq_length=FLAGS.max_seq_length,
481 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
482 | is_training=False)
483 |
484 | result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
485 |
486 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
487 | with tf.gfile.GFile(output_eval_file, "w") as writer:
488 | tf.logging.info("***** Eval results *****")
489 | for key in sorted(result.keys()):
490 | tf.logging.info(" %s = %s", key, str(result[key]))
491 | writer.write("%s = %s\n" % (key, str(result[key])))
492 |
493 |
494 | if __name__ == "__main__":
495 | flags.mark_flag_as_required("input_file")
496 | flags.mark_flag_as_required("bert_config_file")
497 | flags.mark_flag_as_required("output_dir")
498 | tf.app.run()
499 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | #print("items:",items) #['[CLS]', '日', '##期', ',', '但', '被', '##告', '金', '##东', '##福', '载', '##明', '[MASK]', 'U', '##N', '##K', ']', '保', '##证', '本', '##月', '1', '##4', '[MASK]', '到', '##位', ',', '2', '##0', '##1', '##5', '年', '6', '[MASK]', '1', '##1', '日', '[', 'U', '##N', '##K', ']', ',', '原', '##告', '[MASK]', '认', '##可', '于', '2', '##0', '##1', '##5', '[MASK]', '6', '月', '[MASK]', '[MASK]', '日', '##向', '被', '##告', '主', '##张', '权', '##利', '。', '而', '[MASK]', '[MASK]', '自', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '年', '6', '月', '1', '##1', '日', '[SEP]', '原', '##告', '于', '2', '##0', '##1', '##6', '[MASK]', '6', '[MASK]', '2', '##4', '日', '起', '##诉', ',', '主', '##张', '保', '##证', '责', '##任', ',', '已', '超', '##过', '保', '##证', '期', '##限', '[MASK]', '保', '##证', '人', '依', '##法', '不', '##再', '承', '##担', '保', '##证', '[MASK]', '[MASK]', '[MASK]', '[SEP]']
140 | for i,item in enumerate(items):
141 | #print(i,"item:",item) # ##期
142 | output.append(vocab[item])
143 | return output
144 |
145 |
146 | def convert_tokens_to_ids(vocab, tokens):
147 | return convert_by_vocab(vocab, tokens)
148 |
149 |
150 | def convert_ids_to_tokens(inv_vocab, ids):
151 | return convert_by_vocab(inv_vocab, ids)
152 |
153 |
154 | def whitespace_tokenize(text):
155 | """Runs basic whitespace cleaning and splitting on a piece of text."""
156 | text = text.strip()
157 | if not text:
158 | return []
159 | tokens = text.split()
160 | return tokens
161 |
162 |
163 | class FullTokenizer(object):
164 | """Runs end-to-end tokenziation."""
165 |
166 | def __init__(self, vocab_file, do_lower_case=True):
167 | self.vocab = load_vocab(vocab_file)
168 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
169 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
171 |
172 | def tokenize(self, text):
173 | split_tokens = []
174 | for token in self.basic_tokenizer.tokenize(text):
175 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
176 | split_tokens.append(sub_token)
177 |
178 | return split_tokens
179 |
180 | def convert_tokens_to_ids(self, tokens):
181 | return convert_by_vocab(self.vocab, tokens)
182 |
183 | def convert_ids_to_tokens(self, ids):
184 | return convert_by_vocab(self.inv_vocab, ids)
185 |
186 |
187 | class BasicTokenizer(object):
188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
189 |
190 | def __init__(self, do_lower_case=True):
191 | """Constructs a BasicTokenizer.
192 |
193 | Args:
194 | do_lower_case: Whether to lower case the input.
195 | """
196 | self.do_lower_case = do_lower_case
197 |
198 | def tokenize(self, text):
199 | """Tokenizes a piece of text."""
200 | text = convert_to_unicode(text)
201 | text = self._clean_text(text)
202 |
203 | # This was added on November 1st, 2018 for the multilingual and Chinese
204 | # models. This is also applied to the English models now, but it doesn't
205 | # matter since the English models were not trained on any Chinese data
206 | # and generally don't have any Chinese data in them (there are Chinese
207 | # characters in the vocabulary because Wikipedia does have some Chinese
208 | # words in the English Wikipedia.).
209 | text = self._tokenize_chinese_chars(text)
210 |
211 | orig_tokens = whitespace_tokenize(text)
212 | split_tokens = []
213 | for token in orig_tokens:
214 | if self.do_lower_case:
215 | token = token.lower()
216 | token = self._run_strip_accents(token)
217 | split_tokens.extend(self._run_split_on_punc(token))
218 |
219 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
220 | return output_tokens
221 |
222 | def _run_strip_accents(self, text):
223 | """Strips accents from a piece of text."""
224 | text = unicodedata.normalize("NFD", text)
225 | output = []
226 | for char in text:
227 | cat = unicodedata.category(char)
228 | if cat == "Mn":
229 | continue
230 | output.append(char)
231 | return "".join(output)
232 |
233 | def _run_split_on_punc(self, text):
234 | """Splits punctuation on a piece of text."""
235 | chars = list(text)
236 | i = 0
237 | start_new_word = True
238 | output = []
239 | while i < len(chars):
240 | char = chars[i]
241 | if _is_punctuation(char):
242 | output.append([char])
243 | start_new_word = True
244 | else:
245 | if start_new_word:
246 | output.append([])
247 | start_new_word = False
248 | output[-1].append(char)
249 | i += 1
250 |
251 | return ["".join(x) for x in output]
252 |
253 | def _tokenize_chinese_chars(self, text):
254 | """Adds whitespace around any CJK character."""
255 | output = []
256 | for char in text:
257 | cp = ord(char)
258 | if self._is_chinese_char(cp):
259 | output.append(" ")
260 | output.append(char)
261 | output.append(" ")
262 | else:
263 | output.append(char)
264 | return "".join(output)
265 |
266 | def _is_chinese_char(self, cp):
267 | """Checks whether CP is the codepoint of a CJK character."""
268 | # This defines a "chinese character" as anything in the CJK Unicode block:
269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
270 | #
271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
272 | # despite its name. The modern Korean Hangul alphabet is a different block,
273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
274 | # space-separated words, so they are not treated specially and handled
275 | # like the all of the other languages.
276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
277 | (cp >= 0x3400 and cp <= 0x4DBF) or #
278 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
279 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
280 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
282 | (cp >= 0xF900 and cp <= 0xFAFF) or #
283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
284 | return True
285 |
286 | return False
287 |
288 | def _clean_text(self, text):
289 | """Performs invalid character removal and whitespace cleanup on text."""
290 | output = []
291 | for char in text:
292 | cp = ord(char)
293 | if cp == 0 or cp == 0xfffd or _is_control(char):
294 | continue
295 | if _is_whitespace(char):
296 | output.append(" ")
297 | else:
298 | output.append(char)
299 | return "".join(output)
300 |
301 |
302 | class WordpieceTokenizer(object):
303 | """Runs WordPiece tokenziation."""
304 |
305 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
306 | self.vocab = vocab
307 | self.unk_token = unk_token
308 | self.max_input_chars_per_word = max_input_chars_per_word
309 |
310 | def tokenize(self, text):
311 | """Tokenizes a piece of text into its word pieces.
312 |
313 | This uses a greedy longest-match-first algorithm to perform tokenization
314 | using the given vocabulary.
315 |
316 | For example:
317 | input = "unaffable"
318 | output = ["un", "##aff", "##able"]
319 |
320 | Args:
321 | text: A single token or whitespace separated tokens. This should have
322 | already been passed through `BasicTokenizer.
323 |
324 | Returns:
325 | A list of wordpiece tokens.
326 | """
327 |
328 | text = convert_to_unicode(text)
329 |
330 | output_tokens = []
331 | for token in whitespace_tokenize(text):
332 | chars = list(token)
333 | if len(chars) > self.max_input_chars_per_word:
334 | output_tokens.append(self.unk_token)
335 | continue
336 |
337 | is_bad = False
338 | start = 0
339 | sub_tokens = []
340 | while start < len(chars):
341 | end = len(chars)
342 | cur_substr = None
343 | while start < end:
344 | substr = "".join(chars[start:end])
345 | if start > 0:
346 | substr = "##" + substr
347 | if substr in self.vocab:
348 | cur_substr = substr
349 | break
350 | end -= 1
351 | if cur_substr is None:
352 | is_bad = True
353 | break
354 | sub_tokens.append(cur_substr)
355 | start = end
356 |
357 | if is_bad:
358 | output_tokens.append(self.unk_token)
359 | else:
360 | output_tokens.extend(sub_tokens)
361 | return output_tokens
362 |
363 |
364 | def _is_whitespace(char):
365 | """Checks whether `chars` is a whitespace character."""
366 | # \t, \n, and \r are technically contorl characters but we treat them
367 | # as whitespace since they are generally considered as such.
368 | if char == " " or char == "\t" or char == "\n" or char == "\r":
369 | return True
370 | cat = unicodedata.category(char)
371 | if cat == "Zs":
372 | return True
373 | return False
374 |
375 |
376 | def _is_control(char):
377 | """Checks whether `chars` is a control character."""
378 | # These are technically control characters but we count them as whitespace
379 | # characters.
380 | if char == "\t" or char == "\n" or char == "\r":
381 | return False
382 | cat = unicodedata.category(char)
383 | if cat in ("Cc", "Cf"):
384 | return True
385 | return False
386 |
387 |
388 | def _is_punctuation(char):
389 | """Checks whether `chars` is a punctuation character."""
390 | cp = ord(char)
391 | # We treat all non-letter/number ASCII as punctuation.
392 | # Characters such as "^", "$", and "`" are not in the Unicode
393 | # Punctuation class but we treat them as punctuation anyways, for
394 | # consistency.
395 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
396 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
397 | return True
398 | cat = unicodedata.category(char)
399 | if cat.startswith("P"):
400 | return True
401 | return False
402 |
--------------------------------------------------------------------------------