├── README.md ├── albert_config ├── albert_config_base.json ├── albert_config_base_google_fast.json ├── albert_config_large.json ├── albert_config_small_google.json ├── albert_config_tiny.json ├── albert_config_tiny_google.json ├── albert_config_tiny_google_fast.json ├── albert_config_xlarge.json ├── albert_config_xxlarge.json ├── bert_config.json └── vocab.txt ├── args.py ├── bert_utils.py ├── classifier_utils.py ├── create_pretrain_data.sh ├── create_pretraining_data.py ├── create_pretraining_data_google.py ├── data └── news_zh_1.txt ├── lamb_optimizer_google.py ├── modeling.py ├── modeling_google.py ├── modeling_google_fast.py ├── optimization.py ├── optimization_finetuning.py ├── optimization_google.py ├── resources ├── add_data_removing_dropout.jpg ├── albert_configuration.jpg ├── albert_large_zh_parameters.jpg ├── albert_performance.jpg ├── albert_tiny_compare_s.jpg ├── albert_tiny_compare_s_old.jpg ├── create_pretraining_data_roberta.py ├── crmc2018_compare_s.jpg ├── shell_scripts │ └── create_pretrain_data_batch_webtext.sh ├── state_of_the_art.jpg └── xlarge_loss.jpg ├── run_classifier.py ├── run_classifier_clue.py ├── run_classifier_clue.sh ├── run_classifier_lcqmc.sh ├── run_classifier_sp_google.py ├── run_pretraining.py ├── run_pretraining_google.py ├── run_pretraining_google_fast.py ├── similarity.py ├── test_changes.py ├── tokenization.py └── tokenization_google.py /README.md: -------------------------------------------------------------------------------- 1 | # albert_zh 2 | 3 | An Implementation of A Lite Bert For Self-Supervised Learning Language Representations with TensorFlow 4 | 5 | ALBert is based on Bert, but with some improvements. It achieves state of the art performance on main benchmarks with 30% parameters less. 6 | 7 | For albert_base_zh it only has ten percentage parameters compare of original bert model, and main accuracy is retained. 8 | 9 | 10 | Different version of ALBERT pre-trained model for Chinese, including TensorFlow, PyTorch and Keras, is available now. 11 | 12 | 海量中文语料上预训练ALBERT模型:参数更少,效果更好。预训练小模型也能拿下13项NLP任务,ALBERT三大改造登顶GLUE基准 13 | 14 | clueai工具包: 三行代码,三分钟定制一个NLP的API(零样本学习) 15 | 16 | 17 | 18 | 一键运行10个数据集、9个基线模型、不同任务上模型效果的详细对比,见CLUE benchmark 19 | 20 | 一键运行CLUE中文任务:6个中文分类或句子对任务(新) 21 | --------------------------------------------------------------------- 22 | 使用方式: 23 | 1、克隆项目 24 | git clone https://github.com/brightmart/albert_zh.git 25 | 2、运行一键运行脚本(GPU方式): 会自动下载模型和所有任务数据并开始运行。 26 | bash run_classifier_clue.sh 27 | 执行该一键运行脚本将会自动下载所有任务数据,并为所有任务找到最优模型,然后测试得到提交结果 28 | 29 | 30 | 模型下载 Download Pre-trained Models of Chinese 31 | ----------------------------------------------- 32 | 1、albert_tiny_zh, albert_tiny_zh(训练更久,累积学习20亿个样本),文件大小16M、参数为4M 33 | 34 | 训练和推理预测速度提升约10倍,精度基本保留,模型大小为bert的1/25;语义相似度数据集LCQMC测试集上达到85.4%,相比bert_base仅下降1.5个点。 35 | 36 | lcqmc训练使用如下参数: --max_seq_length=128 --train_batch_size=64 --learning_rate=1e-4 --num_train_epochs=5 37 | 38 | albert_tiny使用同样的大规模中文语料数据,层数仅为4层、hidden size等向量维度大幅减少; 尝试使用如下学习率来获得更好效果:{2e-5, 6e-5, 1e-4} 39 | 40 | 【使用场景】任务相对比较简单一些或实时性要求高的任务,如语义相似度等句子对任务、分类任务;比较难的任务如阅读理解等,可以使用其他大模型。 41 | 42 | 例如,可以使用[Tensorflow Lite](https://www.tensorflow.org/lite)在移动端进行部署,本文[随后](#use_tflite)针对这一点进行了介绍,包括如何把模型转换成Tensorflow Lite格式和对其进行性能测试等。 43 | 44 | 一键运行albert_tiny_zh(linux,lcqmc任务): 45 | 1) git clone https://github.com/brightmart/albert_zh 46 | 2) cd albert_zh 47 | 3) bash run_classifier_lcqmc.sh 48 | 1.1、albert_tiny_google_zh(累积学习10亿个样本,google版本),模型大小16M、性能与albert_tiny_zh一致 49 | 50 | 1.2、albert_small_google_zh(累积学习10亿个样本,google版本), 51 | 52 | 速度比bert_base快4倍;LCQMC测试集上比Bert下降仅0.9个点;去掉adam后模型大小18.5M;使用方法,见 #下游任务 Fine-tuning on Downstream Task 53 | 54 | 2、albert_large_zh,参数量,层数24,文件大小为64M 55 | 56 | 参数量和模型大小为bert_base的六分之一;在口语化描述相似性数据集LCQMC的测试集上相比bert_base上升0.2个点 57 | 58 | 3、albert_base_zh(额外训练了1.5亿个实例即 36k steps * batch_size 4096); albert_base_zh(小模型体验版), 参数量12M, 层数12,大小为40M 59 | 60 | 参数量为bert_base的十分之一,模型大小也十分之一;在口语化描述相似性数据集LCQMC的测试集上相比bert_base下降约0.6~1个点; 61 | 相比未预训练,albert_base提升14个点 62 | 63 | 4、albert_xlarge_zh_177k ; 64 | albert_xlarge_zh_183k(优先尝试)参数量,层数24,文件大小为230M 65 | 66 | 参数量和模型大小为bert_base的二分之一;需要一张大的显卡;完整测试对比将后续添加;batch_size不能太小,否则可能影响精度 67 | 68 | ### 快速加载 69 | 依托于[Huggingface-Transformers 2.2.2](https://github.com/huggingface/transformers),可轻松调用以上模型。 70 | ``` 71 | tokenizer = AutoTokenizer.from_pretrained("MODEL_NAME") 72 | model = AutoModel.from_pretrained("MODEL_NAME") 73 | ``` 74 | 75 | 其中`MODEL_NAME`对应列表如下: 76 | 77 | | 模型名 | MODEL_NAME | 78 | | - | - | 79 | | albert_tiny_google_zh | voidful/albert_chinese_tiny | 80 | | albert_small_google_zh | voidful/albert_chinese_small | 81 | | albert_base_zh (from google) | voidful/albert_chinese_base | 82 | | albert_large_zh (from google) | voidful/albert_chinese_large | 83 | | albert_xlarge_zh (from google) | voidful/albert_chinese_xlarge | 84 | | albert_xxlarge_zh (from google) | voidful/albert_chinese_xxlarge | 85 | 86 | 更多通过transformers使用albert的示例 87 | 88 | 预训练 Pre-training 89 | ----------------------------------------------- 90 | 91 | #### 生成特定格式的文件(tfrecords) Generate tfrecords Files 92 | 93 | Run following command 运行以下命令即可。项目自动了一个示例的文本文件(data/news_zh_1.txt) 94 | 95 | bash create_pretrain_data.sh 96 | 97 | 如果你有很多文本文件,可以通过传入参数的方式,生成多个特定格式的文件(tfrecords) 98 | 99 | ###### Support English and Other Non-Chinese Language: 100 | If you are doing pre-train for english or other language,which is not chinese, 101 | you should set hyperparameter of non_chinese to True on create_pretraining_data.py; 102 | otherwise, by default it is doing chinese pre-train using whole word mask of chinese. 103 | 104 | #### 执行预训练 pre-training on GPU/TPU using the command 105 | GPU(brightmart版, tiny模型): 106 | export BERT_BASE_DIR=./albert_tiny_zh 107 | nohup python3 run_pretraining.py --input_file=./data/tf*.tfrecord \ 108 | --output_dir=./my_new_model_path --do_train=True --do_eval=True --bert_config_file=$BERT_BASE_DIR/albert_config_tiny.json \ 109 | --train_batch_size=4096 --max_seq_length=512 --max_predictions_per_seq=51 \ 110 | --num_train_steps=125000 --num_warmup_steps=12500 --learning_rate=0.00176 \ 111 | --save_checkpoints_steps=2000 --init_checkpoint=$BERT_BASE_DIR/albert_model.ckpt & 112 | 113 | GPU(Google版本, small模型): 114 | export BERT_BASE_DIR=./albert_small_zh_google 115 | nohup python3 run_pretraining_google.py --input_file=./data/tf*.tfrecord --eval_batch_size=64 \ 116 | --output_dir=./my_new_model_path --do_train=True --do_eval=True --albert_config_file=$BERT_BASE_DIR/albert_config_small_google.json --export_dir=./my_new_model_path_export \ 117 | --train_batch_size=4096 --max_seq_length=512 --max_predictions_per_seq=20 \ 118 | --num_train_steps=125000 --num_warmup_steps=12500 --learning_rate=0.00176 \ 119 | --save_checkpoints_steps=2000 --init_checkpoint=$BERT_BASE_DIR/albert_model.ckpt 120 | 121 | TPU, add something like this: 122 | --use_tpu=True --tpu_name=grpc://10.240.1.66:8470 --tpu_zone=us-central1-a 123 | 124 | 注:如果你重头开始训练,可以不指定init_checkpoint; 125 | 如果你从现有的模型基础上训练,指定一下BERT_BASE_DIR的路径,并确保bert_config_file和init_checkpoint两个参数的值能对应到相应的文件上; 126 | 领域上的预训练,根据数据的大小,可以不用训练特别久。 127 | 128 | 环境 Environment 129 | ----------------------------------------------- 130 | Use Python3 + Tensorflow 1.x 131 | 132 | e.g. Tensorflow 1.4 or 1.5 133 | 134 | 135 | 下游任务 Fine-tuning on Downstream Task 136 | ----------------------------------------------- 137 | ##### 使用TensorFlow: 138 | 139 | 以使用albert_base做LCQMC任务为例。LCQMC任务是在口语化描述的数据集上做文本的相似性预测。 140 | 141 | We will use LCQMC dataset for fine-tuning, it is oral language corpus, it is used to train and predict semantic similarity of a pair of sentences. 142 | 143 | 下载LCQMC数据集,包含训练、验证和测试集,训练集包含24万口语化描述的中文句子对,标签为1或0。1为句子语义相似,0为语义不相似。 144 | 145 | 通过运行下列命令做LCQMC数据集上的fine-tuning: 146 | 147 | 1. Clone this project: 148 | 149 | git clone https://github.com/brightmart/albert_zh.git 150 | 151 | 2. Fine-tuning by running the following command. 152 | brightmart版本的tiny模型 153 | export BERT_BASE_DIR=./albert_tiny_zh 154 | export TEXT_DIR=./lcqmc 155 | nohup python3 run_classifier.py --task_name=lcqmc_pair --do_train=true --do_eval=true --data_dir=$TEXT_DIR --vocab_file=./albert_config/vocab.txt \ 156 | --bert_config_file=./albert_config/albert_config_tiny.json --max_seq_length=128 --train_batch_size=64 --learning_rate=1e-4 --num_train_epochs=5 \ 157 | --output_dir=./albert_lcqmc_checkpoints --init_checkpoint=$BERT_BASE_DIR/albert_model.ckpt & 158 | 159 | google版本的small模型 160 | export BERT_BASE_DIR=./albert_small_zh 161 | export TEXT_DIR=./lcqmc 162 | nohup python3 run_classifier_sp_google.py --task_name=lcqmc_pair --do_train=true --do_eval=true --data_dir=$TEXT_DIR --vocab_file=./albert_config/vocab.txt \ 163 | --albert_config_file=./$BERT_BASE_DIR/albert_config_small_google.json --max_seq_length=128 --train_batch_size=64 --learning_rate=1e-4 --num_train_epochs=5 \ 164 | --output_dir=./albert_lcqmc_checkpoints --init_checkpoint=$BERT_BASE_DIR/albert_model.ckpt & 165 | 166 | Notice/注: 167 | 1) you need to download pre-trained chinese albert model, and also download LCQMC dataset 168 | 你需要下载预训练的模型,并放入到项目当前项目,假设目录名称为albert_tiny_zh; 需要下载LCQMC数据集,并放入到当前项目, 169 | 假设数据集目录名称为lcqmc 170 | 171 | 2) for Fine-tuning, you can try to add small percentage of dropout(e.g. 0.1) by changing parameters of 172 | attention_probs_dropout_prob & hidden_dropout_prob on albert_config_xxx.json. By default, we set dropout as zero. 173 | 174 | 3) you can try different learning rate {2e-5, 6e-5, 1e-4} for better performance 175 | 176 | 177 | Updates 178 | ----------------------------------------------- 179 | **\*\*\*\*\* 2019-11-03: add google version of albert_small, albert_tiny; 180 | 181 | add method to deploy ablert_tiny to mobile devices with only 0.1 second inference time for sequence length 128, 60M memory \*\*\*\*\*** 182 | 183 | **\*\*\*\*\* 2019-10-30: add a simple guide about converting the model to Tensorflow Lite for edge deployment \*\*\*\*\*** 184 | 185 | **\*\*\*\*\* 2019-10-15: albert_tiny_zh, 10 times fast than bert base for training and inference, accuracy remains \*\*\*\*\*** 186 | 187 | **\*\*\*\*\* 2019-10-07: more models of albert \*\*\*\*\*** 188 | 189 | add albert_xlarge_zh; albert_base_zh_additional_steps, training with more instances 190 | 191 | **\*\*\*\*\* 2019-10-04: PyTorch and Keras versions of albert were supported \*\*\*\*\*** 192 | 193 | a.Convert to PyTorch version and do your tasks through albert_pytorch 194 | 195 | b.Load pre-trained model with keras using one line of codes through bert4keras 196 | 197 | c.Use albert with TensorFlow 2.0: Use or load pre-trained model with tf2.0 through bert-for-tf2 198 | 199 | Releasing albert_xlarge on 6th Oct 200 | 201 | **\*\*\*\*\* 2019-10-02: albert_large_zh,albert_base_zh \*\*\*\*\*** 202 | 203 | Relesed albert_base_zh with only 10% parameters of bert_base, a small model(40M) & training can be very fast. 204 | 205 | Relased albert_large_zh with only 16% parameters of bert_base(64M) 206 | 207 | **\*\*\*\*\* 2019-09-28: codes and test functions \*\*\*\*\*** 208 | 209 | Add codes and test functions for three main changes of albert from bert 210 | 211 | ALBERT模型介绍 Introduction of ALBERT 212 | ----------------------------------------------- 213 | ALBERT模型是BERT的改进版,与最近其他State of the art的模型不同的是,这次是预训练小模型,效果更好、参数更少。 214 | 215 | 它对BERT进行了三个改造 Three main changes of ALBert from Bert: 216 | 217 | 1)词嵌入向量参数的因式分解 Factorized embedding parameterization 218 | 219 | O(V * H) to O(V * E + E * H) 220 | 221 | 如以ALBert_xxlarge为例,V=30000, H=4096, E=128 222 | 223 | 那么原先参数为V * H= 30000 * 4096 = 1.23亿个参数,现在则为V * E + E * H = 30000*128+128*4096 = 384万 + 52万 = 436万, 224 | 225 | 词嵌入相关的参数变化前是变换后的28倍。 226 | 227 | 228 | 2)跨层参数共享 Cross-Layer Parameter Sharing 229 | 230 | 参数共享能显著减少参数。共享可以分为全连接层、注意力层的参数共享;注意力层的参数对效果的减弱影响小一点。 231 | 232 | 3)段落连续性任务 Inter-sentence coherence loss. 233 | 234 | 使用段落连续性任务。正例,使用从一个文档中连续的两个文本段落;负例,使用从一个文档中连续的两个文本段落,但位置调换了。 235 | 236 | 避免使用原有的NSP任务,原有的任务包含隐含了预测主题这类过于简单的任务。 237 | 238 | We maintain that inter-sentence modeling is an important aspect of language understanding, but we propose a loss 239 | based primarily on coherence. That is, for ALBERT, we use a sentence-order prediction (SOP) loss, which avoids topic 240 | prediction and instead focuses on modeling inter-sentence coherence. The SOP loss uses as positive examples the 241 | same technique as BERT (two consecutive segments from the same document), and as negative examples the same two 242 | consecutive segments but with their order swapped. This forces the model to learn finer-grained distinctions about 243 | discourse-level coherence properties. 244 | 245 | 其他变化,还有 Other changes: 246 | 247 | 1)去掉了dropout Remove dropout to enlarge capacity of model. 248 | 最大的模型,训练了1百万步后,还是没有过拟合训练数据。说明模型的容量还可以更大,就移除了dropout 249 | (dropout可以认为是随机的去掉网络中的一部分,同时使网络变小一些) 250 | We also note that, even after training for 1M steps, our largest models still do not overfit to their training data. 251 | As a result, we decide to remove dropout to further increase our model capacity. 252 | 其他型号的模型,在我们的实现中我们还是会保留原始的dropout的比例,防止模型对训练数据的过拟合。 253 | 254 | 2)为加快训练速度,使用LAMB做为优化器 Use LAMB as optimizer, to train with big batch size 255 | 使用了大的batch_size来训练(4096)。 LAMB优化器使得我们可以训练,特别大的批次batch_size,如高达6万。 256 | 257 | 3)使用n-gram(uni-gram,bi-gram, tri-gram)来做遮蔽语言模型 Use n-gram as make language model 258 | 即以不同的概率使用n-gram,uni-gram的概率最大,bi-gram其次,tri-gram概率最小。 259 | 本项目中目前使用的是在中文上做whole word mask,稍后会更新一下与n-gram mask的效果对比。n-gram从spanBERT中来。 260 | 261 | 262 | 训练语料/训练配置 Training Data & Configuration 263 | ----------------------------------------------- 264 | 30g中文语料,超过100亿汉字,包括多个百科、新闻、互动社区。 265 | 266 | 预训练序列长度sequence_length设置为512,批次batch_size为4096,训练产生了3.5亿个训练数据(instance);每一个模型默认会训练125k步,albert_xxlarge将训练更久。 267 | 268 | 作为比较,roberta_zh预训练产生了2.5亿个训练数据、序列长度为256。由于albert_zh预训练生成的训练数据更多、使用的序列长度更长, 269 | 270 | 我们预计albert_zh会有比roberta_zh更好的性能表现,并且能更好处理较长的文本。 271 | 272 | 训练使用TPU v3 Pod,我们使用的是v3-256,它包含32个v3-8。每个v3-8机器,含有128G的显存。 273 | 274 | 275 | 模型性能与对比(英文) Performance and Comparision 276 | ----------------------------------------------- 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 中文任务集上效果对比测试 Performance on Chinese datasets 287 | ----------------------------------------------- 288 | 289 | ### 问题匹配语任务:LCQMC(Sentence Pair Matching) 290 | 291 | | 模型 | 开发集(Dev) | 测试集(Test) | 292 | | :------- | :---------: | :---------: | 293 | | BERT | 89.4(88.4) | 86.9(86.4) | 294 | | ERNIE | 89.8 (89.6) | 87.2 (87.0) | 295 | | BERT-wwm |89.4 (89.2) | 87.0 (86.8) | 296 | | BERT-wwm-ext | - |- | 297 | | RoBERTa-zh-base | 88.7 | 87.0 | 298 | | RoBERTa-zh-Large | ***89.9(89.6)*** | 87.2(86.7) | 299 | | RoBERTa-zh-Large(20w_steps) | 89.7| 87.0 | 300 | | ALBERT-zh-tiny | -- | 85.4 | 301 | | ALBERT-zh-small | -- | 86.0 | 302 | | ALBERT-zh-small(Pytorch) | -- | 86.8 | 303 | | ALBERT-zh-base-additional-36k-steps | 87.8 | 86.3 | 304 | | ALBERT-zh-base | 87.2 | 86.3 | 305 | | ALBERT-large | 88.7 | 87.1 | 306 | | ALBERT-xlarge | 87.3 | ***87.7*** | 307 | 308 | 注:只跑了一次ALBERT-xlarge,效果还可能提升 309 | 310 | ### 自然语言推断:XNLI of Chinese Version 311 | 312 | | 模型 | 开发集 | 测试集 | 313 | | :------- | :---------: | :---------: | 314 | | BERT | 77.8 (77.4) | 77.8 (77.5) | 315 | | ERNIE | 79.7 (79.4) | 78.6 (78.2) | 316 | | BERT-wwm | 79.0 (78.4) | 78.2 (78.0) | 317 | | BERT-wwm-ext | 79.4 (78.6) | 78.7 (78.3) | 318 | | XLNet | 79.2 | 78.7 | 319 | | RoBERTa-zh-base | 79.8 |78.8 | 320 | | RoBERTa-zh-Large | 80.2 (80.0) | 79.9 (79.5) | 321 | | ALBERT-base | 77.0 | 77.1 | 322 | | ALBERT-large | 78.0 | 77.5 | 323 | | ALBERT-xlarge | ? | ? | 324 | 325 | 注:BERT-wwm-ext来自于这里;XLNet来自于这里; RoBERTa-zh-base,指12层RoBERTa中文模型 326 | 327 | 328 | ### 阅读理解任务:CRMC2018 329 | 330 | 331 | 332 | 333 | ### 语言模型、文本段预测准确性、训练时间 Mask Language Model Accuarcy & Training Time 334 | 335 | | Model | MLM eval acc | SOP eval acc | Training(Hours) | Loss eval | 336 | | :------- | :---------: | :---------: | :---------: |:---------: | 337 | | albert_zh_base | 79.1% | 99.0% | 6h | 1.01| 338 | | albert_zh_large | 80.9% | 98.6% | 22.5h | 0.93| 339 | | albert_zh_xlarge | ? | ? | 53h(预估) | ? | 340 | | albert_zh_xxlarge | ? | ? | 106h(预估) | ? | 341 | 342 | 注:? 将很快替换 343 | 344 | 模型参数和配置 Configuration of Models 345 | ----------------------------------------------- 346 | 347 | 348 | 代码实现和测试 Implementation and Code Testing 349 | ----------------------------------------------- 350 | 通过运行以下命令测试主要的改进点,包括但不限于词嵌入向量参数的因式分解、跨层参数共享、段落连续性任务等。 351 | 352 | python test_changes.py 353 | 354 | ##### 使用TensorFlow Lite(TFLite)在移动端进行部署: 355 | 这里我们主要介绍TFLite模型格式转换和性能测试。转换成TFLite模型后,对于如何在移 356 | 动端使用该模型,可以参考TFLite提供的[Android/iOS应用完整开发案例教程页面](https://www.tensorflow.org/lite/examples)。 357 | 该页面目前已经包含了[文本分类](https://github.com/tensorflow/examples/blob/master/lite/examples/text_classification/android), 358 | [文本问答](https://github.com/tensorflow/examples/blob/master/lite/examples/bert_qa/android)两个Android案例。 359 | 360 | 下面以albert_tiny_zh 361 | 为例来介绍TFLite模型格式转换和性能测试: 362 | 363 | 1. Freeze graph from the checkpoint 364 | 365 | Ensure to have >=1.14 1.x installed to use the freeze_graph tool as it is removed from 2.x distribution 366 | 367 | pip install tensorflow==1.15 368 | 369 | freeze_graph --input_checkpoint=./albert_model.ckpt \ 370 | --output_graph=/tmp/albert_tiny_zh.pb \ 371 | --output_node_names=cls/predictions/truediv \ 372 | --checkpoint_version=1 --input_meta_graph=./albert_model.ckpt.meta --input_binary=true 373 | 374 | 2. Convert to TFLite format 375 | 376 | We are going to use the new experimental tf->tflite converter that's distributed with the Tensorflow nightly build. 377 | 378 | pip install tf-nightly 379 | 380 | tflite_convert --graph_def_file=/tmp/albert_tiny_zh.pb \ 381 | --input_arrays='input_ids,input_mask,segment_ids,masked_lm_positions,masked_lm_ids,masked_lm_weights' \ 382 | --output_arrays='cls/predictions/truediv' \ 383 | --input_shapes=1,128:1,128:128:1,128:1,128:1,128 \ 384 | --output_file=/tmp/albert_tiny_zh.tflite \ 385 | --enable_v1_converter --experimental_new_converter 386 | 387 | 3. Benchmark the performance of the TFLite model 388 | 389 | See [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark) 390 | for details about the performance benchmark tools in TFLite. For example: after 391 | building the benchmark tool binary for an Android phone, do the following to 392 | get an idea of how the TFLite model performs on the phone 393 | 394 | adb push /tmp/albert_tiny_zh.tflite /data/local/tmp/ 395 | adb shell /data/local/tmp/benchmark_model_performance_options --graph=/data/local/tmp/albert_tiny_zh.tflite --perf_options_list=cpu 396 | 397 | On an Android phone w/ Qualcomm's SD845 SoC, via the above benchmark tool, as 398 | of 2019/11/01, the inference latency is ~120ms w/ this converted TFLite model 399 | using 4 threads on CPU, and the memory usage is ~60MB for the model during 400 | inference. Note the performance will improve further with future TFLite 401 | implementation optimizations. 402 | 403 | ##### 使用PyTorch版本: 404 | 405 | download pre-trained model, and convert to PyTorch using: 406 | 407 | python convert_albert_tf_checkpoint_to_pytorch.py 408 | 409 | using albert_pytorch 410 | 411 | ##### 使用Keras加载: 412 | 413 | bert4keras 适配albert,能成功加载albert_zh的权重,只需要在load_pretrained_model函数里加上albert=True 414 | 415 | load pre-trained model with bert4keras 416 | 417 | ##### 使用tf2.0加载: 418 | 419 | bert-for-tf2 420 | 421 | 422 | 使用案例-基于用户输入预测文本相似性 Use Case-Text Similarity Based on User Input 423 | ------------------------------------------------- 424 | 425 | 功能说明:用户可以通过本例了解如何加载训训练集实现基于用户输入的短文本相似度判断。可以基于该代码将程序灵活地拓展为后台服务或增加文本分类等示例。 426 | 427 | 涉及代码:similarity.py、args.py 428 | 429 | 步骤: 430 | 431 | 1、使用本模型进行文本相似性训练,保存模型文件至相应目录下 432 | 433 | 2、根据实际情况,修改args.py中的参数,参数说明如下: 434 | 435 | ```python 436 | #模型目录,存放ckpt文件 437 | model_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') 438 | 439 | #config文件,存放模型的json文件 440 | config_name = os.path.join(file_path, 'albert_config/albert_config_tiny.json') 441 | 442 | #ckpt文件名称 443 | ckpt_name = os.path.join(model_dir, 'model.ckpt') 444 | 445 | #输出文件目录,训练时的模型输出目录 446 | output_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') 447 | 448 | #vocab文件目录 449 | vocab_file = os.path.join(file_path, 'albert_config/vocab.txt') 450 | 451 | #数据目录,训练使用的数据集存放目录 452 | data_dir = os.path.join(file_path, 'data/') 453 | ``` 454 | 455 | 本例中的文件结构为: 456 | 457 | |__args.py 458 | 459 | |__similarity.py 460 | 461 | |__data 462 | 463 | |__albert_config 464 | 465 | |__albert_lcqmc_checkpoints 466 | 467 | |__lcqmc 468 | 469 | 3、修改用户输入单词 470 | 471 | 打开similarity.py,最底部如下代码: 472 | 473 | ```python 474 | if __name__ == '__main__': 475 | sim = BertSim() 476 | sim.start_model() 477 | sim.predict_sentences([("我喜欢妈妈做的汤", "妈妈做的汤我很喜欢喝")]) 478 | ``` 479 | 480 | 其中sim.start_model()表示加载模型,sim.predict_sentences的输入为一个元组数组,元组中包含两个元素分别为需要判定相似的句子。 481 | 482 | 4、运行python文件:similarity.py 483 | 484 | 485 | 支持的序列长度与批次大小的关系,12G显存 Trade off between batch Size and sequence length 486 | ------------------------------------------------- 487 | 488 | System | Seq Length | Max Batch Size 489 | ------------ | ---------- | -------------- 490 | `albert-base` | 64 | 64 491 | ... | 128 | 32 492 | ... | 256 | 16 493 | ... | 320 | 14 494 | ... | 384 | 12 495 | ... | 512 | 6 496 | `albert-large` | 64 | 12 497 | ... | 128 | 6 498 | ... | 256 | 2 499 | ... | 320 | 1 500 | ... | 384 | 0 501 | ... | 512 | 0 502 | `albert-xlarge` | - | - 503 | 504 | 学习曲线 Training Loss of xlarge of albert_zh 505 | ------------------------------------------------- 506 | 507 | 508 | 所有的参数 Parameters of albert_xlarge 509 | ------------------------------------------------- 510 | 511 | 512 | 513 | #### 技术交流与问题讨论QQ群: 836811304 Join us on QQ group 514 | 515 | If you have any question, you can raise an issue, or send me an email: brightmart@hotmail.com; 516 | 517 | Currently how to use PyTorch version of albert is not clear yet, if you know how to do that, just email us or open an issue. 518 | 519 | You can also send pull request to report you performance on your task or add methods on how to load models for PyTorch and so on. 520 | 521 | If you have ideas for generate best performance pre-training Chinese model, please also let me know. 522 | 523 | ##### Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC) 524 | 525 | Cite Us 526 | ----------------------------------------------- 527 | Bright Liang Xu, albert_zh, (2019), GitHub repository, https://github.com/brightmart/albert_zh 528 | 529 | Reference 530 | ----------------------------------------------- 531 | 1、ALBERT: A Lite BERT For Self-Supervised Learning Of Language Representations 532 | 533 | 2、BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 534 | 535 | 3、SpanBERT: Improving Pre-training by Representing and Predicting Spans 536 | 537 | 4、RoBERTa: A Robustly Optimized BERT Pretraining Approach 538 | 539 | 5、Large Batch Optimization for Deep Learning: Training BERT in 76 minutes(LAMB) 540 | 541 | 6、LAMB Optimizer,TensorFlow version 542 | 543 | 7、预训练小模型也能拿下13项NLP任务,ALBERT三大改造登顶GLUE基准 544 | 545 | 8、 albert_pytorch 546 | 547 | 9、load albert with keras 548 | 549 | 10、load albert with tf2.0 550 | 551 | 11、repo of albert from google 552 | 553 | 12、chineseGLUE-中文任务基准测评:公开可用多个任务、基线模型、广泛测评与效果对比 554 | 555 | 556 | 557 | 558 | -------------------------------------------------------------------------------- /albert_config/albert_config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 768, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 3072 , 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 12, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_config/albert_config_base_google_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "embedding_size": 128, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "num_hidden_groups": 12, 13 | "net_structure_type": 0, 14 | "gap_size": 0, 15 | "num_memory_blocks": 0, 16 | "inner_group_num": 1, 17 | "down_scale_factor": 1, 18 | "type_vocab_size": 2, 19 | "vocab_size": 21128 20 | } -------------------------------------------------------------------------------- /albert_config/albert_config_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 1024, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 4096, 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 16, 12 | "num_hidden_layers": 24, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_config/albert_config_small_google.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.0, 5 | "embedding_size": 128, 6 | "hidden_size": 384, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1536, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 6, 12 | "num_hidden_groups": 1, 13 | "net_structure_type": 0, 14 | "gap_size": 0, 15 | "num_memory_blocks": 0, 16 | "inner_group_num": 1, 17 | "down_scale_factor": 1, 18 | "type_vocab_size": 2, 19 | "vocab_size": 21128 20 | } -------------------------------------------------------------------------------- /albert_config/albert_config_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 312, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 1248 , 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 4, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_config/albert_config_tiny_google.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.0, 5 | "embedding_size": 128, 6 | "hidden_size": 312, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1248, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 4, 12 | "num_hidden_groups": 1, 13 | "net_structure_type": 0, 14 | "gap_size": 0, 15 | "num_memory_blocks": 0, 16 | "inner_group_num": 1, 17 | "down_scale_factor": 1, 18 | "type_vocab_size": 2, 19 | "vocab_size": 21128 20 | } 21 | -------------------------------------------------------------------------------- /albert_config/albert_config_tiny_google_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "embedding_size": 128, 6 | "hidden_size": 336, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1344, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 4, 12 | "num_hidden_groups": 12, 13 | "net_structure_type": 0, 14 | "gap_size": 0, 15 | "num_memory_blocks": 0, 16 | "inner_group_num": 1, 17 | "down_scale_factor": 1, 18 | "type_vocab_size": 2, 19 | "vocab_size": 21128 20 | } -------------------------------------------------------------------------------- /albert_config/albert_config_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 2048, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 8192, 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 32, 12 | "num_hidden_layers": 24, 13 | 14 | "pooler_fc_size": 1024, 15 | "pooler_num_attention_heads": 64, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_config/albert_config_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 4096, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 16384, 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 64, 12 | "num_hidden_layers": 12, 13 | 14 | "pooler_fc_size": 1024, 15 | "pooler_num_attention_heads": 64, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"preln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_config/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | tf.logging.set_verbosity(tf.logging.INFO) 5 | 6 | file_path = os.path.dirname(__file__) 7 | 8 | 9 | #模型目录 10 | model_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') 11 | 12 | #config文件 13 | config_name = os.path.join(file_path, 'albert_config/albert_config_tiny.json') 14 | #ckpt文件名称 15 | ckpt_name = os.path.join(model_dir, 'model.ckpt') 16 | #输出文件目录 17 | output_dir = os.path.join(file_path, 'albert_lcqmc_checkpoints/') 18 | #vocab文件目录 19 | vocab_file = os.path.join(file_path, 'albert_config/vocab.txt') 20 | #数据目录 21 | data_dir = os.path.join(file_path, 'data/') 22 | 23 | num_train_epochs = 10 24 | batch_size = 128 25 | learning_rate = 0.00005 26 | 27 | # gpu使用率 28 | gpu_memory_fraction = 0.8 29 | 30 | # 默认取倒数第二层的输出值作为句向量 31 | layer_indexes = [-2] 32 | 33 | # 序列的最大程度,单文本建议把该值调小 34 | max_seq_len = 128 35 | 36 | # graph名字 37 | graph_file = os.path.join(file_path, 'albert_lcqmc_checkpoints/graph') -------------------------------------------------------------------------------- /bert_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import copy 7 | import json 8 | import math 9 | import re 10 | import six 11 | import tensorflow as tf 12 | 13 | def get_shape_list(tensor, expected_rank=None, name=None): 14 | """Returns a list of the shape of tensor, preferring static dimensions. 15 | 16 | Args: 17 | tensor: A tf.Tensor object to find the shape of. 18 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 19 | specified and the `tensor` has a different rank, and exception will be 20 | thrown. 21 | name: Optional name of the tensor for the error message. 22 | 23 | Returns: 24 | A list of dimensions of the shape of tensor. All static dimensions will 25 | be returned as python integers, and dynamic dimensions will be returned 26 | as tf.Tensor scalars. 27 | """ 28 | if name is None: 29 | name = tensor.name 30 | 31 | if expected_rank is not None: 32 | assert_rank(tensor, expected_rank, name) 33 | 34 | shape = tensor.shape.as_list() 35 | 36 | non_static_indexes = [] 37 | for (index, dim) in enumerate(shape): 38 | if dim is None: 39 | non_static_indexes.append(index) 40 | 41 | if not non_static_indexes: 42 | return shape 43 | 44 | dyn_shape = tf.shape(tensor) 45 | for index in non_static_indexes: 46 | shape[index] = dyn_shape[index] 47 | return shape 48 | 49 | def reshape_to_matrix(input_tensor): 50 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 51 | ndims = input_tensor.shape.ndims 52 | if ndims < 2: 53 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 54 | (input_tensor.shape)) 55 | if ndims == 2: 56 | return input_tensor 57 | 58 | width = input_tensor.shape[-1] 59 | output_tensor = tf.reshape(input_tensor, [-1, width]) 60 | return output_tensor 61 | 62 | def reshape_from_matrix(output_tensor, orig_shape_list): 63 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 64 | if len(orig_shape_list) == 2: 65 | return output_tensor 66 | 67 | output_shape = get_shape_list(output_tensor) 68 | 69 | orig_dims = orig_shape_list[0:-1] 70 | width = output_shape[-1] 71 | 72 | return tf.reshape(output_tensor, orig_dims + [width]) 73 | 74 | def assert_rank(tensor, expected_rank, name=None): 75 | """Raises an exception if the tensor rank is not of the expected rank. 76 | 77 | Args: 78 | tensor: A tf.Tensor to check the rank of. 79 | expected_rank: Python integer or list of integers, expected rank. 80 | name: Optional name of the tensor for the error message. 81 | 82 | Raises: 83 | ValueError: If the expected shape doesn't match the actual shape. 84 | """ 85 | if name is None: 86 | name = tensor.name 87 | 88 | expected_rank_dict = {} 89 | if isinstance(expected_rank, six.integer_types): 90 | expected_rank_dict[expected_rank] = True 91 | else: 92 | for x in expected_rank: 93 | expected_rank_dict[x] = True 94 | 95 | actual_rank = tensor.shape.ndims 96 | if actual_rank not in expected_rank_dict: 97 | scope_name = tf.get_variable_scope().name 98 | raise ValueError( 99 | "For the tensor `%s` in scope `%s`, the actual rank " 100 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 101 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 102 | 103 | def gather_indexes(sequence_tensor, positions): 104 | """Gathers the vectors at the specific positions over a minibatch.""" 105 | sequence_shape = get_shape_list(sequence_tensor, expected_rank=3) 106 | batch_size = sequence_shape[0] 107 | seq_length = sequence_shape[1] 108 | width = sequence_shape[2] 109 | 110 | flat_offsets = tf.reshape( 111 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 112 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 113 | flat_sequence_tensor = tf.reshape(sequence_tensor, 114 | [batch_size * seq_length, width]) 115 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 116 | return output_tensor 117 | 118 | # add sequence mask for: 119 | # 1. random shuffle lm modeling---xlnet with random shuffled input 120 | # 2. left2right and right2left language modeling 121 | # 3. conditional generation 122 | def generate_seq2seq_mask(attention_mask, mask_sequence, seq_type, **kargs): 123 | if seq_type == 'seq2seq': 124 | if mask_sequence is not None: 125 | seq_shape = get_shape_list(mask_sequence, expected_rank=2) 126 | seq_len = seq_shape[1] 127 | ones = tf.ones((1, seq_len, seq_len)) 128 | a_mask = tf.matrix_band_part(ones, -1, 0) 129 | s_ex12 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 2) 130 | s_ex13 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 3) 131 | a_mask = (1 - s_ex13) * (1 - s_ex12) + s_ex13 * a_mask 132 | # generate mask of batch x seq_len x seq_len 133 | a_mask = tf.reshape(a_mask, (-1, seq_len, seq_len)) 134 | out_mask = attention_mask * a_mask 135 | else: 136 | ones = tf.ones_like(attention_mask[:1]) 137 | mask = (tf.matrix_band_part(ones, -1, 0)) 138 | out_mask = attention_mask * mask 139 | else: 140 | out_mask = attention_mask 141 | 142 | return out_mask 143 | 144 | -------------------------------------------------------------------------------- /create_pretrain_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | BERT_BASE_DIR=./albert_config 4 | python3 create_pretraining_data.py --do_whole_word_mask=True --input_file=data/news_zh_1.txt \ 5 | --output_file=data/tf_news_2016_zh_raw_news2016zh_1.tfrecord --vocab_file=$BERT_BASE_DIR/vocab.txt --do_lower_case=True \ 6 | --max_seq_length=512 --max_predictions_per_seq=51 --masked_lm_prob=0.10 -------------------------------------------------------------------------------- /data/news_zh_1.txt: -------------------------------------------------------------------------------- 1 | 最后的南京老城该往何处去 城市化时代呼唤文化自觉 2 | 【概要】80后学者姚远出版《城市的自觉》一书 姚远出版《城市的自觉》 作者简介姚远,政治学博士,1981年出生于南京,1999年从金陵中学毕业后考入北京大学国际关系学院,负笈燕园十二载,获政治学博士学位。 3 | 现任教于南京大学政府管理学院。 4 | 在关系古都北京、南京等历史文化名城存废的历史关头,他锲而不舍地为抢救中华文明奔走呐喊。 5 | 2010年,他被中国文物保护基金会评为“中国文化遗产保护年度十大杰出人物”,当时的获奖评语是:一支?土耳其诗人纳齐姆·希克梅特曾深情地说:“人的一生有两样东西不会忘记,那就是母亲的面孔和城市的面貌。 6 | ”然而,前不久南京再次发生颜料坊地块市级文保单位两进建筑被毁的事件。 7 | 故宫博物院院长、原国家文物局局长单霁翔近日在宁直言,南京城南再遭损毁令他心痛。 8 | 南京老城“路在何方”? 9 | 2010年被中国文物保护基金会评为“中国文化遗产保护年度十大杰出人物”的80后学者、南京大学姚远老师所著的《城市的自觉》近日正式出版。 10 | 书中探索古城保护与复兴的建设性路径,值得南京的决策者们在颜料坊事件后再次深思。 11 | 江南时报记者黄勇疑问:城市化,是否迷失了文化自觉“目睹一座座古建筑的消失,行走在古城的废墟,想到梁思成说过的‘拆掉北京的一座城楼,就像割掉我的一块肉;扒掉北京的一段城墙,就像扒掉我的一层皮’,真是感同身受,我流泪了。 12 | ”这是姚远最让记者为之动容的一句话,也是《城市的自觉》一书中的“魂”。 13 | 包括南京在内,中国大多数城市正处于大拆除的时代,成片的历史街区在“旧城改造”的大旗下被不断夷为平地。 14 | 有专家称,这场“休克疗法式”的“改造”,对中华文脉的影响之深、之巨、之不可逆,堪称中国城市史上“三千年未有之大变局”。 15 | 《城市的自觉》正是在这种背景下,由北京大学出版社于近日出版的。 16 | 书中,姚远以情理交融的文字,辅之以背景、南京古城珍贵的最后影像,如实记录了在北京梁思成故居和宣南、东四八条、钟鼓楼等历史街区,南京颜料坊、南捕厅、门东、门西等历史街区的最后时刻,为阻挡推土机而屡败屡战的历程。 17 | 同时,又理性剖析了与存续城市记忆密切相关的文化自觉、物权保护、民生改善、公众参与等议题,探索古城保护与复兴的建设性路径。 18 | 为何要保老城? 19 | 很多人认为陈旧的老街区、老房子应该为摩天大楼让位,造高速路、摩天楼是现代化,“保护老古董”是抱残守缺,姚远却不是这种看法:“一些决策者并不知城市遗产保护恰恰是‘后工业’、‘后现代’的思想,比前者的理念差不多领先了一个世纪。 20 | ” 在他眼里,南京这座千年古城曾是“活”着的,老城里有最纯正的方言、最鲜活的民俗、最地道的小吃,简直是一座巨大的民俗博物馆。 21 | “你可以在同老者的交谈中,听到一个个家族或老宅的兴衰故事。 22 | 这里的城与人,就是一本厚重的大书,它们用最生动的语言向你讲述不一样的‘城南旧事’。 23 | ”面对许多古城不断遭到大拆大建、拆真建假、拆旧建新的厄运,姚远痛心地说,“我们的城市化,是否迷失了自我认同,是否失去了文化自觉的能力? 24 | 在城市化的文化自觉重建之前,我们还将继续付出多少代价? 25 | ”现状:老城南仅剩不到1平方公里南京城曾有十九个别称,如秦淮、白下、建邺、江宁等,建城史更是长达两千五百年。 26 | 但如今,除去明城墙以及一些重点文物以及七零八落的民国建筑之外,这个城市跟中国其他的城市看上去并无太多区别,鳞次栉比的高楼大厦,车水马龙的宽阔街道,川流不息的红男绿女……持续多年的旧城改造,已经让南京老城日益失去古朴的历史风貌。 27 | 秦淮河畔的老城南,是南京文化的发源地,是南京的根。 28 | 在2006年前,尽管南京诸多的“殿、庙、塔、桥”已在兵火和变乱中消失,但秦淮河畔的老城南依然保存了文物丰富、风貌完整的历史街区。 29 | 然而,2006年,南京风云突起,突击对颜料坊、安品街等历史街区实施“危旧房改造”,拆毁大量文物建筑。 30 | 2009年又是一轮“危改”,大大的“拆”字,再次涂上了门东、门西、南捕厅等多片老街区。 31 | 2010年至今,南京先后出台了《南京市历史文化名城保护条例》《南京历史文化名城保护规划》《南京老城南历史城区保护规划与城市设计》,以法规的高度,回应了社会各界的诉求,明确要求对老城的整体保护。 32 | 姚远和其他学者联名提出的建议,有40处被采纳进了最后的《条例》中。 33 | 姚远告诉江南时报记者,南京的传统旧城区——老城南仅剩不到1平方公里,尚不及50平方公里老城总面积的2%,整体保护势在必行。 34 | 但他并不认为整体保护意味着“冻结不动”,而是强调古民居、古街巷和宏伟的古建筑一样重要,它们是古都特有的城市肌理,低矮的民居衬托高大的城阙,形成轮廓丰富的城市格局。 35 | 如果消灭了它们,名胜古迹就变成无法交融联络的“孤岛”,古都的整体风貌则无从谈起。 36 | “对于金陵古城濒危的最后这点种子,实行‘整体保护’已经没有任何讨价还价的余地。 37 | ”《城市的自觉》一书中,姚远的声音振聋发聩。 38 | 方案:探索保护与整治的最大合力可惜的是,在专家学者与推土机的拉锯战中,前者基本还是处于下风的,即便是中央领导的几次批示,旧城改造的推土机依然我行我素,将一面面古墙碾在轮下。 39 | 颜料坊、牛市、门东等被“肢解”的老城南片区,如今多已竖起或正在建设房地产开发、商业项目。 40 | 2002年8月,姚远在南京颜料坊开始了古城保护的第一次拍摄。 41 | 如今牛市64号-颜料坊49号这座百年清代建筑却再遭破坏。 42 | 单霁翔近日在南大演讲中也表示,颜料坊再遭损毁令人心痛。 43 | “我不认同南京老城南成片拆除,搬迁当地住户的改造方式。 44 | 简单地认为它的居住形式落后了,这种态度是消极的,没有给予作为代表地域特色的传统建筑的居住形式有尊严的呵护。 45 | ”《城市的自觉》一书中也多次提及南京老城不能“只见物,不见人”。 46 | 姚远强调,南京历史文化名城的保护,离不开对传统社区的活态保护。 47 | 老城南有丰富的民俗和古老的街区,是唇齿相依的一个整体。 48 | 拆去了老宅,迁走了居民,文化自然就成了无源之水、无本之木。 49 | “国际上的成功经验表明,保护从来不是发展、民生、现代化的反义词。 50 | ”姚远建议,老城区的整治,可以在政府的指导和协助下,以居民为主体,通过社区互助的“自我修缮”的方式来实施,将“旧城区改建”从拆迁模式下的行政关系转变为修缮模式下的民事关系,最大限度地调动各方面的积极性,形成保护与整治的最大合力。 51 | 措施:用行动让法律“站起来”经历了两次保卫战,姚远对于文物保护方面的法律条文早已如数家珍。 52 | 在他看来,“法治”和“参与”这两个关键词尤为重要。 53 | 姚远认为,政府的很多失误是因为政策制定的封闭性,推土机开到门口时才告知公众。 54 | 公民参与,就要求行政更加透明、公开。 55 | “几次保护后制定的政策或者法律法规,也很重要。 56 | 因为未来只要有人参与去触动,政策或者法律法规就能‘站起来’,变成一套强有力的程序,约束政府行为。 57 | ”“这些年古城保护的每一点进步,都离不开广泛的公众参与,都凝结着社会各界共同的努力。 58 | ”姚远认为,在北京、南京等许多古城,一批志愿者、社会人士和民间团体,在古城命运的危急关头,已经显示出日益崛起的公众参与的巨大力量。 59 | “关键要有人能够站出来。 60 | 第一个人站出来,就会有第二个人跟上,专家和媒体也会介入,事情就能在公开博弈中得到较为合理的解决。 61 | 我国目前民间的文保力量正在逐渐成长,公民参与将成为构建良性社会机制的重要力量。 62 | ”姚远强调。 63 | 单霁翔对文化遗产保护中的公众参与也做出了高度评价。 64 | 他在《城市的自觉》的序中写道:“保护文化遗产绝不仅仅是各级政府和文物工作者的专利,只有广大民众真心地、持久地参与文化遗产保护,文化遗产才能得到最可靠的保障。 65 | 以姚远博士为代表的一批志愿者和社会人士,在我国文化遗产保护事业中已经显示出不可低估、无可替代的力量。 66 | 67 | 不是每一块石头,都能叫珠宝 68 | 对于很多人来说,矿石是长成这样的石头: 上图:铁矿石 上图:石 上图:煤矿石 上图:锡矿石如你所想象的那样,很多矿石都是又黑又丑,即使在野外遇到,也不会多看一眼的那种石头。 69 | 当然,也不是所有矿石都这么丑。 70 | 我们再看看下面这些矿石: 上图:赤铜 上图:钼铅矿 上图:方硼石 上图:自然硫 上图:云母这些矿石,能否让你感慨大自然的造化神奇?小伙伴们可能会想,这些漂亮的矿石,打磨以后就是漂亮的宝石啊,为什么我们不把他们加工成珠宝呢?这个是个好问题。 71 | 人类自古以来就没有停止过对美好事物的追求,凡漂亮的东西都可能被人们看上,成为制作饰品原料。 72 | 珠宝就是大自然赐予的美好的东西中的一种。 73 | 珠宝如果不美就不能成为珠宝,这种美或表现为绚丽的颜色,或表现为透明而洁净。 74 | 物以稀为贵,鸽血红级别的红宝石、矢车菊蓝级别的蓝宝石,每克拉价值上万美元,而某些颇美丽又可耐久的宝石(如白水晶),由于产量较多,开采较容易,其价格一直较低。 75 | so,大家能明白了吧,不是每一块石头都能成为珠宝。 76 | 如果拥有珠宝,请务必珍惜。 77 | 目前1000+人已关注加入我们您看此文用· 秒,转发只需1秒呦~ 78 | 79 | 北京市黄埔同学会接待“踏寻中山足迹学习之旅”台湾参访团 80 | 光明网讯(通讯员苏民军记者任生心)日前,由台湾中国统一联盟桃竹分会成员组成的“踏寻中山足迹学习之旅”参访团一行21人来到北京参观访问。 81 | 在北京市黄埔同学会的精心安排下,在京期间,参访团拜谒了中山先生衣冠冢,参观了卢沟桥、抗战纪念馆、抗战名将纪念馆和宋庆龄故居等;“踏寻中山足迹学习之旅”参访团还将赴南京中山堂等地参访。 82 | 在抗战纪念馆,参访团成员们认真聆听讲解员的介绍,仔细观看每张图片资料,回顾国共两党团结抗战的往事,缅怀那些为民族独立而壮烈牺牲的英雄。 83 | 而后,参访团一行来到位于京西香山深处的孙中山先生衣冠冢拜谒,参访团团长李尚贤(台湾中国统一联盟总会第一副主席兼秘书长)发表了简短的感言后,全体成员在孙中山雕像前三鞠躬,向孙中山先生致敬,缅怀孙中山先生以“三民主义”为宗旨的革命的一生。 84 | 随后,参访团一行又来到2009年建成的北京香麓园抗战名将纪念馆,瞻仰了佟麟阁将军墓,他们还参观了宋庆龄故居。 85 | 86 | 鼎丰(08056.HK)向客户借出5000万人币 月息1.75厘 为期一年 87 | 鼎丰集团控股(08056.HK)+0.030(+1.345%)公布,同意将一笔5000万元人民币的款项委托予贷款银行,以供转借予客户,贷款期为十二个月,月息1.75厘。 88 | (报价延迟最少十五分钟。 89 | 90 | 在青岛不买房,居然能拥有这么多东西! 91 | 这段时间青岛房价扶摇直上闹得人心惶惶这不,青岛房市,又在国庆节火了一把 国庆5天内16城启动楼市限购一时之间楼市风云大转纵观9月份青岛一手房均价怎么也有一万三四了看完十三哥默默地回去工作了 按照一套房子100平米计算购买一套房子大概需要130万在青岛,买一套房子怎么也得需要130万如果这些钱不买房能在全世界各地买什么呢? 92 | 今天,小编就带大家(bai)感(ri)受(meng)一下在西班牙能买3.4个村庄 一位英国人,名叫尼尔·克里斯蒂,在西班牙农村西北部一个田园地区买下了一处村庄(阿鲁纳达),只花费了4.5万欧元(约合35.6万人民币)。 93 | 简直便宜到吐血,这点钱要是在青岛的豪宅区,恐怕厕所都买不了。 94 | 如果选的地方靠近旅游景区,稍微装修一下,变成一个度假村……妥妥的壕啊,画面太美,不敢想象……在爱尔兰差不多能买个小岛 Inishdooney岛,位于北爱尔兰西北部,售价14万英镑(约合139万人民币)。 95 | 约38万平方米的无人居住地有淡水池塘、天然溶洞和鹅卵石海滩,美翻了有木有! 96 | 一个小岛的钱,和青岛一个水泥格子的价格差不多。 97 | 不要拦着最懂妹,我要去爱尔兰做岛主! 98 | 在巴厘岛能买2座别墅 巴厘岛,蓝天、碧水、白云,美的像梦一样,而你知道吗,这座世界著名旅游岛一个小镇的别墅只要10.7万美元,也就是不到70万人民币,青岛买房那点钱都够买两栋别墅了。 99 | 在巴厘岛拥有两座别墅是什么概念? 100 | 发完文章小编就去买机票! 101 | 在美国能买1驾小飞机 美国塞斯纳C172R型,最大航程可达1270公里,飞机上具备GPS导航定位系统、自动驾驶、盲降设备等,价格大概在17万美元左右,也就是104万人民币。 102 | 在青岛买房的钱妥妥的够买一架飞机了。 103 | 直接移民去西班牙 一个以阳光和沙滩吸引着无数游客的国家,有着激情的足球和斗牛文化、独特的海鲜美食、发达的时装行业、热情火辣的西班牙女郎...... 直接去西班牙? 104 | 你以为我在搞笑? 105 | 西班牙有个买房移民的政策,在西班牙的指定区域购买当地售价在170万人民币以上的房产就可以办理多次往返签证了,然后你待够10年,就可以入西班牙国籍了。 106 | 买一大堆LV手袋 十三哥相信很多女孩应该都很喜欢LV手袋。 107 | 这款极具魅力的CHAIN LOUISE手袋价格为2.04万人民币。 108 | 随随便便买一堆! 109 | 带着爱人环游世界 微博上那对香港80后小夫妻历时308天花费16万人民币走遍了37国,你们还记得吗? 110 | 按照他们的行程,你几乎就能去环游世界了。 111 | 什么也不用想,痛痛快快环游地球一圈! 112 | 在澳大利亚当农场主 五卧室、三浴室的大房子,还有德尼利昆镇附近一块27英亩的农场。 113 | 只需要美元价格14.4万美元(≈96万人民币),是不是惊呆了! 114 | 哦,对了,澳大利亚还提供住房贷款业务哟! 115 | 十三哥要挣钱去澳大利亚买牧场! 116 | 在莫斯科买下1座别墅 莫斯科市中心双卧室、双浴室的豪华大别墅,你觉得多少钱? 117 | 千万别吃惊,美元价格在15.2万美元左右(≈100.1万人民币)。 118 | 虽然在这个城市生活总会有各种各样的压力我们必须十分努力才能看起来毫不费力但是我们永远保持一颗向上的心不气馁,好好加油! 119 | [海尔地产世纪公馆]新都心2期升级新品9月底推出 海尔地产世纪公馆二期规划8栋高层住宅,预计9月底推出,认筹中,交2.5万享99折优惠,预计均价17000-18000元/平。 120 | 户型面积区间89-162平,主力120-140平品质改善产品。 121 | 125-126平为套三,142-162平为套四。 122 | 海尔地产世纪公馆一户一价,以上价格仅供参考,所有在售户型价格以售楼处公布为准。 123 | 咨询电话:400-099-0099 转 27724[金隅和府]3大商圈环绕地铁房18000元 金隅和府一户一价,以下价格仅供参考,所有在售户型价格以售楼处公布为准。 124 | 金隅和府预计9月20日加推6#楼(24F)楼王,3个单元,1梯2户,户型面积为90平套二,122平、138平套三,团购交1万团购金、10万认筹金可以享受97折优惠,预计均价18000-26000元/平。 125 | 金隅和府位于镇江路12号,近邻山东路、延吉路、东西快速路等三横三纵交通网、未来享地铁M5之便利;CBD商圈、香港路商圈、台东商圈3大商圈环绕,居住生活便利。 126 | 127 | 直播拐点来临:未来直播APP开发还有哪些趋势? 128 | 趋势一:巨头收割直播价值,依赖巨头扶持的直播平台存活几率更高尽管一线垂直领域已经被巨头的直播平台占领,但创业者依然还有机会。 129 | 未来在泛娱乐社交、游戏、美妆电商等核心领域必然会有几家直播平台具有突出优势,而这些具备突出优势的直播平台很可能会被BAT入股收购或者收编,因此如果能够获得巨头的资本输血与流量扶持,往往存活的几率会更大。 130 | 趋势二:直播平台从争抢网红到争抢明星资源明星+粉丝经济+直播平台,很可能会衍生出新型的整合营销方式。 131 | 即怎样通过可购买价值的内容设定,运营好与粉丝之间的感情沟通,让粉丝群体进行持续性参与并进行情感消费投入,直播平台与明星组合叠加的人气效应与非理性消费的频次也非常契合品牌商的需求。 132 | 因此,直播的未来趋势将从争抢网红资源到争抢明星资源。 133 | 这是直播平台孕育粉丝经济进而带来新型的情感消费与商业模式的要走的一条必要的路径。 134 | 而未来可能会有越来越多的品牌商更愿意尝试这种直播互动带来的品牌曝光机会与商业变现模式。 135 | 趋势三:从泛娱乐明星网红直播转入到二级垂直细分市场的专业直播泛娱乐直播内容属性上由于其单一、无聊的直播内容无法构成平台的核心竞争力,直播平台未来大趋势是从泛娱乐直播转入到内涵直播。 136 | 目前部分视频直播平台已针对财经、育儿、时尚、体育、美食等垂直领域的自频道开放直播权限,内容的差异化与垂直化可以为直播平台带来新的商业模式,平台也可以通过优质的直播内容,产生付费、会员、打赏以及直播购物等盈利模式。 137 | 因为目前缺乏真正有价值的直播,多数直播平台在内容供给侧是存在问题的,网红要提升自身与粉丝之间的黏性,显然需要差异化的内容,而从目前的欧美网红与直播内容的发展规律来看,更健康、更有价值与内涵的直播内容成为未来的发展趋势之一。 138 | 趋势四:网红孵化器批量生产网红 将走向专业化由于在网红包装、传播、变现等方面具备专业的运营能力,网红孵化器未来须具备 “经纪人+代运营+供应链+网红星探”等多重角色,向专业网红群聚捆绑者向提供专业化的服务与垂直领域专家型、特长型、个性型网红培养者与发现者这一定位转型。 139 | 借助在用户洞察、网红运营、电商管理方面的精良团队,需要打通粉丝营销和电商运营,并将网红、粉丝,平台、内容,品牌、供应链,进行有效链接及整合。 140 | 趋势五:C端直播洗牌 B端企业直播崛起带动专业的商务直播需求目前,各种企业的商务发布会、沙龙、座谈、讲座、渠道大会、教育培训等方面直播需求强烈,在企业进行移动视频直播的需求推动下,它们开始寻求低成本、快速的搭建属于自己的高清视频直播平台的模式,而企业搭建视频直播平台需要专业的技术能力的服务商来应对这种需求。 141 | 用户可以通过微信直接观看企业直播参与互动,让直播突破空间场地的限制,某种程度也代表直播产业链的一个接入的发展方向。 142 | 趋势六:解决直播用户体验与新媒体营销,移动直播服务商将迎来新的机会直播行业进入了各行各业均可参与,并将直播作为企业服务工具的直播+时代,而玩转直播+,从技术、营销、服务、内容,进而可以衍生出更多的直播服务盈利。 143 | 而对于解决直播体验背后的移动直播服务商,也将迎来新的机会。 144 | 趋势七:直播或成为企业的标配,可能为企业带来更多转化率当直播火爆起来的时候,人们要关注的不仅仅是行业能火爆多久,它的商业模式是否成熟,在洗牌节点来临与巨头羽翼覆盖下,自身还有没有机会,创业者与企业都应该从中寻找自己的机会与跨界领域的嫁接。 145 | 它不仅仅是内容和流量的变现工具,更应该是一种营销与商业理念的转变。 146 | 不久前,马化腾向青年创业者建议,要关注两个产业跨界的部分,因为将新技术用在两个产业跨界部分往往最有可能诞生创新的机会。 147 | 而企业营销如果能从垂直细分领域的切入并借助直播技术与趋势为已所用,往往也能获得新的机会,尽管任何基于行业趋势的预测都意味着不确定性,但抓住不确定性的机会,才能最终在新一轮风口下,把握企业转型与商业、营销模式创新的机会,迎来属于自己的时代。 148 | 欢迎互联网创业者加入杭州互联网创业QQ群:157936473直接加QQ或pc上点击加群项目开发咨询:0571-28030088 149 | 150 | 邓伟根北美硅谷行“捎回”一个MBA授课点 151 | 南都讯记者郭伟豪通讯员伍新宇6月7日至16日,佛山市委常委、南海区委书记、佛山高新区党工委书记兼管委会主任邓伟根率领由南海区和佛山高新区相关人员组成的经贸洽谈和友好交流代表团,对新加坡、美国和加拿大进行友好访问。 152 | 由于新加坡裕廊、美国硅谷与有“加拿大高科技之都”美誉的万锦市均以发达的高科技产业著名,皆是所在国的硅谷,邓伟根更称此行为“三谷”之行。 153 | 在新加坡,邓伟根一行与新加坡淡马锡控股公司相关负责人就双方进一步深化合作进行了深入的探讨。 154 | 交流中,新加坡国立大学(N U S)商学院杨贤院长表示有意在南海设立N U S的海外M B A授课点,双方拟于6月下旬就有关意向在南海签订合作协议。 155 | 6月9日,邓伟根一行前往硅谷拜会了硅谷美华科技商会(S V C A C A )和华美半导体协会(C A SPA )。 156 | SV C A C A和CA SPA将通过其广泛的会员和在硅谷等地的影响力,为佛高区、南高区在硅谷进行宣传推介,并积极把有意拓展中国市场的高科技项目推荐到南高区。 157 | 代表团一行还到访了南海区政府与万锦市政府联合举办了“南海区与万锦市经贸交流会”。 158 | 2012年12月,万锦市市长薛家平先生率团访问南海后,万锦市议会正式通过了为当地一道路命名“南海街”的议案,并于2013年9月举行道路命名仪式。 159 | 在本次交流中,邓伟根提议未来也在南海选址命名一条“万锦路”,此举也立即得到薛家平市长的认同。 160 | 对于“三谷”之行,邓伟根表示,南海将利用现有的南海乡亲和关系密切的协会等有利资源,计划在“三谷”建立南海和佛高区的海外联络处,学习和吸收海外高科技之都的先进经验,努力将已定位为“中国制造金谷”的佛高区南海核心园打造成为下一个“硅谷”,并争取早日实现佛高区挺进全国国家高新区20强的目标。 161 | 162 | 内地高中生将通篇学习《道德经》 163 | 摘要国内第一套自主研发的高中传统文化通识教材预计将于今年9月出版,四册分别为《论语》《孟子》《大学·中庸》和《道德经》。 164 | 2016年高考改革方案中,全国25个省高考要统一命题,并且增加分数后的语文考试,正在研究增加“中华优秀传统文化”之相关内容。 165 | 《道德经》成为高中传统文化教材。 166 | 法制晚报讯(记者 李文姬 )今天上午,记者从“十二五”教育部规划课题《传统文化与中小学生人格培养研究》总课题组了解到,国内第一套自主研发的高中传统文化通识教材预计将于今年9月出版,四册分别为《论语》《孟子》《大学·中庸》和《道德经》。 167 | 至此,课题组已完成了幼儿园、小学、初中、高中各阶段标准化传统文化教材的研发工作,高中国学教材将在各地开展成规模的教材试用工作。 168 | 中国国学文化艺术中心秘书长张健表示,目前各地高考改革的几个信号均指向国学,但考什么、怎么考又是一个难题。 169 | 专家建议,不应以文言文字词解释等传统形式考查,应关注考生如何消化吸收传统文化中的哲学素养和思想韬略。 170 | 教材各年级国学内容全覆盖据 “十二五”教育部规划课题《传统文化与中小学生人格培养研究》总课题组介绍,高中传统文化通识系列教材作为“十一五”、“十二五”两个阶段十年课题研究的重要成果之一,由中国国学文化艺术中心承担资源整合和编著。 171 | 去年,教育部印发了《完善中华优秀传统文化教育指导纲要》,要求在课程建设和课程标准修订中强化中华优秀传统文化内容。 172 | 在中小学德育、语文、历史等课程标准修订中,增加中华优秀传统文化的比重。 173 | 课题组秘书长张健表示,幼儿园、小学、初中、高中各阶段标准化传统文化教材的均已研发完成,明确提出以“青少年完美人格”为传统文化教育目标,教材知识相互关联,自成体系,并通过高中教材实现最终教学评价。 174 | 这是“十一五”“十二五”两个阶段十年课题研究的重要成果之一。 175 | 今年5月份之前,《高等教育传统文化教材》(12册)《全国行政领导干部国学教材》(10册)两套教材也将研发完毕。 176 | 内容高中教材含《论语》《道德经》此次即将出版的高中阶段传统文化通识教材共有4册,供高中一、二年级使用。 177 | 高一学习《论语》《孟子》,高二学习《大学·中庸》和《道德经》。 178 | 其中《道德经》为原文全本讲解,另外三册则是按主题归类讲解。 179 | 如《大学·中庸》一册,分为“慎独”“齐家”“格物致知”“中和”“为政”等章节。 180 | 据课题组专家介绍,这4册书并非孤立的高中教材,而是《中华优秀传统文化教育全国中小学实验教材》的高中部分。 181 | 全套教材包含小学、初中和高中三个阶段,经专家组反复研讨、论证,制定了“儒学养正、兵学相佑、道法自然、文化浸润”的课程结构,各阶段教学内容和深度循序渐进、系统科学。 182 | 事实上,小学高年级段已开始涉及《论语》《孟子》等儒学典籍,但仅以诵读和简单理解为主,到高中阶段,学生可在已有基础上更为深刻地领悟儒道经典的思想内涵,以达到融会贯通的程度。 183 | 此外,每一章节在讲解儒道核心精神的同时,还为学生提供了大量中西文化比较等拓展阅读素材。 184 | 针对公众关注的一个话题,即传统文化有望成为高考的新考点,课题组表示目前在研发高中传统文化教材的同时,就已开展了另一个重点子课题研究,即传统文化教学评价与考试模式研究。 185 | 张健强调高考改革的几个信号均指向国学,例如北京、上海等地公布的高考改革方案中,英语降分后其所降分数分给了语文,而且还更进一步明确指出了就是将分数转移给所增加的“传统文化考试内容”部分。 186 | 又如今年清华北大自主招生均招收国学特长生。 187 | 此外,近期公布的2016年高考改革方案中,全国25个省高考要统一命题,并且增加分数后的语文考试,正在研究增加“中华优秀传统文化”之相关内容。 188 | 张健表示,传统文化成为高考的又一创新考点指日可待,但考什么、怎么考又是一个重大难题。 189 | 由于相关子课题研究还没有结束,课题组非行政机构只承担建议义务。 190 | 张健坦言,能否在高考语文中出现一个新的形式——政论或申论形式的传统文化论述题,这一方向应该是研究和创新的改革方向之一。 191 | 若2016年传统文化进入高考,最大的问题是很多高中生没有接触过传统文化课程,不具备相关知识储备和素养,国学文化是通过长期熏陶和涵养才能显现的,不是靠一朝一夕突击补课就能拥有的。 192 | 193 | 悬灸技术培训专家教你艾灸降血糖,为爸妈收好了! 194 | 近年来随着我国经济条件的改善和人们生活水平的提高,我国糖尿病的患病率也在逐年上升。 195 | 悬灸技术培训的创始人艾灸专家刘全军先生对糖尿病深有研究,接下来,学一学他是怎么用艾灸降血压的吧! 196 | 中医认为,糖尿病是气血、阴阳失调等多种原因引起的一种慢性疾病。 197 | 虽然分为上消、中消、下消,但是无论何种糖尿病 ,治疗的原则都是荣养阴液,清热润燥。 198 | 艾灸对控制血糖效果不错。 199 | 艾灸功效:调升元阳降血糖艾灸可以修复受损胰岛细胞,激活再生,逐步实现胰岛素的自给自足。 200 | 服药一天比一天少,身体一天比一天好,彻底摆脱终生服药! 201 | 还可以双向调节血糖,使血糖老老实实地锁定在正常的恒定值范围。 202 | 也可以改善组织供氧,对微血管病变导致的视物不清、眼底出血等视网膜病变及早期肾病病变及早期肾病病变有明显治疗与改善作用,改善病人消瘦无力、免疫力低下、低蛋白质血证及伤口不愈等现象。 203 | 艾灸取穴糖尿病艾灸过的穴位有,承浆中脘足三里关元曲骨三阴交、期门太冲下脘天枢气海膈俞膻中、胃俞,这么多穴位可根据患者当时的症状进行选取。 204 | 选取后艾灸,每10天为一个疗程,疗程间休息3-5天后继续第二轮的治疗,三个疗程基本可见到理想疗效。 205 | 这几个穴位都是具有补充人体元阳功能的大穴和调节脏腑功能的腧穴,从根上调节人体的元阳使阴阳达到新的平衡,五脏六腑尤其是肺、脾肾的功能恢复正常,糖尿病自然也就不药而愈了。 206 | 艾灸可以有效控制糖尿病 ,这在很多资料都有报导。 207 | 艾灸使病人的营养能得到有效的吸收和利用,从而提高人体的自身免疫功能和抗病防病能力,防止了系列并发症的发生,真正做到综合治疗,标本兼治。 208 | 艾灸对于常见病是具有广泛的适应性的。 209 | 希望大家把艾灸推广出去,让艾灸这个疗法能够更完善,造福更多的人。 210 | 211 | 熟食放在垃圾旁无照窝点被取缔 212 | 本报讯(记者李涛)又黑又脏的墙面、随意堆放的加工原料、处处弥漫的刺鼻味道。 213 | 昨天上午,东小口镇政府与城管、食药、公安等部门开展联合执法行动时,依法取缔了一个位于昌平区东小口镇半截塔村的非法熟食加工窝点。 214 | 昨天上午,执法人员对东小口镇半截塔村进行环境整治时,一家挂着“久久鸭”招牌的小店的店主显得有点紧张,还“顺手”把通向后院的门关上了。 215 | 执法人员觉得有些蹊跷,便要求到后院进行检查。 216 | 一进院子,执法人员就发现大量的熟食加工原料被随意摆放在地上,旁边就堆放着垃圾。 217 | 院内煤炉上的一口锅内正煮着的食物,发出刺鼻的味道。 218 | 执法队员介绍,在炉子一旁的笸箩里盛着制作好的熟食制品,但却没有任何遮盖,一阵风起,煤灰混着尘土就落在上面。 219 | 执法队员说:“走进院旁的小屋内,地上和墙上满是油污,脏乎乎的冰柜上堆放着一袋一袋的半成品,一个个用来盛放熟食制品的笸箩摞在生锈的铁架子上。 220 | ”随后,执法人员仔细查找,没有发现任何消毒设施,调查得知从事加工的人员也没有取得加工熟食应需的健康证。 221 | 执法人员随后对店主进行询问,当执法人员要求出示营业执照及卫生许可证时,店主嘟囔了半天才坦白自己不具备任何手续。 222 | 执法人员当即对该非法生产窝点进行了取缔,对现场工作人员进行了宣传与教育,并依法没收了加工工具及食品。 -------------------------------------------------------------------------------- /lamb_optimizer_google.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Functions and classes related to optimization (weight updates).""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import re 24 | import six 25 | import tensorflow as tf 26 | 27 | # pylint: disable=g-direct-tensorflow-import 28 | from tensorflow.python.ops import array_ops 29 | from tensorflow.python.ops import linalg_ops 30 | from tensorflow.python.ops import math_ops 31 | # pylint: enable=g-direct-tensorflow-import 32 | 33 | 34 | class LAMBOptimizer(tf.train.Optimizer): 35 | """LAMB (Layer-wise Adaptive Moments optimizer for Batch training).""" 36 | # A new optimizer that includes correct L2 weight decay, adaptive 37 | # element-wise updating, and layer-wise justification. The LAMB optimizer 38 | # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song, 39 | # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT 40 | # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962) 41 | 42 | def __init__(self, 43 | learning_rate, 44 | weight_decay_rate=0.0, 45 | beta_1=0.9, 46 | beta_2=0.999, 47 | epsilon=1e-6, 48 | exclude_from_weight_decay=None, 49 | exclude_from_layer_adaptation=None, 50 | name="LAMBOptimizer"): 51 | """Constructs a LAMBOptimizer.""" 52 | super(LAMBOptimizer, self).__init__(False, name) 53 | 54 | self.learning_rate = learning_rate 55 | self.weight_decay_rate = weight_decay_rate 56 | self.beta_1 = beta_1 57 | self.beta_2 = beta_2 58 | self.epsilon = epsilon 59 | self.exclude_from_weight_decay = exclude_from_weight_decay 60 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 61 | # arg is None. 62 | # TODO(jingli): validate if exclude_from_layer_adaptation is necessary. 63 | if exclude_from_layer_adaptation: 64 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 65 | else: 66 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 67 | 68 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 69 | """See base class.""" 70 | assignments = [] 71 | for (grad, param) in grads_and_vars: 72 | if grad is None or param is None: 73 | continue 74 | 75 | param_name = self._get_variable_name(param.name) 76 | 77 | m = tf.get_variable( 78 | name=six.ensure_str(param_name) + "/adam_m", 79 | shape=param.shape.as_list(), 80 | dtype=tf.float32, 81 | trainable=False, 82 | initializer=tf.zeros_initializer()) 83 | v = tf.get_variable( 84 | name=six.ensure_str(param_name) + "/adam_v", 85 | shape=param.shape.as_list(), 86 | dtype=tf.float32, 87 | trainable=False, 88 | initializer=tf.zeros_initializer()) 89 | 90 | # Standard Adam update. 91 | next_m = ( 92 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 93 | next_v = ( 94 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 95 | tf.square(grad))) 96 | 97 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 98 | 99 | # Just adding the square of the weights to the loss function is *not* 100 | # the correct way of using L2 regularization/weight decay with Adam, 101 | # since that will interact with the m and v parameters in strange ways. 102 | # 103 | # Instead we want ot decay the weights in a manner that doesn't interact 104 | # with the m/v parameters. This is equivalent to adding the square 105 | # of the weights to the loss with plain (non-momentum) SGD. 106 | if self._do_use_weight_decay(param_name): 107 | update += self.weight_decay_rate * param 108 | 109 | ratio = 1.0 110 | if self._do_layer_adaptation(param_name): 111 | w_norm = linalg_ops.norm(param, ord=2) 112 | g_norm = linalg_ops.norm(update, ord=2) 113 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( 114 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) 115 | 116 | update_with_lr = ratio * self.learning_rate * update 117 | 118 | next_param = param - update_with_lr 119 | 120 | assignments.extend( 121 | [param.assign(next_param), 122 | m.assign(next_m), 123 | v.assign(next_v)]) 124 | return tf.group(*assignments, name=name) 125 | 126 | def _do_use_weight_decay(self, param_name): 127 | """Whether to use L2 weight decay for `param_name`.""" 128 | if not self.weight_decay_rate: 129 | return False 130 | if self.exclude_from_weight_decay: 131 | for r in self.exclude_from_weight_decay: 132 | if re.search(r, param_name) is not None: 133 | return False 134 | return True 135 | 136 | def _do_layer_adaptation(self, param_name): 137 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 138 | if self.exclude_from_layer_adaptation: 139 | for r in self.exclude_from_layer_adaptation: 140 | if re.search(r, param_name) is not None: 141 | return False 142 | return True 143 | 144 | def _get_variable_name(self, param_name): 145 | """Get the variable name from the tensor name.""" 146 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 147 | if m is not None: 148 | param_name = m.group(1) 149 | return param_name 150 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = LAMBOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | 176 | 177 | # 178 | class LAMBOptimizer(tf.train.Optimizer): 179 | """ 180 | LAMBOptimizer optimizer. 181 | https://github.com/ymcui/LAMB_Optimizer_TF 182 | # IMPORTANT NOTE 183 | - This is NOT an official implementation. 184 | - LAMB optimizer is changed from arXiv v1 ~ v3. 185 | - We implement v3 version (which is the latest version on June, 2019.). 186 | - Our implementation is based on `AdamWeightDecayOptimizer` in BERT (provided by Google). 187 | 188 | # References 189 | - Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962v3 190 | - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805 191 | # Parameters 192 | - There is nothing special, just the same as `AdamWeightDecayOptimizer`. 193 | """ 194 | 195 | def __init__(self, 196 | learning_rate, 197 | weight_decay_rate=0.01, 198 | beta_1=0.9, 199 | beta_2=0.999, 200 | epsilon=1e-6, 201 | exclude_from_weight_decay=None, 202 | name="LAMBOptimizer"): 203 | """Constructs a LAMBOptimizer.""" 204 | super(LAMBOptimizer, self).__init__(False, name) 205 | 206 | self.learning_rate = learning_rate 207 | self.weight_decay_rate = weight_decay_rate 208 | self.beta_1 = beta_1 209 | self.beta_2 = beta_2 210 | self.epsilon = epsilon 211 | self.exclude_from_weight_decay = exclude_from_weight_decay 212 | 213 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 214 | """See base class.""" 215 | assignments = [] 216 | for (grad, param) in grads_and_vars: 217 | if grad is None or param is None: 218 | continue 219 | 220 | param_name = self._get_variable_name(param.name) 221 | 222 | m = tf.get_variable( 223 | name=param_name + "/lamb_m", 224 | shape=param.shape.as_list(), 225 | dtype=tf.float32, 226 | trainable=False, 227 | initializer=tf.zeros_initializer()) 228 | v = tf.get_variable( 229 | name=param_name + "/lamb_v", 230 | shape=param.shape.as_list(), 231 | dtype=tf.float32, 232 | trainable=False, 233 | initializer=tf.zeros_initializer()) 234 | 235 | # Standard Adam update. 236 | next_m = ( 237 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 238 | next_v = ( 239 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 240 | tf.square(grad))) 241 | 242 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 243 | 244 | # Just adding the square of the weights to the loss function is *not* 245 | # the correct way of using L2 regularization/weight decay with Adam, 246 | # since that will interact with the m and v parameters in strange ways. 247 | # 248 | # Instead we want ot decay the weights in a manner that doesn't interact 249 | # with the m/v parameters. This is equivalent to adding the square 250 | # of the weights to the loss with plain (non-momentum) SGD. 251 | if self._do_use_weight_decay(param_name): 252 | update += self.weight_decay_rate * param 253 | 254 | ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ############## 255 | 256 | # Note: Here are two choices for scaling function \phi(z) 257 | # minmax: \phi(z) = min(max(z, \gamma_l), \gamma_u) 258 | # identity: \phi(z) = z 259 | # The authors does not mention what is \gamma_l and \gamma_u 260 | # UPDATE: after asking authors, they provide me the code below. 261 | # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( 262 | # math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) 263 | 264 | r1 = tf.sqrt(tf.reduce_sum(tf.square(param))) 265 | r2 = tf.sqrt(tf.reduce_sum(tf.square(update))) 266 | 267 | r = tf.where(tf.greater(r1, 0.0), 268 | tf.where(tf.greater(r2, 0.0), 269 | r1 / r2, 270 | 1.0), 271 | 1.0) 272 | 273 | eta = self.learning_rate * r 274 | 275 | update_with_lr = eta * update 276 | 277 | next_param = param - update_with_lr 278 | 279 | assignments.extend( 280 | [param.assign(next_param), 281 | m.assign(next_m), 282 | v.assign(next_v)]) 283 | return tf.group(*assignments, name=name) 284 | 285 | def _do_use_weight_decay(self, param_name): 286 | """Whether to use L2 weight decay for `param_name`.""" 287 | if not self.weight_decay_rate: 288 | return False 289 | if self.exclude_from_weight_decay: 290 | for r in self.exclude_from_weight_decay: 291 | if re.search(r, param_name) is not None: 292 | return False 293 | return True 294 | 295 | def _get_variable_name(self, param_name): 296 | """Get the variable name from the tensor name.""" 297 | m = re.match("^(.*):\\d+$", param_name) 298 | if m is not None: 299 | param_name = m.group(1) 300 | return param_name -------------------------------------------------------------------------------- /optimization_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /optimization_google.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Functions and classes related to optimization (weight updates).""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import re 24 | 25 | import six 26 | from six.moves import zip 27 | import tensorflow as tf 28 | 29 | import lamb_optimizer_google as lamb_optimizer 30 | 31 | 32 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, 33 | optimizer="adamw", poly_power=1.0, start_warmup_step=0): 34 | """Creates an optimizer training op.""" 35 | global_step = tf.train.get_or_create_global_step() 36 | 37 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 38 | 39 | # Implements linear decay of the learning rate. 40 | learning_rate = tf.train.polynomial_decay( 41 | learning_rate, 42 | global_step, 43 | num_train_steps, 44 | end_learning_rate=0.0, 45 | power=poly_power, 46 | cycle=False) 47 | 48 | # Implements linear warmup. I.e., if global_step - start_warmup_step < 49 | # num_warmup_steps, the learning rate will be 50 | # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`. 51 | if num_warmup_steps: 52 | tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step) 53 | + ", for " + str(num_warmup_steps) + " steps ++++++") 54 | global_steps_int = tf.cast(global_step, tf.int32) 55 | start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32) 56 | global_steps_int = global_steps_int - start_warm_int 57 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 58 | 59 | global_steps_float = tf.cast(global_steps_int, tf.float32) 60 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 61 | 62 | warmup_percent_done = global_steps_float / warmup_steps_float 63 | warmup_learning_rate = init_lr * warmup_percent_done 64 | 65 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 66 | learning_rate = ( 67 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 68 | 69 | # It is OK that you use this optimizer for finetuning, since this 70 | # is how the model was trained (note that the Adam m/v variables are NOT 71 | # loaded from init_checkpoint.) 72 | # It is OK to use AdamW in the finetuning even the model is trained by LAMB. 73 | # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune 74 | # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a 75 | # batch size of 64 in the finetune. 76 | if optimizer == "adamw": 77 | tf.logging.info("using adamw") 78 | optimizer = AdamWeightDecayOptimizer( 79 | learning_rate=learning_rate, 80 | weight_decay_rate=0.01, 81 | beta_1=0.9, 82 | beta_2=0.999, 83 | epsilon=1e-6, 84 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 85 | elif optimizer == "lamb": 86 | tf.logging.info("using lamb") 87 | optimizer = lamb_optimizer.LAMBOptimizer( 88 | learning_rate=learning_rate, 89 | weight_decay_rate=0.01, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 94 | else: 95 | raise ValueError("Not supported optimizer: ", optimizer) 96 | 97 | if use_tpu: 98 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 99 | 100 | tvars = tf.trainable_variables() 101 | grads = tf.gradients(loss, tvars) 102 | 103 | # This is how the model was pre-trained. 104 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 105 | 106 | train_op = optimizer.apply_gradients( 107 | list(zip(grads, tvars)), global_step=global_step) 108 | 109 | # Normally the global step update is done inside of `apply_gradients`. 110 | # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this. 111 | # But if you use a different optimizer, you should probably take this line 112 | # out. 113 | new_global_step = global_step + 1 114 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 115 | return train_op 116 | 117 | 118 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 119 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 120 | 121 | def __init__(self, 122 | learning_rate, 123 | weight_decay_rate=0.0, 124 | beta_1=0.9, 125 | beta_2=0.999, 126 | epsilon=1e-6, 127 | exclude_from_weight_decay=None, 128 | name="AdamWeightDecayOptimizer"): 129 | """Constructs a AdamWeightDecayOptimizer.""" 130 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 131 | 132 | self.learning_rate = learning_rate 133 | self.weight_decay_rate = weight_decay_rate 134 | self.beta_1 = beta_1 135 | self.beta_2 = beta_2 136 | self.epsilon = epsilon 137 | self.exclude_from_weight_decay = exclude_from_weight_decay 138 | 139 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 140 | """See base class.""" 141 | assignments = [] 142 | for (grad, param) in grads_and_vars: 143 | if grad is None or param is None: 144 | continue 145 | 146 | param_name = self._get_variable_name(param.name) 147 | 148 | m = tf.get_variable( 149 | name=six.ensure_str(param_name) + "/adam_m", 150 | shape=param.shape.as_list(), 151 | dtype=tf.float32, 152 | trainable=False, 153 | initializer=tf.zeros_initializer()) 154 | v = tf.get_variable( 155 | name=six.ensure_str(param_name) + "/adam_v", 156 | shape=param.shape.as_list(), 157 | dtype=tf.float32, 158 | trainable=False, 159 | initializer=tf.zeros_initializer()) 160 | 161 | # Standard Adam update. 162 | next_m = ( 163 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 164 | next_v = ( 165 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 166 | tf.square(grad))) 167 | 168 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 169 | 170 | # Just adding the square of the weights to the loss function is *not* 171 | # the correct way of using L2 regularization/weight decay with Adam, 172 | # since that will interact with the m and v parameters in strange ways. 173 | # 174 | # Instead we want ot decay the weights in a manner that doesn't interact 175 | # with the m/v parameters. This is equivalent to adding the square 176 | # of the weights to the loss with plain (non-momentum) SGD. 177 | if self._do_use_weight_decay(param_name): 178 | update += self.weight_decay_rate * param 179 | 180 | update_with_lr = self.learning_rate * update 181 | 182 | next_param = param - update_with_lr 183 | 184 | assignments.extend( 185 | [param.assign(next_param), 186 | m.assign(next_m), 187 | v.assign(next_v)]) 188 | return tf.group(*assignments, name=name) 189 | 190 | def _do_use_weight_decay(self, param_name): 191 | """Whether to use L2 weight decay for `param_name`.""" 192 | if not self.weight_decay_rate: 193 | return False 194 | if self.exclude_from_weight_decay: 195 | for r in self.exclude_from_weight_decay: 196 | if re.search(r, param_name) is not None: 197 | return False 198 | return True 199 | 200 | def _get_variable_name(self, param_name): 201 | """Get the variable name from the tensor name.""" 202 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 203 | if m is not None: 204 | param_name = m.group(1) 205 | return param_name 206 | -------------------------------------------------------------------------------- /resources/add_data_removing_dropout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/add_data_removing_dropout.jpg -------------------------------------------------------------------------------- /resources/albert_configuration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/albert_configuration.jpg -------------------------------------------------------------------------------- /resources/albert_large_zh_parameters.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/albert_large_zh_parameters.jpg -------------------------------------------------------------------------------- /resources/albert_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/albert_performance.jpg -------------------------------------------------------------------------------- /resources/albert_tiny_compare_s.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/albert_tiny_compare_s.jpg -------------------------------------------------------------------------------- /resources/albert_tiny_compare_s_old.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/albert_tiny_compare_s_old.jpg -------------------------------------------------------------------------------- /resources/crmc2018_compare_s.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/crmc2018_compare_s.jpg -------------------------------------------------------------------------------- /resources/shell_scripts/create_pretrain_data_batch_webtext.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo $1,$2 3 | 4 | BERT_BASE_DIR=./bert_config 5 | for((i=$1;i<=$2;i++)); 6 | do 7 | python3 create_pretraining_data.py --do_whole_word_mask=True --input_file=gs://raw_text/web_text_zh_raw/web_text_zh_$i.txt \ 8 | --output_file=gs://albert_zh/tf_records/tf_web_text_zh_$i.tfrecord --vocab_file=$BERT_BASE_DIR/vocab.txt --do_lower_case=True \ 9 | --max_seq_length=512 --max_predictions_per_seq=76 --masked_lm_prob=0.15 10 | done 11 | -------------------------------------------------------------------------------- /resources/state_of_the_art.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/state_of_the_art.jpg -------------------------------------------------------------------------------- /resources/xlarge_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/albert_zh/52149e82faf3eddd02f31dc2942a4d06bc78c247/resources/xlarge_loss.jpg -------------------------------------------------------------------------------- /run_classifier_clue.sh: -------------------------------------------------------------------------------- 1 | # @Author: bo.shi 2 | # @Date: 2020-03-15 16:11:00 3 | # @Last Modified by: bo.shi 4 | # @Last Modified time: 2020-04-02 17:54:05 5 | #!/usr/bin/env bash 6 | 7 | export CUDA_VISIBLE_DEVICES="0" 8 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 9 | CLUE_DATA_DIR=$CURRENT_DIR/CLUEdataset 10 | ALBERT_TINY_DIR=$CURRENT_DIR/albert_tiny 11 | 12 | download_data(){ 13 | TASK_NAME=$1 14 | if [ ! -d $CLUE_DATA_DIR ]; then 15 | mkdir -p $CLUE_DATA_DIR 16 | echo "makedir $CLUE_DATA_DIR" 17 | fi 18 | cd $CLUE_DATA_DIR 19 | if [ ! -d ${TASK_NAME} ]; then 20 | mkdir $TASK_NAME 21 | echo "make dataset dir $CLUE_DATA_DIR/$TASK_NAME" 22 | fi 23 | cd $TASK_NAME 24 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 25 | rm * 26 | wget https://storage.googleapis.com/cluebenchmark/tasks/${TASK_NAME}_public.zip 27 | unzip ${TASK_NAME}_public.zip 28 | rm ${TASK_NAME}_public.zip 29 | else 30 | echo "data exists" 31 | fi 32 | echo "Finish download dataset." 33 | } 34 | 35 | download_model(){ 36 | if [ ! -d $ALBERT_TINY_DIR ]; then 37 | mkdir -p $ALBERT_TINY_DIR 38 | echo "makedir $ALBERT_TINY_DIR" 39 | fi 40 | cd $ALBERT_TINY_DIR 41 | if [ ! -f "albert_config_tiny.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "checkpoint" ] || [ ! -f "albert_model.ckpt.index" ] || [ ! -f "albert_model.ckpt.meta" ] || [ ! -f "albert_model.ckpt.data-00000-of-00001" ]; then 42 | rm * 43 | wget -c https://storage.googleapis.com/albert_zh/albert_tiny_489k.zip 44 | unzip albert_tiny_489k.zip 45 | rm albert_tiny_489k.zip 46 | else 47 | echo "model exists" 48 | fi 49 | echo "Finish download model." 50 | } 51 | 52 | run_task() { 53 | TASK_NAME=$1 54 | download_data $TASK_NAME 55 | download_model $MODEL_NAME 56 | DATA_DIR=$CLUE_DATA_DIR/${TASK_NAME} 57 | PREV_TRAINED_MODEL_DIR=$ALBERT_TINY_DIR 58 | MAX_SEQ_LENGTH=$2 59 | TRAIN_BATCH_SIZE=$3 60 | LEARNING_RATE=$4 61 | NUM_TRAIN_EPOCHS=$5 62 | SAVE_CHECKPOINTS_STEPS=$6 63 | OUTPUT_DIR=$CURRENT_DIR/${TASK_NAME}_output/ 64 | COMMON_ARGS=" 65 | --task_name=$TASK_NAME \ 66 | --data_dir=$DATA_DIR \ 67 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 68 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/albert_config_tiny.json \ 69 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/albert_model.ckpt \ 70 | --max_seq_length=$MAX_SEQ_LENGTH \ 71 | --train_batch_size=$TRAIN_BATCH_SIZE \ 72 | --learning_rate=$LEARNING_RATE \ 73 | --num_train_epochs=$NUM_TRAIN_EPOCHS \ 74 | --save_checkpoints_steps=$SAVE_CHECKPOINTS_STEPS \ 75 | --output_dir=$OUTPUT_DIR \ 76 | --keep_checkpoint_max=0 \ 77 | " 78 | cd $CURRENT_DIR 79 | echo "Start running..." 80 | python run_classifier_clue.py \ 81 | $COMMON_ARGS \ 82 | --do_train=true \ 83 | --do_eval=false \ 84 | --do_predict=false 85 | 86 | echo "Start predict..." 87 | python run_classifier_clue.py \ 88 | $COMMON_ARGS \ 89 | --do_train=false \ 90 | --do_eval=true \ 91 | --do_predict=true 92 | } 93 | 94 | ##command##task_name##model_name##max_seq_length##train_batch_size##learning_rate##num_train_epochs##save_checkpoints_steps##tpu_ip 95 | run_task afqmc 128 16 2e-5 3 300 96 | run_task cmnli 128 64 3e-5 2 300 97 | run_task csl 128 16 1e-5 5 100 98 | run_task iflytek 128 32 2e-5 3 300 99 | run_task tnews 128 16 2e-5 3 300 100 | run_task wsc 128 16 1e-5 10 10 -------------------------------------------------------------------------------- /run_classifier_lcqmc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # @Author: bo.shi, https://github.com/chineseGLUE/chineseGLUE 3 | # @Date: 2019-11-04 09:56:36 4 | # @Last Modified by: bright 5 | # @Last Modified time: 2019-11-10 09:00:00 6 | 7 | TASK_NAME="lcqmc" 8 | MODEL_NAME="albert_tiny_zh" 9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 10 | 11 | export CUDA_VISIBLE_DEVICES="0" 12 | export ALBERT_CONFIG_DIR=$CURRENT_DIR/albert_config 13 | export ALBERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model 14 | export ALBERT_TINY_DIR=$ALBERT_PRETRAINED_MODELS_DIR/$MODEL_NAME 15 | #mkdir chineseGLUEdatasets 16 | export GLUE_DATA_DIR=$CURRENT_DIR/chineseGLUEdatasets 17 | 18 | # download and unzip dataset 19 | if [ ! -d $GLUE_DATA_DIR ]; then 20 | mkdir -p $GLUE_DATA_DIR 21 | echo "makedir $GLUE_DATA_DIR" 22 | fi 23 | cd $GLUE_DATA_DIR 24 | if [ ! -d $TASK_NAME ]; then 25 | mkdir $TASK_NAME 26 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME" 27 | fi 28 | cd $TASK_NAME 29 | echo "Please try again if the data is not downloaded successfully." 30 | wget -c https://raw.githubusercontent.com/pengming617/text_matching/master/data/train.txt 31 | wget -c https://raw.githubusercontent.com/pengming617/text_matching/master/data/dev.txt 32 | wget -c https://raw.githubusercontent.com/pengming617/text_matching/master/data/test.txt 33 | echo "Finish download dataset." 34 | 35 | # download model 36 | if [ ! -d $ALBERT_TINY_DIR ]; then 37 | mkdir -p $ALBERT_TINY_DIR 38 | echo "makedir $ALBERT_TINY_DIR" 39 | fi 40 | cd $ALBERT_TINY_DIR 41 | if [ ! -f "albert_config_tiny.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "checkpoint" ] || [ ! -f "albert_model.ckpt.index" ] || [ ! -f "albert_model.ckpt.meta" ] || [ ! -f "albert_model.ckpt.data-00000-of-00001" ]; then 42 | rm * 43 | wget https://storage.googleapis.com/albert_zh/albert_tiny_489k.zip 44 | unzip albert_tiny_489k.zip 45 | rm albert_tiny_489k.zip 46 | else 47 | echo "model exists" 48 | fi 49 | echo "Finish download model." 50 | 51 | # run task 52 | cd $CURRENT_DIR 53 | echo "Start running..." 54 | python run_classifier.py \ 55 | --task_name=$TASK_NAME \ 56 | --do_train=true \ 57 | --do_eval=true \ 58 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \ 59 | --vocab_file=$ALBERT_CONFIG_DIR/vocab.txt \ 60 | --bert_config_file=$ALBERT_CONFIG_DIR/albert_config_tiny.json \ 61 | --init_checkpoint=$ALBERT_TINY_DIR/albert_model.ckpt \ 62 | --max_seq_length=128 \ 63 | --train_batch_size=64 \ 64 | --learning_rate=1e-4 \ 65 | --num_train_epochs=5.0 \ 66 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/ 67 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(),model.get_embedding_table_2(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | print("init_checkpoint:",init_checkpoint) 154 | scaffold_fn = None 155 | if init_checkpoint: 156 | (assignment_map, initialized_variable_names 157 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 158 | if use_tpu: 159 | 160 | def tpu_scaffold(): 161 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 162 | return tf.train.Scaffold() 163 | 164 | scaffold_fn = tpu_scaffold 165 | else: 166 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 167 | 168 | tf.logging.info("**** Trainable Variables ****") 169 | for var in tvars: 170 | init_string = "" 171 | if var.name in initialized_variable_names: 172 | init_string = ", *INIT_FROM_CKPT*" 173 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 174 | init_string) 175 | 176 | output_spec = None 177 | if mode == tf.estimator.ModeKeys.TRAIN: 178 | train_op = optimization.create_optimizer( 179 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 180 | 181 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 182 | mode=mode, 183 | loss=total_loss, 184 | train_op=train_op, 185 | scaffold_fn=scaffold_fn) 186 | elif mode == tf.estimator.ModeKeys.EVAL: 187 | 188 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 189 | masked_lm_weights, next_sentence_example_loss, 190 | next_sentence_log_probs, next_sentence_labels): 191 | """Computes the loss and accuracy of the model.""" 192 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,[-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) 194 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 195 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 196 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 197 | masked_lm_accuracy = tf.metrics.accuracy( 198 | labels=masked_lm_ids, 199 | predictions=masked_lm_predictions, 200 | weights=masked_lm_weights) 201 | masked_lm_mean_loss = tf.metrics.mean( 202 | values=masked_lm_example_loss, weights=masked_lm_weights) 203 | 204 | next_sentence_log_probs = tf.reshape( 205 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 206 | next_sentence_predictions = tf.argmax( 207 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 208 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 209 | next_sentence_accuracy = tf.metrics.accuracy( 210 | labels=next_sentence_labels, predictions=next_sentence_predictions) 211 | next_sentence_mean_loss = tf.metrics.mean( 212 | values=next_sentence_example_loss) 213 | 214 | return { 215 | "masked_lm_accuracy": masked_lm_accuracy, 216 | "masked_lm_loss": masked_lm_mean_loss, 217 | "next_sentence_accuracy": next_sentence_accuracy, 218 | "next_sentence_loss": next_sentence_mean_loss, 219 | } 220 | 221 | # next_sentence_example_loss=0.0 TODO 222 | # next_sentence_log_probs=0.0 # TODO 223 | eval_metrics = (metric_fn, [ 224 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 225 | masked_lm_weights, next_sentence_example_loss, 226 | next_sentence_log_probs, next_sentence_labels 227 | ]) 228 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 229 | mode=mode, 230 | loss=total_loss, 231 | eval_metrics=eval_metrics, 232 | scaffold_fn=scaffold_fn) 233 | else: 234 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 235 | 236 | return output_spec 237 | 238 | return model_fn 239 | 240 | 241 | def get_masked_lm_output(bert_config, input_tensor, output_weights,project_weights, positions, 242 | label_ids, label_weights): 243 | """Get loss and log probs for the masked LM.""" 244 | input_tensor = gather_indexes(input_tensor, positions) 245 | 246 | with tf.variable_scope("cls/predictions"): 247 | # We apply one more non-linear transformation before the output layer. 248 | # This matrix is not used after pre-training. 249 | with tf.variable_scope("transform"): 250 | input_tensor = tf.layers.dense( 251 | input_tensor, 252 | units=bert_config.hidden_size, 253 | activation=modeling.get_activation(bert_config.hidden_act), 254 | kernel_initializer=modeling.create_initializer( 255 | bert_config.initializer_range)) 256 | input_tensor = modeling.layer_norm(input_tensor) 257 | 258 | # The output weights are the same as the input embeddings, but there is 259 | # an output-only bias for each token. 260 | output_bias = tf.get_variable( 261 | "output_bias", 262 | shape=[bert_config.vocab_size], 263 | initializer=tf.zeros_initializer()) 264 | # logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 265 | # input_tensor=[-1,hidden_size], project_weights=[embedding_size, hidden_size], project_weights_transpose=[hidden_size, embedding_size]--->[-1, embedding_size] 266 | input_project = tf.matmul(input_tensor, project_weights, transpose_b=True) 267 | logits = tf.matmul(input_project, output_weights, transpose_b=True) 268 | # # input_project=[-1, embedding_size], output_weights=[vocab_size, embedding_size], output_weights_transpose=[embedding_size, vocab_size] ---> [-1, vocab_size] 269 | 270 | logits = tf.nn.bias_add(logits, output_bias) 271 | log_probs = tf.nn.log_softmax(logits, axis=-1) 272 | 273 | label_ids = tf.reshape(label_ids, [-1]) 274 | label_weights = tf.reshape(label_weights, [-1]) 275 | 276 | one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 277 | 278 | # The `positions` tensor might be zero-padded (if the sequence is too 279 | # short to have the maximum number of predictions). The `label_weights` 280 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 281 | # padding predictions. 282 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 283 | numerator = tf.reduce_sum(label_weights * per_example_loss) 284 | denominator = tf.reduce_sum(label_weights) + 1e-5 285 | loss = numerator / denominator 286 | 287 | return (loss, per_example_loss, log_probs) 288 | 289 | 290 | def get_next_sentence_output(bert_config, input_tensor, labels): 291 | """Get loss and log probs for the next sentence prediction.""" 292 | 293 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 294 | # "random sentence". This weight matrix is not used after pre-training. 295 | with tf.variable_scope("cls/seq_relationship"): 296 | output_weights = tf.get_variable( 297 | "output_weights", 298 | shape=[2, bert_config.hidden_size], 299 | initializer=modeling.create_initializer(bert_config.initializer_range)) 300 | output_bias = tf.get_variable( 301 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 302 | 303 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 304 | logits = tf.nn.bias_add(logits, output_bias) 305 | log_probs = tf.nn.log_softmax(logits, axis=-1) 306 | labels = tf.reshape(labels, [-1]) 307 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 308 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 309 | loss = tf.reduce_mean(per_example_loss) 310 | return (loss, per_example_loss, log_probs) 311 | 312 | 313 | def gather_indexes(sequence_tensor, positions): 314 | """Gathers the vectors at the specific positions over a minibatch.""" 315 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 316 | batch_size = sequence_shape[0] 317 | seq_length = sequence_shape[1] 318 | width = sequence_shape[2] 319 | 320 | flat_offsets = tf.reshape( 321 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 322 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 323 | flat_sequence_tensor = tf.reshape(sequence_tensor, 324 | [batch_size * seq_length, width]) 325 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 326 | return output_tensor 327 | 328 | 329 | def input_fn_builder(input_files, 330 | max_seq_length, 331 | max_predictions_per_seq, 332 | is_training, 333 | num_cpu_threads=4): 334 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 335 | 336 | def input_fn(params): 337 | """The actual input function.""" 338 | batch_size = params["batch_size"] 339 | 340 | name_to_features = { 341 | "input_ids": 342 | tf.FixedLenFeature([max_seq_length], tf.int64), 343 | "input_mask": 344 | tf.FixedLenFeature([max_seq_length], tf.int64), 345 | "segment_ids": 346 | tf.FixedLenFeature([max_seq_length], tf.int64), 347 | "masked_lm_positions": 348 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 349 | "masked_lm_ids": 350 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 351 | "masked_lm_weights": 352 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 353 | "next_sentence_labels": 354 | tf.FixedLenFeature([1], tf.int64), 355 | } 356 | 357 | # For training, we want a lot of parallel reading and shuffling. 358 | # For eval, we want no shuffling and parallel reading doesn't matter. 359 | if is_training: 360 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 361 | d = d.repeat() 362 | d = d.shuffle(buffer_size=len(input_files)) 363 | 364 | # `cycle_length` is the number of parallel files that get read. 365 | cycle_length = min(num_cpu_threads, len(input_files)) 366 | 367 | # `sloppy` mode means that the interleaving is not exact. This adds 368 | # even more randomness to the training pipeline. 369 | d = d.apply( 370 | tf.contrib.data.parallel_interleave( 371 | tf.data.TFRecordDataset, 372 | sloppy=is_training, 373 | cycle_length=cycle_length)) 374 | d = d.shuffle(buffer_size=100) 375 | else: 376 | d = tf.data.TFRecordDataset(input_files) 377 | # Since we evaluate for a fixed number of steps we don't want to encounter 378 | # out-of-range exceptions. 379 | d = d.repeat() 380 | 381 | # We must `drop_remainder` on training because the TPU requires fixed 382 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 383 | # and we *don't* want to drop the remainder, otherwise we wont cover 384 | # every sample. 385 | d = d.apply( 386 | tf.contrib.data.map_and_batch( 387 | lambda record: _decode_record(record, name_to_features), 388 | batch_size=batch_size, 389 | num_parallel_batches=num_cpu_threads, 390 | drop_remainder=True)) 391 | return d 392 | 393 | return input_fn 394 | 395 | 396 | def _decode_record(record, name_to_features): 397 | """Decodes a record to a TensorFlow example.""" 398 | example = tf.parse_single_example(record, name_to_features) 399 | 400 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 401 | # So cast all int64 to int32. 402 | for name in list(example.keys()): 403 | t = example[name] 404 | if t.dtype == tf.int64: 405 | t = tf.to_int32(t) 406 | example[name] = t 407 | 408 | return example 409 | 410 | 411 | def main(_): 412 | tf.logging.set_verbosity(tf.logging.INFO) 413 | 414 | if not FLAGS.do_train and not FLAGS.do_eval: # 必须是训练或验证的类型 415 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 416 | 417 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) # 从json文件中获得配置信息 418 | 419 | tf.gfile.MakeDirs(FLAGS.output_dir) 420 | 421 | input_files = [] # 输入可以是多个文件,以“逗号隔开”;可以是一个匹配形式的,如“input_x*” 422 | for input_pattern in FLAGS.input_file.split(","): 423 | input_files.extend(tf.gfile.Glob(input_pattern)) 424 | 425 | tf.logging.info("*** Input Files ***") 426 | for input_file in input_files: 427 | tf.logging.info(" %s" % input_file) 428 | 429 | tpu_cluster_resolver = None 430 | if FLAGS.use_tpu and FLAGS.tpu_name: 431 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( # TODO 432 | tpu=FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 433 | 434 | print("###tpu_cluster_resolver:",tpu_cluster_resolver,";FLAGS.use_tpu:",FLAGS.use_tpu,";FLAGS.tpu_name:",FLAGS.tpu_name,";FLAGS.tpu_zone:",FLAGS.tpu_zone) 435 | # ###tpu_cluster_resolver: ;FLAGS.use_tpu: True ;FLAGS.tpu_name: grpc://10.240.1.83:8470 436 | 437 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 438 | run_config = tf.contrib.tpu.RunConfig( 439 | keep_checkpoint_max=20, # 10 440 | cluster=tpu_cluster_resolver, 441 | master=FLAGS.master, 442 | model_dir=FLAGS.output_dir, 443 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 444 | tpu_config=tf.contrib.tpu.TPUConfig( 445 | iterations_per_loop=FLAGS.iterations_per_loop, 446 | num_shards=FLAGS.num_tpu_cores, 447 | per_host_input_for_training=is_per_host)) 448 | 449 | model_fn = model_fn_builder( 450 | bert_config=bert_config, 451 | init_checkpoint=FLAGS.init_checkpoint, 452 | learning_rate=FLAGS.learning_rate, 453 | num_train_steps=FLAGS.num_train_steps, 454 | num_warmup_steps=FLAGS.num_warmup_steps, 455 | use_tpu=FLAGS.use_tpu, 456 | use_one_hot_embeddings=FLAGS.use_tpu) 457 | 458 | # If TPU is not available, this will fall back to normal Estimator on CPU 459 | # or GPU. 460 | estimator = tf.contrib.tpu.TPUEstimator( 461 | use_tpu=FLAGS.use_tpu, 462 | model_fn=model_fn, 463 | config=run_config, 464 | train_batch_size=FLAGS.train_batch_size, 465 | eval_batch_size=FLAGS.eval_batch_size) 466 | 467 | if FLAGS.do_train: 468 | tf.logging.info("***** Running training *****") 469 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 470 | train_input_fn = input_fn_builder( 471 | input_files=input_files, 472 | max_seq_length=FLAGS.max_seq_length, 473 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 474 | is_training=True) 475 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 476 | 477 | if FLAGS.do_eval: 478 | tf.logging.info("***** Running evaluation *****") 479 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 480 | 481 | eval_input_fn = input_fn_builder( 482 | input_files=input_files, 483 | max_seq_length=FLAGS.max_seq_length, 484 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 485 | is_training=False) 486 | 487 | result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 488 | 489 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 490 | with tf.gfile.GFile(output_eval_file, "w") as writer: 491 | tf.logging.info("***** Eval results *****") 492 | for key in sorted(result.keys()): 493 | tf.logging.info(" %s = %s", key, str(result[key])) 494 | writer.write("%s = %s\n" % (key, str(result[key]))) 495 | 496 | 497 | if __name__ == "__main__": 498 | flags.mark_flag_as_required("input_file") 499 | flags.mark_flag_as_required("bert_config_file") 500 | flags.mark_flag_as_required("output_dir") 501 | tf.app.run() 502 | -------------------------------------------------------------------------------- /run_pretraining_google.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Run masked LM/next sentence masked_lm pre-training for ALBERT.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import os 24 | import time 25 | 26 | from six.moves import range 27 | import tensorflow as tf 28 | 29 | import modeling_google as modeling 30 | import optimization_google as optimization 31 | 32 | flags = tf.flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | ## Required parameters 37 | flags.DEFINE_string( 38 | "albert_config_file", None, 39 | "The config json file corresponding to the pre-trained ALBERT model. " 40 | "This specifies the model architecture.") 41 | 42 | flags.DEFINE_string( 43 | "input_file", None, 44 | "Input TF example files (can be a glob or comma separated).") 45 | 46 | flags.DEFINE_string( 47 | "output_dir", None, 48 | "The output directory where the model checkpoints will be written.") 49 | 50 | flags.DEFINE_string( 51 | "export_dir", None, 52 | "The output directory where the saved models will be written.") 53 | ## Other parameters 54 | flags.DEFINE_string( 55 | "init_checkpoint", None, 56 | "Initial checkpoint (usually from a pre-trained ALBERT model).") 57 | 58 | flags.DEFINE_integer( 59 | "max_seq_length", 512, 60 | "The maximum total input sequence length after WordPiece tokenization. " 61 | "Sequences longer than this will be truncated, and sequences shorter " 62 | "than this will be padded. Must match data generation.") 63 | 64 | flags.DEFINE_integer( 65 | "max_predictions_per_seq", 20, 66 | "Maximum number of masked LM predictions per sequence. " 67 | "Must match data generation.") 68 | 69 | flags.DEFINE_bool("do_train", True, "Whether to run training.") 70 | 71 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 72 | 73 | flags.DEFINE_integer("train_batch_size", 4096, "Total batch size for training.") 74 | 75 | flags.DEFINE_integer("eval_batch_size", 64, "Total batch size for eval.") 76 | 77 | flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"], 78 | "The optimizer for training.") 79 | 80 | flags.DEFINE_float("learning_rate", 0.00176, "The initial learning rate.") 81 | 82 | flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.") 83 | 84 | flags.DEFINE_integer("num_train_steps", 125000, "Number of training steps.") 85 | 86 | flags.DEFINE_integer("num_warmup_steps", 3125, "Number of warmup steps.") 87 | 88 | flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.") 89 | 90 | flags.DEFINE_integer("save_checkpoints_steps", 5000, 91 | "How often to save the model checkpoint.") 92 | 93 | flags.DEFINE_integer("iterations_per_loop", 1000, 94 | "How many steps to make in each estimator call.") 95 | 96 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 97 | 98 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 99 | 100 | flags.DEFINE_bool("init_from_group0", False, "Whether to initialize" 101 | "parameters of other groups from group 0") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | flags.DEFINE_float( 128 | "masked_lm_budget", 0, 129 | "If >0, the ratio of masked ngrams to unmasked ngrams. Default 0," 130 | "for offline masking") 131 | 132 | 133 | def model_fn_builder(albert_config, init_checkpoint, learning_rate, 134 | num_train_steps, num_warmup_steps, use_tpu, 135 | use_one_hot_embeddings, optimizer, poly_power, 136 | start_warmup_step): 137 | """Returns `model_fn` closure for TPUEstimator.""" 138 | 139 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 140 | """The `model_fn` for TPUEstimator.""" 141 | 142 | tf.logging.info("*** Features ***") 143 | for name in sorted(features.keys()): 144 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 145 | 146 | input_ids = features["input_ids"] 147 | input_mask = features["input_mask"] 148 | segment_ids = features["segment_ids"] 149 | masked_lm_positions = features["masked_lm_positions"] 150 | masked_lm_ids = features["masked_lm_ids"] 151 | masked_lm_weights = features["masked_lm_weights"] 152 | # Note: We keep this feature name `next_sentence_labels` to be compatible 153 | # with the original data created by lanzhzh@. However, in the ALBERT case 154 | # it does represent sentence_order_labels. 155 | sentence_order_labels = features["next_sentence_labels"] 156 | 157 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 158 | 159 | model = modeling.AlbertModel( 160 | config=albert_config, 161 | is_training=is_training, 162 | input_ids=input_ids, 163 | input_mask=input_mask, 164 | token_type_ids=segment_ids, 165 | use_one_hot_embeddings=use_one_hot_embeddings) 166 | 167 | (masked_lm_loss, masked_lm_example_loss, 168 | masked_lm_log_probs) = get_masked_lm_output(albert_config, 169 | model.get_sequence_output(), 170 | model.get_embedding_table(), 171 | masked_lm_positions, 172 | masked_lm_ids, 173 | masked_lm_weights) 174 | 175 | (sentence_order_loss, sentence_order_example_loss, 176 | sentence_order_log_probs) = get_sentence_order_output( 177 | albert_config, model.get_pooled_output(), sentence_order_labels) 178 | 179 | total_loss = masked_lm_loss + sentence_order_loss 180 | 181 | tvars = tf.trainable_variables() 182 | 183 | initialized_variable_names = {} 184 | scaffold_fn = None 185 | if init_checkpoint: 186 | tf.logging.info("number of hidden group %d to initialize", 187 | albert_config.num_hidden_groups) 188 | num_of_initialize_group = 1 189 | if FLAGS.init_from_group0: 190 | num_of_initialize_group = albert_config.num_hidden_groups 191 | if albert_config.net_structure_type > 0: 192 | num_of_initialize_group = albert_config.num_hidden_layers 193 | (assignment_map, initialized_variable_names 194 | ) = modeling.get_assignment_map_from_checkpoint( 195 | tvars, init_checkpoint, num_of_initialize_group) 196 | if use_tpu: 197 | 198 | def tpu_scaffold(): 199 | for gid in range(num_of_initialize_group): 200 | tf.logging.info("initialize the %dth layer", gid) 201 | tf.logging.info(assignment_map[gid]) 202 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid]) 203 | return tf.train.Scaffold() 204 | 205 | scaffold_fn = tpu_scaffold 206 | else: 207 | for gid in range(num_of_initialize_group): 208 | tf.logging.info("initialize the %dth layer", gid) 209 | tf.logging.info(assignment_map[gid]) 210 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid]) 211 | 212 | tf.logging.info("**** Trainable Variables ****") 213 | for var in tvars: 214 | init_string = "" 215 | if var.name in initialized_variable_names: 216 | init_string = ", *INIT_FROM_CKPT*" 217 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 218 | init_string) 219 | 220 | output_spec = None 221 | if mode == tf.estimator.ModeKeys.TRAIN: 222 | train_op = optimization.create_optimizer( 223 | total_loss, learning_rate, num_train_steps, num_warmup_steps, 224 | use_tpu, optimizer, poly_power, start_warmup_step) 225 | 226 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 227 | mode=mode, 228 | loss=total_loss, 229 | train_op=train_op, 230 | scaffold_fn=scaffold_fn) 231 | elif mode == tf.estimator.ModeKeys.EVAL: 232 | 233 | def metric_fn(*args): 234 | """Computes the loss and accuracy of the model.""" 235 | (masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 236 | masked_lm_weights, sentence_order_example_loss, 237 | sentence_order_log_probs, sentence_order_labels) = args[:7] 238 | 239 | 240 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 241 | [-1, masked_lm_log_probs.shape[-1]]) 242 | masked_lm_predictions = tf.argmax( 243 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 244 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 245 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 246 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 247 | masked_lm_accuracy = tf.metrics.accuracy( 248 | labels=masked_lm_ids, 249 | predictions=masked_lm_predictions, 250 | weights=masked_lm_weights) 251 | masked_lm_mean_loss = tf.metrics.mean( 252 | values=masked_lm_example_loss, weights=masked_lm_weights) 253 | 254 | metrics = { 255 | "masked_lm_accuracy": masked_lm_accuracy, 256 | "masked_lm_loss": masked_lm_mean_loss, 257 | } 258 | 259 | sentence_order_log_probs = tf.reshape( 260 | sentence_order_log_probs, [-1, sentence_order_log_probs.shape[-1]]) 261 | sentence_order_predictions = tf.argmax( 262 | sentence_order_log_probs, axis=-1, output_type=tf.int32) 263 | sentence_order_labels = tf.reshape(sentence_order_labels, [-1]) 264 | sentence_order_accuracy = tf.metrics.accuracy( 265 | labels=sentence_order_labels, 266 | predictions=sentence_order_predictions) 267 | sentence_order_mean_loss = tf.metrics.mean( 268 | values=sentence_order_example_loss) 269 | metrics.update({ 270 | "sentence_order_accuracy": sentence_order_accuracy, 271 | "sentence_order_loss": sentence_order_mean_loss 272 | }) 273 | return metrics 274 | 275 | metric_values = [ 276 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 277 | masked_lm_weights, sentence_order_example_loss, 278 | sentence_order_log_probs, sentence_order_labels 279 | ] 280 | 281 | eval_metrics = (metric_fn, metric_values) 282 | 283 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 284 | mode=mode, 285 | loss=total_loss, 286 | eval_metrics=eval_metrics, 287 | scaffold_fn=scaffold_fn) 288 | else: 289 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 290 | 291 | return output_spec 292 | 293 | return model_fn 294 | 295 | 296 | def get_masked_lm_output(albert_config, input_tensor, output_weights, positions, 297 | label_ids, label_weights): 298 | """Get loss and log probs for the masked LM.""" 299 | input_tensor = gather_indexes(input_tensor, positions) 300 | 301 | 302 | with tf.variable_scope("cls/predictions"): 303 | # We apply one more non-linear transformation before the output layer. 304 | # This matrix is not used after pre-training. 305 | with tf.variable_scope("transform"): 306 | input_tensor = tf.layers.dense( 307 | input_tensor, 308 | units=albert_config.embedding_size, 309 | activation=modeling.get_activation(albert_config.hidden_act), 310 | kernel_initializer=modeling.create_initializer( 311 | albert_config.initializer_range)) 312 | input_tensor = modeling.layer_norm(input_tensor) 313 | 314 | # The output weights are the same as the input embeddings, but there is 315 | # an output-only bias for each token. 316 | output_bias = tf.get_variable( 317 | "output_bias", 318 | shape=[albert_config.vocab_size], 319 | initializer=tf.zeros_initializer()) 320 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 321 | logits = tf.nn.bias_add(logits, output_bias) 322 | log_probs = tf.nn.log_softmax(logits, axis=-1) 323 | 324 | label_ids = tf.reshape(label_ids, [-1]) 325 | label_weights = tf.reshape(label_weights, [-1]) 326 | 327 | one_hot_labels = tf.one_hot( 328 | label_ids, depth=albert_config.vocab_size, dtype=tf.float32) 329 | 330 | # The `positions` tensor might be zero-padded (if the sequence is too 331 | # short to have the maximum number of predictions). The `label_weights` 332 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 333 | # padding predictions. 334 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 335 | numerator = tf.reduce_sum(label_weights * per_example_loss) 336 | denominator = tf.reduce_sum(label_weights) + 1e-5 337 | loss = numerator / denominator 338 | 339 | return (loss, per_example_loss, log_probs) 340 | 341 | 342 | def get_sentence_order_output(albert_config, input_tensor, labels): 343 | """Get loss and log probs for the next sentence prediction.""" 344 | 345 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 346 | # "random sentence". This weight matrix is not used after pre-training. 347 | with tf.variable_scope("cls/seq_relationship"): 348 | output_weights = tf.get_variable( 349 | "output_weights", 350 | shape=[2, albert_config.hidden_size], 351 | initializer=modeling.create_initializer( 352 | albert_config.initializer_range)) 353 | output_bias = tf.get_variable( 354 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 355 | 356 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 357 | logits = tf.nn.bias_add(logits, output_bias) 358 | log_probs = tf.nn.log_softmax(logits, axis=-1) 359 | labels = tf.reshape(labels, [-1]) 360 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 361 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 362 | loss = tf.reduce_mean(per_example_loss) 363 | return (loss, per_example_loss, log_probs) 364 | 365 | 366 | def gather_indexes(sequence_tensor, positions): 367 | """Gathers the vectors at the specific positions over a minibatch.""" 368 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 369 | batch_size = sequence_shape[0] 370 | seq_length = sequence_shape[1] 371 | width = sequence_shape[2] 372 | 373 | flat_offsets = tf.reshape( 374 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 375 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 376 | flat_sequence_tensor = tf.reshape(sequence_tensor, 377 | [batch_size * seq_length, width]) 378 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 379 | return output_tensor 380 | 381 | 382 | def input_fn_builder(input_files, 383 | max_seq_length, 384 | max_predictions_per_seq, 385 | is_training, 386 | num_cpu_threads=4): 387 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 388 | 389 | def input_fn(params): 390 | """The actual input function.""" 391 | batch_size = params["batch_size"] 392 | 393 | name_to_features = { 394 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 395 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 396 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 397 | # Note: We keep this feature name `next_sentence_labels` to be 398 | # compatible with the original data created by lanzhzh@. However, in 399 | # the ALBERT case it does represent sentence_order_labels. 400 | "next_sentence_labels": tf.FixedLenFeature([1], tf.int64), 401 | } 402 | 403 | if FLAGS.masked_lm_budget: 404 | name_to_features.update({ 405 | "token_boundary": 406 | tf.FixedLenFeature([max_seq_length], tf.int64)}) 407 | else: 408 | name_to_features.update({ 409 | "masked_lm_positions": 410 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 411 | "masked_lm_ids": 412 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 413 | "masked_lm_weights": 414 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32)}) 415 | 416 | # For training, we want a lot of parallel reading and shuffling. 417 | # For eval, we want no shuffling and parallel reading doesn't matter. 418 | if is_training: 419 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 420 | d = d.repeat() 421 | d = d.shuffle(buffer_size=len(input_files)) 422 | 423 | # `cycle_length` is the number of parallel files that get read. 424 | cycle_length = min(num_cpu_threads, len(input_files)) 425 | 426 | # `sloppy` mode means that the interleaving is not exact. This adds 427 | # even more randomness to the training pipeline. 428 | d = d.apply( 429 | tf.contrib.data.parallel_interleave( 430 | tf.data.TFRecordDataset, 431 | sloppy=is_training, 432 | cycle_length=cycle_length)) 433 | d = d.shuffle(buffer_size=100) 434 | else: 435 | d = tf.data.TFRecordDataset(input_files) 436 | # Since we evaluate for a fixed number of steps we don't want to encounter 437 | # out-of-range exceptions. 438 | d = d.repeat() 439 | 440 | # We must `drop_remainder` on training because the TPU requires fixed 441 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 442 | # and we *don't* want to drop the remainder, otherwise we wont cover 443 | # every sample. 444 | d = d.apply( 445 | tf.data.experimental.map_and_batch_with_legacy_function( 446 | lambda record: _decode_record(record, name_to_features), 447 | batch_size=batch_size, 448 | num_parallel_batches=num_cpu_threads, 449 | drop_remainder=True)) 450 | tf.logging.info(d) 451 | return d 452 | 453 | return input_fn 454 | 455 | 456 | def _decode_record(record, name_to_features): 457 | """Decodes a record to a TensorFlow example.""" 458 | example = tf.parse_single_example(record, name_to_features) 459 | 460 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 461 | # So cast all int64 to int32. 462 | for name in list(example.keys()): 463 | t = example[name] 464 | if t.dtype == tf.int64: 465 | t = tf.to_int32(t) 466 | example[name] = t 467 | 468 | return example 469 | 470 | 471 | def main(_): 472 | tf.logging.set_verbosity(tf.logging.INFO) 473 | 474 | if not FLAGS.do_train and not FLAGS.do_eval: 475 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 476 | 477 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 478 | 479 | tf.gfile.MakeDirs(FLAGS.output_dir) 480 | 481 | input_files = [] 482 | for input_pattern in FLAGS.input_file.split(","): 483 | input_files.extend(tf.gfile.Glob(input_pattern)) 484 | 485 | tf.logging.info("*** Input Files ***") 486 | for input_file in input_files: 487 | tf.logging.info(" %s" % input_file) 488 | 489 | tpu_cluster_resolver = None 490 | if FLAGS.use_tpu and FLAGS.tpu_name: 491 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 492 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 493 | 494 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 495 | run_config = tf.contrib.tpu.RunConfig( 496 | cluster=tpu_cluster_resolver, 497 | master=FLAGS.master, 498 | model_dir=FLAGS.output_dir, 499 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 500 | tpu_config=tf.contrib.tpu.TPUConfig( 501 | iterations_per_loop=FLAGS.iterations_per_loop, 502 | num_shards=FLAGS.num_tpu_cores, 503 | per_host_input_for_training=is_per_host)) 504 | 505 | model_fn = model_fn_builder( 506 | albert_config=albert_config, 507 | init_checkpoint=FLAGS.init_checkpoint, 508 | learning_rate=FLAGS.learning_rate, 509 | num_train_steps=FLAGS.num_train_steps, 510 | num_warmup_steps=FLAGS.num_warmup_steps, 511 | use_tpu=FLAGS.use_tpu, 512 | use_one_hot_embeddings=FLAGS.use_tpu, 513 | optimizer=FLAGS.optimizer, 514 | poly_power=FLAGS.poly_power, 515 | start_warmup_step=FLAGS.start_warmup_step) 516 | 517 | # If TPU is not available, this will fall back to normal Estimator on CPU 518 | # or GPU. 519 | estimator = tf.contrib.tpu.TPUEstimator( 520 | use_tpu=FLAGS.use_tpu, 521 | model_fn=model_fn, 522 | config=run_config, 523 | train_batch_size=FLAGS.train_batch_size, 524 | eval_batch_size=FLAGS.eval_batch_size) 525 | 526 | if FLAGS.do_train: 527 | tf.logging.info("***** Running training *****") 528 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 529 | train_input_fn = input_fn_builder( 530 | input_files=input_files, 531 | max_seq_length=FLAGS.max_seq_length, 532 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 533 | is_training=True) 534 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 535 | 536 | if FLAGS.do_eval: 537 | tf.logging.info("***** Running evaluation *****") 538 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 539 | global_step = -1 540 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 541 | writer = tf.gfile.GFile(output_eval_file, "w") 542 | tf.gfile.MakeDirs(FLAGS.export_dir) 543 | eval_input_fn = input_fn_builder( 544 | input_files=input_files, 545 | max_seq_length=FLAGS.max_seq_length, 546 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 547 | is_training=False) 548 | while global_step < FLAGS.num_train_steps: 549 | if estimator.latest_checkpoint() is None: 550 | tf.logging.info("No checkpoint found yet. Sleeping.") 551 | time.sleep(1) 552 | else: 553 | result = estimator.evaluate( 554 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 555 | global_step = result["global_step"] 556 | tf.logging.info("***** Eval results *****") 557 | for key in sorted(result.keys()): 558 | tf.logging.info(" %s = %s", key, str(result[key])) 559 | writer.write("%s = %s\n" % (key, str(result[key]))) 560 | 561 | if __name__ == "__main__": 562 | flags.mark_flag_as_required("input_file") 563 | flags.mark_flag_as_required("albert_config_file") 564 | flags.mark_flag_as_required("output_dir") 565 | tf.app.run() -------------------------------------------------------------------------------- /run_pretraining_google_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Run masked LM/next sentence masked_lm pre-training for ALBERT.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import os 24 | import time 25 | 26 | from six.moves import range 27 | import tensorflow as tf 28 | 29 | import modeling_google_fast as modeling 30 | import optimization_google as optimization 31 | 32 | flags = tf.flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | ## Required parameters 37 | flags.DEFINE_string( 38 | "albert_config_file", None, 39 | "The config json file corresponding to the pre-trained ALBERT model. " 40 | "This specifies the model architecture.") 41 | 42 | flags.DEFINE_string( 43 | "input_file", None, 44 | "Input TF example files (can be a glob or comma separated).") 45 | 46 | flags.DEFINE_string( 47 | "output_dir", None, 48 | "The output directory where the model checkpoints will be written.") 49 | 50 | flags.DEFINE_string( 51 | "export_dir", None, 52 | "The output directory where the saved models will be written.") 53 | ## Other parameters 54 | flags.DEFINE_string( 55 | "init_checkpoint", None, 56 | "Initial checkpoint (usually from a pre-trained ALBERT model).") 57 | 58 | flags.DEFINE_integer( 59 | "max_seq_length", 512, 60 | "The maximum total input sequence length after WordPiece tokenization. " 61 | "Sequences longer than this will be truncated, and sequences shorter " 62 | "than this will be padded. Must match data generation.") 63 | 64 | flags.DEFINE_integer( 65 | "max_predictions_per_seq", 20, 66 | "Maximum number of masked LM predictions per sequence. " 67 | "Must match data generation.") 68 | 69 | flags.DEFINE_bool("do_train", True, "Whether to run training.") 70 | 71 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 72 | 73 | flags.DEFINE_integer("train_batch_size", 4096, "Total batch size for training.") 74 | 75 | flags.DEFINE_integer("eval_batch_size", 64, "Total batch size for eval.") 76 | 77 | flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"], 78 | "The optimizer for training.") 79 | 80 | flags.DEFINE_float("learning_rate", 0.00176, "The initial learning rate.") 81 | 82 | flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.") 83 | 84 | flags.DEFINE_integer("num_train_steps", 125000, "Number of training steps.") 85 | 86 | flags.DEFINE_integer("num_warmup_steps", 3125, "Number of warmup steps.") 87 | 88 | flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.") 89 | 90 | flags.DEFINE_integer("save_checkpoints_steps", 5000, 91 | "How often to save the model checkpoint.") 92 | 93 | flags.DEFINE_integer("iterations_per_loop", 1000, 94 | "How many steps to make in each estimator call.") 95 | 96 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 97 | 98 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 99 | 100 | flags.DEFINE_bool("init_from_group0", False, "Whether to initialize" 101 | "parameters of other groups from group 0") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | flags.DEFINE_float( 128 | "masked_lm_budget", 0, 129 | "If >0, the ratio of masked ngrams to unmasked ngrams. Default 0," 130 | "for offline masking") 131 | 132 | 133 | def model_fn_builder(albert_config, init_checkpoint, learning_rate, 134 | num_train_steps, num_warmup_steps, use_tpu, 135 | use_one_hot_embeddings, optimizer, poly_power, 136 | start_warmup_step): 137 | """Returns `model_fn` closure for TPUEstimator.""" 138 | 139 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 140 | """The `model_fn` for TPUEstimator.""" 141 | 142 | tf.logging.info("*** Features ***") 143 | for name in sorted(features.keys()): 144 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 145 | 146 | input_ids = features["input_ids"] 147 | input_mask = features["input_mask"] 148 | segment_ids = features["segment_ids"] 149 | masked_lm_positions = features["masked_lm_positions"] 150 | masked_lm_ids = features["masked_lm_ids"] 151 | masked_lm_weights = features["masked_lm_weights"] 152 | # Note: We keep this feature name `next_sentence_labels` to be compatible 153 | # with the original data created by lanzhzh@. However, in the ALBERT case 154 | # it does represent sentence_order_labels. 155 | sentence_order_labels = features["next_sentence_labels"] 156 | 157 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 158 | 159 | model = modeling.AlbertModel( 160 | config=albert_config, 161 | is_training=is_training, 162 | input_ids=input_ids, 163 | input_mask=input_mask, 164 | token_type_ids=segment_ids, 165 | use_one_hot_embeddings=use_one_hot_embeddings) 166 | 167 | (masked_lm_loss, masked_lm_example_loss, 168 | masked_lm_log_probs) = get_masked_lm_output(albert_config, 169 | model.get_sequence_output(), 170 | model.get_embedding_table(), 171 | masked_lm_positions, 172 | masked_lm_ids, 173 | masked_lm_weights) 174 | 175 | (sentence_order_loss, sentence_order_example_loss, 176 | sentence_order_log_probs) = get_sentence_order_output( 177 | albert_config, model.get_pooled_output(), sentence_order_labels) 178 | 179 | total_loss = masked_lm_loss + sentence_order_loss 180 | 181 | tvars = tf.trainable_variables() 182 | 183 | initialized_variable_names = {} 184 | scaffold_fn = None 185 | if init_checkpoint: 186 | tf.logging.info("number of hidden group %d to initialize", 187 | albert_config.num_hidden_groups) 188 | num_of_initialize_group = 1 189 | if FLAGS.init_from_group0: 190 | num_of_initialize_group = albert_config.num_hidden_groups 191 | if albert_config.net_structure_type > 0: 192 | num_of_initialize_group = albert_config.num_hidden_layers 193 | (assignment_map, initialized_variable_names 194 | ) = modeling.get_assignment_map_from_checkpoint( 195 | tvars, init_checkpoint, num_of_initialize_group) 196 | if use_tpu: 197 | 198 | def tpu_scaffold(): 199 | for gid in range(num_of_initialize_group): 200 | tf.logging.info("initialize the %dth layer", gid) 201 | tf.logging.info(assignment_map[gid]) 202 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid]) 203 | return tf.train.Scaffold() 204 | 205 | scaffold_fn = tpu_scaffold 206 | else: 207 | for gid in range(num_of_initialize_group): 208 | tf.logging.info("initialize the %dth layer", gid) 209 | tf.logging.info(assignment_map[gid]) 210 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid]) 211 | 212 | tf.logging.info("**** Trainable Variables ****") 213 | for var in tvars: 214 | init_string = "" 215 | if var.name in initialized_variable_names: 216 | init_string = ", *INIT_FROM_CKPT*" 217 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 218 | init_string) 219 | 220 | output_spec = None 221 | if mode == tf.estimator.ModeKeys.TRAIN: 222 | train_op = optimization.create_optimizer( 223 | total_loss, learning_rate, num_train_steps, num_warmup_steps, 224 | use_tpu, optimizer, poly_power, start_warmup_step) 225 | 226 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 227 | mode=mode, 228 | loss=total_loss, 229 | train_op=train_op, 230 | scaffold_fn=scaffold_fn) 231 | elif mode == tf.estimator.ModeKeys.EVAL: 232 | 233 | def metric_fn(*args): 234 | """Computes the loss and accuracy of the model.""" 235 | (masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 236 | masked_lm_weights, sentence_order_example_loss, 237 | sentence_order_log_probs, sentence_order_labels) = args[:7] 238 | 239 | 240 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 241 | [-1, masked_lm_log_probs.shape[-1]]) 242 | masked_lm_predictions = tf.argmax( 243 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 244 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 245 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 246 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 247 | masked_lm_accuracy = tf.metrics.accuracy( 248 | labels=masked_lm_ids, 249 | predictions=masked_lm_predictions, 250 | weights=masked_lm_weights) 251 | masked_lm_mean_loss = tf.metrics.mean( 252 | values=masked_lm_example_loss, weights=masked_lm_weights) 253 | 254 | metrics = { 255 | "masked_lm_accuracy": masked_lm_accuracy, 256 | "masked_lm_loss": masked_lm_mean_loss, 257 | } 258 | 259 | sentence_order_log_probs = tf.reshape( 260 | sentence_order_log_probs, [-1, sentence_order_log_probs.shape[-1]]) 261 | sentence_order_predictions = tf.argmax( 262 | sentence_order_log_probs, axis=-1, output_type=tf.int32) 263 | sentence_order_labels = tf.reshape(sentence_order_labels, [-1]) 264 | sentence_order_accuracy = tf.metrics.accuracy( 265 | labels=sentence_order_labels, 266 | predictions=sentence_order_predictions) 267 | sentence_order_mean_loss = tf.metrics.mean( 268 | values=sentence_order_example_loss) 269 | metrics.update({ 270 | "sentence_order_accuracy": sentence_order_accuracy, 271 | "sentence_order_loss": sentence_order_mean_loss 272 | }) 273 | return metrics 274 | 275 | metric_values = [ 276 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 277 | masked_lm_weights, sentence_order_example_loss, 278 | sentence_order_log_probs, sentence_order_labels 279 | ] 280 | 281 | eval_metrics = (metric_fn, metric_values) 282 | 283 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 284 | mode=mode, 285 | loss=total_loss, 286 | eval_metrics=eval_metrics, 287 | scaffold_fn=scaffold_fn) 288 | else: 289 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 290 | 291 | return output_spec 292 | 293 | return model_fn 294 | 295 | 296 | def get_masked_lm_output(albert_config, input_tensor, output_weights, positions, 297 | label_ids, label_weights): 298 | """Get loss and log probs for the masked LM.""" 299 | input_tensor = gather_indexes(input_tensor, positions) 300 | 301 | 302 | with tf.variable_scope("cls/predictions"): 303 | # We apply one more non-linear transformation before the output layer. 304 | # This matrix is not used after pre-training. 305 | with tf.variable_scope("transform"): 306 | input_tensor = tf.layers.dense( 307 | input_tensor, 308 | units=albert_config.embedding_size, 309 | activation=modeling.get_activation(albert_config.hidden_act), 310 | kernel_initializer=modeling.create_initializer( 311 | albert_config.initializer_range)) 312 | input_tensor = modeling.layer_norm(input_tensor) 313 | 314 | # The output weights are the same as the input embeddings, but there is 315 | # an output-only bias for each token. 316 | output_bias = tf.get_variable( 317 | "output_bias", 318 | shape=[albert_config.vocab_size], 319 | initializer=tf.zeros_initializer()) 320 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 321 | logits = tf.nn.bias_add(logits, output_bias) 322 | log_probs = tf.nn.log_softmax(logits, axis=-1) 323 | 324 | label_ids = tf.reshape(label_ids, [-1]) 325 | label_weights = tf.reshape(label_weights, [-1]) 326 | 327 | one_hot_labels = tf.one_hot( 328 | label_ids, depth=albert_config.vocab_size, dtype=tf.float32) 329 | 330 | # The `positions` tensor might be zero-padded (if the sequence is too 331 | # short to have the maximum number of predictions). The `label_weights` 332 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 333 | # padding predictions. 334 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 335 | numerator = tf.reduce_sum(label_weights * per_example_loss) 336 | denominator = tf.reduce_sum(label_weights) + 1e-5 337 | loss = numerator / denominator 338 | 339 | return (loss, per_example_loss, log_probs) 340 | 341 | 342 | def get_sentence_order_output(albert_config, input_tensor, labels): 343 | """Get loss and log probs for the next sentence prediction.""" 344 | 345 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 346 | # "random sentence". This weight matrix is not used after pre-training. 347 | with tf.variable_scope("cls/seq_relationship"): 348 | output_weights = tf.get_variable( 349 | "output_weights", 350 | shape=[2, albert_config.hidden_size], 351 | initializer=modeling.create_initializer( 352 | albert_config.initializer_range)) 353 | output_bias = tf.get_variable( 354 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 355 | 356 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 357 | logits = tf.nn.bias_add(logits, output_bias) 358 | log_probs = tf.nn.log_softmax(logits, axis=-1) 359 | labels = tf.reshape(labels, [-1]) 360 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 361 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 362 | loss = tf.reduce_mean(per_example_loss) 363 | return (loss, per_example_loss, log_probs) 364 | 365 | 366 | def gather_indexes(sequence_tensor, positions): 367 | """Gathers the vectors at the specific positions over a minibatch.""" 368 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 369 | batch_size = sequence_shape[0] 370 | seq_length = sequence_shape[1] 371 | width = sequence_shape[2] 372 | 373 | flat_offsets = tf.reshape( 374 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 375 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 376 | flat_sequence_tensor = tf.reshape(sequence_tensor, 377 | [batch_size * seq_length, width]) 378 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 379 | return output_tensor 380 | 381 | 382 | def input_fn_builder(input_files, 383 | max_seq_length, 384 | max_predictions_per_seq, 385 | is_training, 386 | num_cpu_threads=4): 387 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 388 | 389 | def input_fn(params): 390 | """The actual input function.""" 391 | batch_size = params["batch_size"] 392 | 393 | name_to_features = { 394 | "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 395 | "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), 396 | "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), 397 | # Note: We keep this feature name `next_sentence_labels` to be 398 | # compatible with the original data created by lanzhzh@. However, in 399 | # the ALBERT case it does represent sentence_order_labels. 400 | "next_sentence_labels": tf.FixedLenFeature([1], tf.int64), 401 | } 402 | 403 | if FLAGS.masked_lm_budget: 404 | name_to_features.update({ 405 | "token_boundary": 406 | tf.FixedLenFeature([max_seq_length], tf.int64)}) 407 | else: 408 | name_to_features.update({ 409 | "masked_lm_positions": 410 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 411 | "masked_lm_ids": 412 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 413 | "masked_lm_weights": 414 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32)}) 415 | 416 | # For training, we want a lot of parallel reading and shuffling. 417 | # For eval, we want no shuffling and parallel reading doesn't matter. 418 | if is_training: 419 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 420 | d = d.repeat() 421 | d = d.shuffle(buffer_size=len(input_files)) 422 | 423 | # `cycle_length` is the number of parallel files that get read. 424 | cycle_length = min(num_cpu_threads, len(input_files)) 425 | 426 | # `sloppy` mode means that the interleaving is not exact. This adds 427 | # even more randomness to the training pipeline. 428 | d = d.apply( 429 | tf.contrib.data.parallel_interleave( 430 | tf.data.TFRecordDataset, 431 | sloppy=is_training, 432 | cycle_length=cycle_length)) 433 | d = d.shuffle(buffer_size=100) 434 | else: 435 | d = tf.data.TFRecordDataset(input_files) 436 | # Since we evaluate for a fixed number of steps we don't want to encounter 437 | # out-of-range exceptions. 438 | d = d.repeat() 439 | 440 | # We must `drop_remainder` on training because the TPU requires fixed 441 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 442 | # and we *don't* want to drop the remainder, otherwise we wont cover 443 | # every sample. 444 | d = d.apply( 445 | tf.data.experimental.map_and_batch_with_legacy_function( 446 | lambda record: _decode_record(record, name_to_features), 447 | batch_size=batch_size, 448 | num_parallel_batches=num_cpu_threads, 449 | drop_remainder=True)) 450 | tf.logging.info(d) 451 | return d 452 | 453 | return input_fn 454 | 455 | 456 | def _decode_record(record, name_to_features): 457 | """Decodes a record to a TensorFlow example.""" 458 | example = tf.parse_single_example(record, name_to_features) 459 | 460 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 461 | # So cast all int64 to int32. 462 | for name in list(example.keys()): 463 | t = example[name] 464 | if t.dtype == tf.int64: 465 | t = tf.to_int32(t) 466 | example[name] = t 467 | 468 | return example 469 | 470 | 471 | def main(_): 472 | tf.logging.set_verbosity(tf.logging.INFO) 473 | 474 | if not FLAGS.do_train and not FLAGS.do_eval: 475 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 476 | 477 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 478 | 479 | tf.gfile.MakeDirs(FLAGS.output_dir) 480 | 481 | input_files = [] 482 | for input_pattern in FLAGS.input_file.split(","): 483 | input_files.extend(tf.gfile.Glob(input_pattern)) 484 | 485 | tf.logging.info("*** Input Files ***") 486 | for input_file in input_files: 487 | tf.logging.info(" %s" % input_file) 488 | 489 | tpu_cluster_resolver = None 490 | if FLAGS.use_tpu and FLAGS.tpu_name: 491 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 492 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 493 | 494 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 495 | run_config = tf.contrib.tpu.RunConfig( 496 | cluster=tpu_cluster_resolver, 497 | master=FLAGS.master, 498 | model_dir=FLAGS.output_dir, 499 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 500 | tpu_config=tf.contrib.tpu.TPUConfig( 501 | iterations_per_loop=FLAGS.iterations_per_loop, 502 | num_shards=FLAGS.num_tpu_cores, 503 | per_host_input_for_training=is_per_host)) 504 | 505 | model_fn = model_fn_builder( 506 | albert_config=albert_config, 507 | init_checkpoint=FLAGS.init_checkpoint, 508 | learning_rate=FLAGS.learning_rate, 509 | num_train_steps=FLAGS.num_train_steps, 510 | num_warmup_steps=FLAGS.num_warmup_steps, 511 | use_tpu=FLAGS.use_tpu, 512 | use_one_hot_embeddings=FLAGS.use_tpu, 513 | optimizer=FLAGS.optimizer, 514 | poly_power=FLAGS.poly_power, 515 | start_warmup_step=FLAGS.start_warmup_step) 516 | 517 | # If TPU is not available, this will fall back to normal Estimator on CPU 518 | # or GPU. 519 | estimator = tf.contrib.tpu.TPUEstimator( 520 | use_tpu=FLAGS.use_tpu, 521 | model_fn=model_fn, 522 | config=run_config, 523 | train_batch_size=FLAGS.train_batch_size, 524 | eval_batch_size=FLAGS.eval_batch_size) 525 | 526 | if FLAGS.do_train: 527 | tf.logging.info("***** Running training *****") 528 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 529 | train_input_fn = input_fn_builder( 530 | input_files=input_files, 531 | max_seq_length=FLAGS.max_seq_length, 532 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 533 | is_training=True) 534 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 535 | 536 | if FLAGS.do_eval: 537 | tf.logging.info("***** Running evaluation *****") 538 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 539 | global_step = -1 540 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 541 | writer = tf.gfile.GFile(output_eval_file, "w") 542 | tf.gfile.MakeDirs(FLAGS.export_dir) 543 | eval_input_fn = input_fn_builder( 544 | input_files=input_files, 545 | max_seq_length=FLAGS.max_seq_length, 546 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 547 | is_training=False) 548 | while global_step < FLAGS.num_train_steps: 549 | if estimator.latest_checkpoint() is None: 550 | tf.logging.info("No checkpoint found yet. Sleeping.") 551 | time.sleep(1) 552 | else: 553 | result = estimator.evaluate( 554 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 555 | global_step = result["global_step"] 556 | tf.logging.info("***** Eval results *****") 557 | for key in sorted(result.keys()): 558 | tf.logging.info(" %s = %s", key, str(result[key])) 559 | writer.write("%s = %s\n" % (key, str(result[key]))) 560 | 561 | if __name__ == "__main__": 562 | flags.mark_flag_as_required("input_file") 563 | flags.mark_flag_as_required("albert_config_file") 564 | flags.mark_flag_as_required("output_dir") 565 | tf.app.run() -------------------------------------------------------------------------------- /similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | 进行文本相似度预测的示例。可以直接运行进行预测。 3 | 参考了项目:https://github.com/chdd/bert-utils 4 | 5 | """ 6 | 7 | 8 | import tensorflow as tf 9 | import args 10 | import tokenization 11 | import modeling 12 | from run_classifier import InputFeatures, InputExample, DataProcessor, create_model, convert_examples_to_features 13 | 14 | 15 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 16 | 17 | 18 | class SimProcessor(DataProcessor): 19 | def get_sentence_examples(self, questions): 20 | examples = [] 21 | for index, data in enumerate(questions): 22 | guid = 'test-%d' % index 23 | text_a = tokenization.convert_to_unicode(str(data[0])) 24 | text_b = tokenization.convert_to_unicode(str(data[1])) 25 | label = str(0) 26 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 27 | return examples 28 | 29 | def get_labels(self): 30 | return ['0', '1'] 31 | 32 | 33 | """ 34 | 模型类,负责载入checkpoint初始化模型 35 | """ 36 | class BertSim: 37 | def __init__(self, batch_size=args.batch_size): 38 | self.mode = None 39 | self.max_seq_length = args.max_seq_len 40 | self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True) 41 | self.batch_size = batch_size 42 | self.estimator = None 43 | self.processor = SimProcessor() 44 | tf.logging.set_verbosity(tf.logging.INFO) 45 | 46 | 47 | 48 | #载入estimator,构造模型 49 | def start_model(self): 50 | self.estimator = self.get_estimator() 51 | 52 | 53 | def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate, 54 | num_train_steps, num_warmup_steps, 55 | use_one_hot_embeddings): 56 | """Returns `model_fn` closurimport_tfe for TPUEstimator.""" 57 | 58 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 59 | from tensorflow.python.estimator.model_fn import EstimatorSpec 60 | 61 | tf.logging.info("*** Features ***") 62 | for name in sorted(features.keys()): 63 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 64 | 65 | input_ids = features["input_ids"] 66 | input_mask = features["input_mask"] 67 | segment_ids = features["segment_ids"] 68 | label_ids = features["label_ids"] 69 | 70 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 71 | 72 | (total_loss, per_example_loss, logits, probabilities) = create_model( 73 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 74 | num_labels, use_one_hot_embeddings) 75 | 76 | tvars = tf.trainable_variables() 77 | initialized_variable_names = {} 78 | 79 | if init_checkpoint: 80 | (assignment_map, initialized_variable_names) \ 81 | = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 82 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 83 | 84 | tf.logging.info("**** Trainable Variables ****") 85 | for var in tvars: 86 | init_string = "" 87 | if var.name in initialized_variable_names: 88 | init_string = ", *INIT_FROM_CKPT*" 89 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 90 | init_string) 91 | output_spec = EstimatorSpec(mode=mode, predictions=probabilities) 92 | 93 | return output_spec 94 | 95 | return model_fn 96 | 97 | def get_estimator(self): 98 | 99 | from tensorflow.python.estimator.estimator import Estimator 100 | from tensorflow.python.estimator.run_config import RunConfig 101 | 102 | bert_config = modeling.BertConfig.from_json_file(args.config_name) 103 | label_list = self.processor.get_labels() 104 | if self.mode == tf.estimator.ModeKeys.TRAIN: 105 | init_checkpoint = args.ckpt_name 106 | else: 107 | init_checkpoint = args.output_dir 108 | 109 | model_fn = self.model_fn_builder( 110 | bert_config=bert_config, 111 | num_labels=len(label_list), 112 | init_checkpoint=init_checkpoint, 113 | learning_rate=args.learning_rate, 114 | num_train_steps=None, 115 | num_warmup_steps=None, 116 | use_one_hot_embeddings=False) 117 | 118 | config = tf.ConfigProto() 119 | config.gpu_options.allow_growth = True 120 | config.gpu_options.per_process_gpu_memory_fraction = args.gpu_memory_fraction 121 | config.log_device_placement = False 122 | 123 | return Estimator(model_fn=model_fn, config=RunConfig(session_config=config), model_dir=args.output_dir, 124 | params={'batch_size': self.batch_size}) 125 | 126 | def predict_sentences(self,sentences): 127 | results= self.estimator.predict(input_fn=input_fn_builder(self,sentences), yield_single_examples=False) 128 | #打印预测结果 129 | for i in results: 130 | print(i) 131 | 132 | def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): 133 | """Truncates a sequence pair in place to the maximum length.""" 134 | 135 | # This is a simple heuristic which will always truncate the longer sequence 136 | # one token at a time. This makes more sense than truncating an equal percent 137 | # of tokens from each, since if one sequence is very short then each token 138 | # that's truncated likely contains more information than a longer sequence. 139 | while True: 140 | total_length = len(tokens_a) + len(tokens_b) 141 | if total_length <= max_length: 142 | break 143 | if len(tokens_a) > len(tokens_b): 144 | tokens_a.pop() 145 | else: 146 | tokens_b.pop() 147 | 148 | def convert_single_example(self, ex_index, example, label_list, max_seq_length, tokenizer): 149 | """Converts a single `InputExample` into a single `InputFeatures`.""" 150 | label_map = {} 151 | for (i, label) in enumerate(label_list): 152 | label_map[label] = i 153 | 154 | tokens_a = tokenizer.tokenize(example.text_a) 155 | tokens_b = None 156 | if example.text_b: 157 | tokens_b = tokenizer.tokenize(example.text_b) 158 | 159 | if tokens_b: 160 | # Modifies `tokens_a` and `tokens_b` in place so that the total 161 | # length is less than the specified length. 162 | # Account for [CLS], [SEP], [SEP] with "- 3" 163 | self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 164 | else: 165 | # Account for [CLS] and [SEP] with "- 2" 166 | if len(tokens_a) > max_seq_length - 2: 167 | tokens_a = tokens_a[0:(max_seq_length - 2)] 168 | 169 | # The convention in BERT is: 170 | # (a) For sequence pairs: 171 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 172 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 173 | # (b) For single sequences: 174 | # tokens: [CLS] the dog is hairy . [SEP] 175 | # type_ids: 0 0 0 0 0 0 0 176 | # 177 | # Where "type_ids" are used to indicate whether this is the first 178 | # sequence or the second sequence. The embedding vectors for `type=0` and 179 | # `type=1` were learned during pre-training and are added to the wordpiece 180 | # embedding vector (and position vector). This is not *strictly* necessary 181 | # since the [SEP] token unambiguously separates the sequences, but it makes 182 | # it easier for the model to learn the concept of sequences. 183 | # 184 | # For classification tasks, the first vector (corresponding to [CLS]) is 185 | # used as as the "sentence vector". Note that this only makes sense because 186 | # the entire model is fine-tuned. 187 | tokens = [] 188 | segment_ids = [] 189 | tokens.append("[CLS]") 190 | segment_ids.append(0) 191 | for token in tokens_a: 192 | tokens.append(token) 193 | segment_ids.append(0) 194 | tokens.append("[SEP]") 195 | segment_ids.append(0) 196 | 197 | if tokens_b: 198 | for token in tokens_b: 199 | tokens.append(token) 200 | segment_ids.append(1) 201 | tokens.append("[SEP]") 202 | segment_ids.append(1) 203 | 204 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 205 | 206 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 207 | # tokens are attended to. 208 | input_mask = [1] * len(input_ids) 209 | 210 | # Zero-pad up to the sequence length. 211 | while len(input_ids) < max_seq_length: 212 | input_ids.append(0) 213 | input_mask.append(0) 214 | segment_ids.append(0) 215 | 216 | assert len(input_ids) == max_seq_length 217 | assert len(input_mask) == max_seq_length 218 | assert len(segment_ids) == max_seq_length 219 | 220 | label_id = label_map[example.label] 221 | if ex_index < 5: 222 | tf.logging.info("*** Example ***") 223 | tf.logging.info("guid: %s" % (example.guid)) 224 | tf.logging.info("tokens: %s" % " ".join( 225 | [tokenization.printable_text(x) for x in tokens])) 226 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 227 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 228 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 229 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 230 | 231 | feature = InputFeatures( 232 | input_ids=input_ids, 233 | input_mask=input_mask, 234 | segment_ids=segment_ids, 235 | label_id=label_id) 236 | return feature 237 | 238 | 239 | 240 | 241 | def input_fn_builder(bertSim,sentences): 242 | def predict_input_fn(): 243 | return (tf.data.Dataset.from_generator( 244 | generate_from_input, 245 | output_types={ 246 | 'input_ids': tf.int32, 247 | 'input_mask': tf.int32, 248 | 'segment_ids': tf.int32, 249 | 'label_ids': tf.int32}, 250 | output_shapes={ 251 | 'input_ids': (None, bertSim.max_seq_length), 252 | 'input_mask': (None, bertSim.max_seq_length), 253 | 'segment_ids': (None, bertSim.max_seq_length), 254 | 'label_ids': (1,)}).prefetch(10)) 255 | 256 | def generate_from_input(): 257 | processor = bertSim.processor 258 | predict_examples = processor.get_sentence_examples(sentences) 259 | features = convert_examples_to_features(predict_examples, processor.get_labels(), args.max_seq_len, 260 | bertSim.tokenizer) 261 | yield { 262 | 'input_ids': [f.input_ids for f in features], 263 | 'input_mask': [f.input_mask for f in features], 264 | 'segment_ids': [f.segment_ids for f in features], 265 | 'label_ids': [f.label_id for f in features] 266 | } 267 | 268 | return predict_input_fn 269 | 270 | 271 | if __name__ == '__main__': 272 | sim = BertSim() 273 | sim.start_model() 274 | sim.predict_sentences([("我喜欢妈妈做的汤", "妈妈做的汤我很喜欢喝")]) 275 | -------------------------------------------------------------------------------- /test_changes.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from modeling import embedding_lookup_factorized,transformer_model 4 | import os 5 | 6 | """ 7 | 测试albert主要的改进点:词嵌入的因式分解、层间参数共享、段落间连贯性 8 | test main change of albert from bert 9 | """ 10 | batch_size = 2048 11 | sequence_length = 512 12 | vocab_size = 30000 13 | hidden_size = 1024 14 | num_attention_heads = int(hidden_size / 64) 15 | 16 | def get_total_parameters(): 17 | """ 18 | get total parameters of a graph 19 | :return: 20 | """ 21 | total_parameters = 0 22 | for variable in tf.trainable_variables(): 23 | # shape is an array of tf.Dimension 24 | shape = variable.get_shape() 25 | # print(shape) 26 | # print(len(shape)) 27 | variable_parameters = 1 28 | for dim in shape: 29 | # print(dim) 30 | variable_parameters *= dim.value 31 | # print(variable_parameters) 32 | total_parameters += variable_parameters 33 | return total_parameters 34 | 35 | def test_factorized_embedding(): 36 | """ 37 | test of Factorized embedding parameterization 38 | :return: 39 | """ 40 | input_ids=tf.zeros((batch_size, sequence_length),dtype=tf.int32) 41 | output, embedding_table, embedding_table_2=embedding_lookup_factorized(input_ids,vocab_size,hidden_size) 42 | print("output:",output) 43 | 44 | def test_share_parameters(): 45 | """ 46 | test of share parameters across all layers: how many parameter after share parameter across layers of transformer. 47 | :return: 48 | """ 49 | def total_parameters_transformer(share_parameter_across_layers): 50 | input_tensor=tf.zeros((batch_size, sequence_length, hidden_size),dtype=tf.float32) 51 | print("transformer_model. input:",input_tensor) 52 | transformer_result=transformer_model(input_tensor,hidden_size=hidden_size,num_attention_heads=num_attention_heads,share_parameter_across_layers=share_parameter_across_layers) 53 | print("transformer_result:",transformer_result) 54 | total_parameters=get_total_parameters() 55 | print('total_parameters(not share):',total_parameters) 56 | 57 | share_parameter_across_layers=False 58 | total_parameters_transformer(share_parameter_across_layers) # total parameters, not share: 125,976,576 = 125 million 59 | 60 | tf.reset_default_graph() # Clears the default graph stack and resets the global default graph 61 | share_parameter_across_layers=True 62 | total_parameters_transformer(share_parameter_across_layers) # total parameters, share: 10,498,048 = 10.5 million 63 | 64 | def test_sentence_order_prediction(): 65 | """ 66 | sentence order prediction. 67 | 68 | check method of create_instances_from_document_albert from create_pretrining_data.py 69 | 70 | :return: 71 | """ 72 | # 添加运行权限 73 | os.system("chmod +x create_pretrain_data.sh") 74 | 75 | os.system("./create_pretrain_data.sh") 76 | 77 | 78 | # 1.test of Factorized embedding parameterization 79 | #test_factorized_embedding() 80 | 81 | # 2. test of share parameters across all layers: how many parameter after share parameter across layers of transformer. 82 | # before share parameter: 125,976,576; after share parameter: 83 | #test_share_parameters() 84 | 85 | # 3. test of sentence order prediction(SOP) 86 | test_sentence_order_prediction() 87 | 88 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | #print("items:",items) #['[CLS]', '日', '##期', ',', '但', '被', '##告', '金', '##东', '##福', '载', '##明', '[MASK]', 'U', '##N', '##K', ']', '保', '##证', '本', '##月', '1', '##4', '[MASK]', '到', '##位', ',', '2', '##0', '##1', '##5', '年', '6', '[MASK]', '1', '##1', '日', '[', 'U', '##N', '##K', ']', ',', '原', '##告', '[MASK]', '认', '##可', '于', '2', '##0', '##1', '##5', '[MASK]', '6', '月', '[MASK]', '[MASK]', '日', '##向', '被', '##告', '主', '##张', '权', '##利', '。', '而', '[MASK]', '[MASK]', '自', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '年', '6', '月', '1', '##1', '日', '[SEP]', '原', '##告', '于', '2', '##0', '##1', '##6', '[MASK]', '6', '[MASK]', '2', '##4', '日', '起', '##诉', ',', '主', '##张', '保', '##证', '责', '##任', ',', '已', '超', '##过', '保', '##证', '期', '##限', '[MASK]', '保', '##证', '人', '依', '##法', '不', '##再', '承', '##担', '保', '##证', '[MASK]', '[MASK]', '[MASK]', '[SEP]'] 140 | for i,item in enumerate(items): 141 | #print(i,"item:",item) # ##期 142 | output.append(vocab[item]) 143 | return output 144 | 145 | 146 | def convert_tokens_to_ids(vocab, tokens): 147 | return convert_by_vocab(vocab, tokens) 148 | 149 | 150 | def convert_ids_to_tokens(inv_vocab, ids): 151 | return convert_by_vocab(inv_vocab, ids) 152 | 153 | 154 | def whitespace_tokenize(text): 155 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 156 | text = text.strip() 157 | if not text: 158 | return [] 159 | tokens = text.split() 160 | return tokens 161 | 162 | 163 | class FullTokenizer(object): 164 | """Runs end-to-end tokenziation.""" 165 | 166 | def __init__(self, vocab_file, do_lower_case=True): 167 | self.vocab = load_vocab(vocab_file) 168 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 169 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 171 | 172 | def tokenize(self, text): 173 | split_tokens = [] 174 | for token in self.basic_tokenizer.tokenize(text): 175 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 176 | split_tokens.append(sub_token) 177 | 178 | return split_tokens 179 | 180 | def convert_tokens_to_ids(self, tokens): 181 | return convert_by_vocab(self.vocab, tokens) 182 | 183 | def convert_ids_to_tokens(self, ids): 184 | return convert_by_vocab(self.inv_vocab, ids) 185 | 186 | 187 | class BasicTokenizer(object): 188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 189 | 190 | def __init__(self, do_lower_case=True): 191 | """Constructs a BasicTokenizer. 192 | 193 | Args: 194 | do_lower_case: Whether to lower case the input. 195 | """ 196 | self.do_lower_case = do_lower_case 197 | 198 | def tokenize(self, text): 199 | """Tokenizes a piece of text.""" 200 | text = convert_to_unicode(text) 201 | text = self._clean_text(text) 202 | 203 | # This was added on November 1st, 2018 for the multilingual and Chinese 204 | # models. This is also applied to the English models now, but it doesn't 205 | # matter since the English models were not trained on any Chinese data 206 | # and generally don't have any Chinese data in them (there are Chinese 207 | # characters in the vocabulary because Wikipedia does have some Chinese 208 | # words in the English Wikipedia.). 209 | text = self._tokenize_chinese_chars(text) 210 | 211 | orig_tokens = whitespace_tokenize(text) 212 | split_tokens = [] 213 | for token in orig_tokens: 214 | if self.do_lower_case: 215 | token = token.lower() 216 | token = self._run_strip_accents(token) 217 | split_tokens.extend(self._run_split_on_punc(token)) 218 | 219 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 220 | return output_tokens 221 | 222 | def _run_strip_accents(self, text): 223 | """Strips accents from a piece of text.""" 224 | text = unicodedata.normalize("NFD", text) 225 | output = [] 226 | for char in text: 227 | cat = unicodedata.category(char) 228 | if cat == "Mn": 229 | continue 230 | output.append(char) 231 | return "".join(output) 232 | 233 | def _run_split_on_punc(self, text): 234 | """Splits punctuation on a piece of text.""" 235 | chars = list(text) 236 | i = 0 237 | start_new_word = True 238 | output = [] 239 | while i < len(chars): 240 | char = chars[i] 241 | if _is_punctuation(char): 242 | output.append([char]) 243 | start_new_word = True 244 | else: 245 | if start_new_word: 246 | output.append([]) 247 | start_new_word = False 248 | output[-1].append(char) 249 | i += 1 250 | 251 | return ["".join(x) for x in output] 252 | 253 | def _tokenize_chinese_chars(self, text): 254 | """Adds whitespace around any CJK character.""" 255 | output = [] 256 | for char in text: 257 | cp = ord(char) 258 | if self._is_chinese_char(cp): 259 | output.append(" ") 260 | output.append(char) 261 | output.append(" ") 262 | else: 263 | output.append(char) 264 | return "".join(output) 265 | 266 | def _is_chinese_char(self, cp): 267 | """Checks whether CP is the codepoint of a CJK character.""" 268 | # This defines a "chinese character" as anything in the CJK Unicode block: 269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 270 | # 271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 272 | # despite its name. The modern Korean Hangul alphabet is a different block, 273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 274 | # space-separated words, so they are not treated specially and handled 275 | # like the all of the other languages. 276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 277 | (cp >= 0x3400 and cp <= 0x4DBF) or # 278 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 279 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 280 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 282 | (cp >= 0xF900 and cp <= 0xFAFF) or # 283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 284 | return True 285 | 286 | return False 287 | 288 | def _clean_text(self, text): 289 | """Performs invalid character removal and whitespace cleanup on text.""" 290 | output = [] 291 | for char in text: 292 | cp = ord(char) 293 | if cp == 0 or cp == 0xfffd or _is_control(char): 294 | continue 295 | if _is_whitespace(char): 296 | output.append(" ") 297 | else: 298 | output.append(char) 299 | return "".join(output) 300 | 301 | 302 | class WordpieceTokenizer(object): 303 | """Runs WordPiece tokenziation.""" 304 | 305 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 306 | self.vocab = vocab 307 | self.unk_token = unk_token 308 | self.max_input_chars_per_word = max_input_chars_per_word 309 | 310 | def tokenize(self, text): 311 | """Tokenizes a piece of text into its word pieces. 312 | 313 | This uses a greedy longest-match-first algorithm to perform tokenization 314 | using the given vocabulary. 315 | 316 | For example: 317 | input = "unaffable" 318 | output = ["un", "##aff", "##able"] 319 | 320 | Args: 321 | text: A single token or whitespace separated tokens. This should have 322 | already been passed through `BasicTokenizer. 323 | 324 | Returns: 325 | A list of wordpiece tokens. 326 | """ 327 | 328 | text = convert_to_unicode(text) 329 | 330 | output_tokens = [] 331 | for token in whitespace_tokenize(text): 332 | chars = list(token) 333 | if len(chars) > self.max_input_chars_per_word: 334 | output_tokens.append(self.unk_token) 335 | continue 336 | 337 | is_bad = False 338 | start = 0 339 | sub_tokens = [] 340 | while start < len(chars): 341 | end = len(chars) 342 | cur_substr = None 343 | while start < end: 344 | substr = "".join(chars[start:end]) 345 | if start > 0: 346 | substr = "##" + substr 347 | if substr in self.vocab: 348 | cur_substr = substr 349 | break 350 | end -= 1 351 | if cur_substr is None: 352 | is_bad = True 353 | break 354 | sub_tokens.append(cur_substr) 355 | start = end 356 | 357 | if is_bad: 358 | output_tokens.append(self.unk_token) 359 | else: 360 | output_tokens.extend(sub_tokens) 361 | return output_tokens 362 | 363 | 364 | def _is_whitespace(char): 365 | """Checks whether `chars` is a whitespace character.""" 366 | # \t, \n, and \r are technically contorl characters but we treat them 367 | # as whitespace since they are generally considered as such. 368 | if char == " " or char == "\t" or char == "\n" or char == "\r": 369 | return True 370 | cat = unicodedata.category(char) 371 | if cat == "Zs": 372 | return True 373 | return False 374 | 375 | 376 | def _is_control(char): 377 | """Checks whether `chars` is a control character.""" 378 | # These are technically control characters but we count them as whitespace 379 | # characters. 380 | if char == "\t" or char == "\n" or char == "\r": 381 | return False 382 | cat = unicodedata.category(char) 383 | if cat in ("Cc", "Cf"): 384 | return True 385 | return False 386 | 387 | 388 | def _is_punctuation(char): 389 | """Checks whether `chars` is a punctuation character.""" 390 | cp = ord(char) 391 | # We treat all non-letter/number ASCII as punctuation. 392 | # Characters such as "^", "$", and "`" are not in the Unicode 393 | # Punctuation class but we treat them as punctuation anyways, for 394 | # consistency. 395 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 396 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 397 | return True 398 | cat = unicodedata.category(char) 399 | if cat.startswith("P"): 400 | return True 401 | return False 402 | -------------------------------------------------------------------------------- /tokenization_google.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | # coding=utf-8 18 | """Tokenization classes.""" 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import re 26 | import unicodedata 27 | import six 28 | from six.moves import range 29 | import tensorflow as tf 30 | import sentencepiece as spm 31 | 32 | SPIECE_UNDERLINE = u"▁".encode("utf-8") 33 | 34 | 35 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 36 | """Checks whether the casing config is consistent with the checkpoint name.""" 37 | 38 | # The casing has to be passed in by the user and there is no explicit check 39 | # as to whether it matches the checkpoint. The casing information probably 40 | # should have been stored in the bert_config.json file, but it's not, so 41 | # we have to heuristically detect it to validate. 42 | 43 | if not init_checkpoint: 44 | return 45 | 46 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", 47 | six.ensure_str(init_checkpoint)) 48 | if m is None: 49 | return 50 | 51 | model_name = m.group(1) 52 | 53 | lower_models = [ 54 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 55 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 56 | ] 57 | 58 | cased_models = [ 59 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 60 | "multi_cased_L-12_H-768_A-12" 61 | ] 62 | 63 | is_bad_config = False 64 | if model_name in lower_models and not do_lower_case: 65 | is_bad_config = True 66 | actual_flag = "False" 67 | case_name = "lowercased" 68 | opposite_flag = "True" 69 | 70 | if model_name in cased_models and do_lower_case: 71 | is_bad_config = True 72 | actual_flag = "True" 73 | case_name = "cased" 74 | opposite_flag = "False" 75 | 76 | if is_bad_config: 77 | raise ValueError( 78 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 79 | "However, `%s` seems to be a %s model, so you " 80 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 81 | "how the model was pre-training. If this error is wrong, please " 82 | "just comment out this check." % (actual_flag, init_checkpoint, 83 | model_name, case_name, opposite_flag)) 84 | 85 | 86 | def preprocess_text(inputs, remove_space=True, lower=False): 87 | """preprocess data by removing extra space and normalize data.""" 88 | outputs = inputs 89 | if remove_space: 90 | outputs = " ".join(inputs.strip().split()) 91 | 92 | if six.PY2 and isinstance(outputs, str): 93 | try: 94 | outputs = six.ensure_text(outputs, "utf-8") 95 | except UnicodeDecodeError: 96 | outputs = six.ensure_text(outputs, "latin-1") 97 | 98 | outputs = unicodedata.normalize("NFKD", outputs) 99 | outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) 100 | if lower: 101 | outputs = outputs.lower() 102 | 103 | return outputs 104 | 105 | 106 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 107 | """turn sentences into word pieces.""" 108 | 109 | if six.PY2 and isinstance(text, six.text_type): 110 | text = six.ensure_binary(text, "utf-8") 111 | 112 | if not sample: 113 | pieces = sp_model.EncodeAsPieces(text) 114 | else: 115 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 116 | new_pieces = [] 117 | for piece in pieces: 118 | piece = printable_text(piece) 119 | if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit(): 120 | cur_pieces = sp_model.EncodeAsPieces( 121 | six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b"")) 122 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 123 | if len(cur_pieces[0]) == 1: 124 | cur_pieces = cur_pieces[1:] 125 | else: 126 | cur_pieces[0] = cur_pieces[0][1:] 127 | cur_pieces.append(piece[-1]) 128 | new_pieces.extend(cur_pieces) 129 | else: 130 | new_pieces.append(piece) 131 | 132 | # note(zhiliny): convert back to unicode for py2 133 | if six.PY2 and return_unicode: 134 | ret_pieces = [] 135 | for piece in new_pieces: 136 | if isinstance(piece, str): 137 | piece = six.ensure_text(piece, "utf-8") 138 | ret_pieces.append(piece) 139 | new_pieces = ret_pieces 140 | 141 | return new_pieces 142 | 143 | 144 | def encode_ids(sp_model, text, sample=False): 145 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 146 | ids = [sp_model.PieceToId(piece) for piece in pieces] 147 | return ids 148 | 149 | 150 | def convert_to_unicode(text): 151 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 152 | if six.PY3: 153 | if isinstance(text, str): 154 | return text 155 | elif isinstance(text, bytes): 156 | return six.ensure_text(text, "utf-8", "ignore") 157 | else: 158 | raise ValueError("Unsupported string type: %s" % (type(text))) 159 | elif six.PY2: 160 | if isinstance(text, str): 161 | return six.ensure_text(text, "utf-8", "ignore") 162 | elif isinstance(text, six.text_type): 163 | return text 164 | else: 165 | raise ValueError("Unsupported string type: %s" % (type(text))) 166 | else: 167 | raise ValueError("Not running on Python2 or Python 3?") 168 | 169 | 170 | def printable_text(text): 171 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 172 | 173 | # These functions want `str` for both Python2 and Python3, but in one case 174 | # it's a Unicode string and in the other it's a byte string. 175 | if six.PY3: 176 | if isinstance(text, str): 177 | return text 178 | elif isinstance(text, bytes): 179 | return six.ensure_text(text, "utf-8", "ignore") 180 | else: 181 | raise ValueError("Unsupported string type: %s" % (type(text))) 182 | elif six.PY2: 183 | if isinstance(text, str): 184 | return text 185 | elif isinstance(text, six.text_type): 186 | return six.ensure_binary(text, "utf-8") 187 | else: 188 | raise ValueError("Unsupported string type: %s" % (type(text))) 189 | else: 190 | raise ValueError("Not running on Python2 or Python 3?") 191 | 192 | 193 | def load_vocab(vocab_file): 194 | """Loads a vocabulary file into a dictionary.""" 195 | vocab = collections.OrderedDict() 196 | with tf.gfile.GFile(vocab_file, "r") as reader: 197 | while True: 198 | token = convert_to_unicode(reader.readline()) 199 | if not token: 200 | break 201 | token = token.strip() # previous: token.strip().split()[0] 202 | if token not in vocab: 203 | vocab[token] = len(vocab) 204 | return vocab 205 | 206 | 207 | def convert_by_vocab(vocab, items): 208 | """Converts a sequence of [tokens|ids] using the vocab.""" 209 | output = [] 210 | for item in items: 211 | output.append(vocab[item]) 212 | return output 213 | 214 | 215 | def convert_tokens_to_ids(vocab, tokens): 216 | return convert_by_vocab(vocab, tokens) 217 | 218 | 219 | def convert_ids_to_tokens(inv_vocab, ids): 220 | return convert_by_vocab(inv_vocab, ids) 221 | 222 | 223 | def whitespace_tokenize(text): 224 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 225 | text = text.strip() 226 | if not text: 227 | return [] 228 | tokens = text.split() 229 | return tokens 230 | 231 | 232 | class FullTokenizer(object): 233 | """Runs end-to-end tokenziation.""" 234 | 235 | def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None): 236 | self.vocab = None 237 | self.sp_model = None 238 | print("spm_model_file:",spm_model_file,";vocab_file:",vocab_file) 239 | if spm_model_file: 240 | print("#Use spm_model_file") 241 | self.sp_model = spm.SentencePieceProcessor() 242 | tf.logging.info("loading sentence piece model") 243 | self.sp_model.Load(spm_model_file) 244 | # Note(mingdachen): For the purpose of consisent API, we are 245 | # generating a vocabulary for the sentence piece tokenizer. 246 | self.vocab = {self.sp_model.IdToPiece(i): i for i 247 | in range(self.sp_model.GetPieceSize())} 248 | else: 249 | print("#Use vocab_file") 250 | self.vocab = load_vocab(vocab_file) 251 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 252 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 253 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 254 | 255 | def tokenize(self, text): 256 | if self.sp_model: 257 | split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) 258 | else: 259 | split_tokens = [] 260 | for token in self.basic_tokenizer.tokenize(text): 261 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 262 | split_tokens.append(sub_token) 263 | 264 | return split_tokens 265 | 266 | def convert_tokens_to_ids(self, tokens): 267 | if self.sp_model: 268 | tf.logging.info("using sentence piece tokenzier.") 269 | return [self.sp_model.PieceToId( 270 | printable_text(token)) for token in tokens] 271 | else: 272 | return convert_by_vocab(self.vocab, tokens) 273 | 274 | def convert_ids_to_tokens(self, ids): 275 | if self.sp_model: 276 | tf.logging.info("using sentence piece tokenzier.") 277 | return [self.sp_model.IdToPiece(id_) for id_ in ids] 278 | else: 279 | return convert_by_vocab(self.inv_vocab, ids) 280 | 281 | 282 | class BasicTokenizer(object): 283 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 284 | 285 | def __init__(self, do_lower_case=True): 286 | """Constructs a BasicTokenizer. 287 | 288 | Args: 289 | do_lower_case: Whether to lower case the input. 290 | """ 291 | self.do_lower_case = do_lower_case 292 | 293 | def tokenize(self, text): 294 | """Tokenizes a piece of text.""" 295 | text = convert_to_unicode(text) 296 | text = self._clean_text(text) 297 | 298 | # This was added on November 1st, 2018 for the multilingual and Chinese 299 | # models. This is also applied to the English models now, but it doesn't 300 | # matter since the English models were not trained on any Chinese data 301 | # and generally don't have any Chinese data in them (there are Chinese 302 | # characters in the vocabulary because Wikipedia does have some Chinese 303 | # words in the English Wikipedia.). 304 | text = self._tokenize_chinese_chars(text) 305 | 306 | orig_tokens = whitespace_tokenize(text) 307 | split_tokens = [] 308 | for token in orig_tokens: 309 | if self.do_lower_case: 310 | token = token.lower() 311 | token = self._run_strip_accents(token) 312 | split_tokens.extend(self._run_split_on_punc(token)) 313 | 314 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 315 | return output_tokens 316 | 317 | def _run_strip_accents(self, text): 318 | """Strips accents from a piece of text.""" 319 | text = unicodedata.normalize("NFD", text) 320 | output = [] 321 | for char in text: 322 | cat = unicodedata.category(char) 323 | if cat == "Mn": 324 | continue 325 | output.append(char) 326 | return "".join(output) 327 | 328 | def _run_split_on_punc(self, text): 329 | """Splits punctuation on a piece of text.""" 330 | chars = list(text) 331 | i = 0 332 | start_new_word = True 333 | output = [] 334 | while i < len(chars): 335 | char = chars[i] 336 | if _is_punctuation(char): 337 | output.append([char]) 338 | start_new_word = True 339 | else: 340 | if start_new_word: 341 | output.append([]) 342 | start_new_word = False 343 | output[-1].append(char) 344 | i += 1 345 | 346 | return ["".join(x) for x in output] 347 | 348 | def _tokenize_chinese_chars(self, text): 349 | """Adds whitespace around any CJK character.""" 350 | output = [] 351 | for char in text: 352 | cp = ord(char) 353 | if self._is_chinese_char(cp): 354 | output.append(" ") 355 | output.append(char) 356 | output.append(" ") 357 | else: 358 | output.append(char) 359 | return "".join(output) 360 | 361 | def _is_chinese_char(self, cp): 362 | """Checks whether CP is the codepoint of a CJK character.""" 363 | # This defines a "chinese character" as anything in the CJK Unicode block: 364 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 365 | # 366 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 367 | # despite its name. The modern Korean Hangul alphabet is a different block, 368 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 369 | # space-separated words, so they are not treated specially and handled 370 | # like the all of the other languages. 371 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 372 | (cp >= 0x3400 and cp <= 0x4DBF) or # 373 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 374 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 375 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 376 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 377 | (cp >= 0xF900 and cp <= 0xFAFF) or # 378 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 379 | return True 380 | 381 | return False 382 | 383 | def _clean_text(self, text): 384 | """Performs invalid character removal and whitespace cleanup on text.""" 385 | output = [] 386 | for char in text: 387 | cp = ord(char) 388 | if cp == 0 or cp == 0xfffd or _is_control(char): 389 | continue 390 | if _is_whitespace(char): 391 | output.append(" ") 392 | else: 393 | output.append(char) 394 | return "".join(output) 395 | 396 | 397 | class WordpieceTokenizer(object): 398 | """Runs WordPiece tokenziation.""" 399 | 400 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 401 | self.vocab = vocab 402 | self.unk_token = unk_token 403 | self.max_input_chars_per_word = max_input_chars_per_word 404 | 405 | def tokenize(self, text): 406 | """Tokenizes a piece of text into its word pieces. 407 | 408 | This uses a greedy longest-match-first algorithm to perform tokenization 409 | using the given vocabulary. 410 | 411 | For example: 412 | input = "unaffable" 413 | output = ["un", "##aff", "##able"] 414 | 415 | Args: 416 | text: A single token or whitespace separated tokens. This should have 417 | already been passed through `BasicTokenizer. 418 | 419 | Returns: 420 | A list of wordpiece tokens. 421 | """ 422 | 423 | text = convert_to_unicode(text) 424 | 425 | output_tokens = [] 426 | for token in whitespace_tokenize(text): 427 | chars = list(token) 428 | if len(chars) > self.max_input_chars_per_word: 429 | output_tokens.append(self.unk_token) 430 | continue 431 | 432 | is_bad = False 433 | start = 0 434 | sub_tokens = [] 435 | while start < len(chars): 436 | end = len(chars) 437 | cur_substr = None 438 | while start < end: 439 | substr = "".join(chars[start:end]) 440 | if start > 0: 441 | substr = "##" + six.ensure_str(substr) 442 | if substr in self.vocab: 443 | cur_substr = substr 444 | break 445 | end -= 1 446 | if cur_substr is None: 447 | is_bad = True 448 | break 449 | sub_tokens.append(cur_substr) 450 | start = end 451 | 452 | if is_bad: 453 | output_tokens.append(self.unk_token) 454 | else: 455 | output_tokens.extend(sub_tokens) 456 | return output_tokens 457 | 458 | 459 | def _is_whitespace(char): 460 | """Checks whether `chars` is a whitespace character.""" 461 | # \t, \n, and \r are technically control characters but we treat them 462 | # as whitespace since they are generally considered as such. 463 | if char == " " or char == "\t" or char == "\n" or char == "\r": 464 | return True 465 | cat = unicodedata.category(char) 466 | if cat == "Zs": 467 | return True 468 | return False 469 | 470 | 471 | def _is_control(char): 472 | """Checks whether `chars` is a control character.""" 473 | # These are technically control characters but we count them as whitespace 474 | # characters. 475 | if char == "\t" or char == "\n" or char == "\r": 476 | return False 477 | cat = unicodedata.category(char) 478 | if cat in ("Cc", "Cf"): 479 | return True 480 | return False 481 | 482 | 483 | def _is_punctuation(char): 484 | """Checks whether `chars` is a punctuation character.""" 485 | cp = ord(char) 486 | # We treat all non-letter/number ASCII as punctuation. 487 | # Characters such as "^", "$", and "`" are not in the Unicode 488 | # Punctuation class but we treat them as punctuation anyways, for 489 | # consistency. 490 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 491 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 492 | return True 493 | cat = unicodedata.category(char) 494 | if cat.startswith("P"): 495 | return True 496 | return False 497 | --------------------------------------------------------------------------------