├── img └── shared_task_QA_code ├── example ├── nlpcc2020_baseline.zip └── cluewsc2020_predict.json ├── baselines ├── models │ ├── bert_wsc_csl │ │ ├── requirements.txt │ │ ├── prev_trained_models │ │ │ └── RoBERTa-tiny-clue │ │ │ │ └── bert_config.json │ │ ├── __init__.py │ │ ├── tpu │ │ │ ├── run_classifier_inews.sh │ │ │ ├── run_classifier_lcqmc.sh │ │ │ ├── run_classifier_xnli.sh │ │ │ ├── run_classifier_thucnews.sh │ │ │ ├── run_classifier_jdcomment.sh │ │ │ └── run_classifier_tnews.sh │ │ ├── CONTRIBUTING.md │ │ ├── optimization_test.py │ │ ├── .gitignore │ │ ├── run_classifier_csl.sh │ │ ├── run_classifier_wsc.sh │ │ ├── sample_text.txt │ │ ├── run_classifier_clue.sh │ │ ├── tokenization_test.py │ │ ├── optimization.py │ │ ├── tf_metrics.py │ │ ├── wsc_output │ │ │ └── wsc_predict.json │ │ ├── modeling_test.py │ │ ├── conlleval.py │ │ ├── LICENSE │ │ ├── multilingual.md │ │ └── tokenization.py │ ├── bert_mrc │ │ ├── README.md │ │ ├── __init__.py │ │ ├── run_cmrc.sh │ │ ├── LICENSE │ │ ├── cmrc2018_evaluate.py │ │ ├── optimization.py │ │ └── tokenization.py │ ├── README.md │ └── bert_ner │ │ ├── prev_trained_models │ │ └── RoBERTa-tiny-clue │ │ │ └── bert_config.json │ │ ├── len_count.json │ │ ├── label2id.json │ │ ├── readme.md │ │ ├── run_classifier_tiny.sh │ │ ├── data-ming.py │ │ ├── data_processor_seq.py │ │ ├── predict_sequence_label.py │ │ ├── optimization.py │ │ └── tokenization.py └── CLUEdataset │ └── ner │ ├── README.md │ └── cluener_predict.json ├── Schedule4clue.md ├── .gitignore └── README.md /img/shared_task_QA_code: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLUEbenchmark/LightLM/HEAD/img/shared_task_QA_code -------------------------------------------------------------------------------- /example/nlpcc2020_baseline.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLUEbenchmark/LightLM/HEAD/example/nlpcc2020_baseline.zip -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/README.md: -------------------------------------------------------------------------------- 1 | # mrc 2 | You should run other tasks before for downloading models and data. I didn't write these shell. 3 | 4 | This will be modified later 5 | 6 | 7 | 8 | This is copied from 9 | https://github.com/johndpope/CMRC2018-DRCD-BERT 10 | -------------------------------------------------------------------------------- /baselines/CLUEdataset/ner/README.md: -------------------------------------------------------------------------------- 1 | CLUENER 细粒度命名实体识别 2 | 3 | 数据分为10个标签类别,分别为: 4 | 地址(address), 5 | 书名(book), 6 | 公司(company), 7 | 游戏(game), 8 | 政府(goverment), 9 | 电影(movie), 10 | 姓名(name), 11 | 组织机构(organization), 12 | 职位(position), 13 | 景点(scene) 14 | 15 | 数据详细介绍、基线模型和效果测评,见 https://github.com/CLUEbenchmark/CLUENER 16 | 17 | 技术讨论或问题,请项目中提issue或PR,或发送电子邮件到 ChineseGLUE@163.com 18 | 19 | 测试集上SOTA效果见榜单:www.CLUEbenchmark.com -------------------------------------------------------------------------------- /baselines/CLUEdataset/ner/cluener_predict.json: -------------------------------------------------------------------------------- 1 | {"id": 0, "label": {"address": {"丹棱县": [[11, 13]]}, "name": {"胡文和": [[41, 43]]}}} 2 | {"id": 1, "label": {"address": {"阿布贾": [[12, 14]]}, "goverment": {"尼日利亚海军": [[0, 5]], "尼日利亚通讯社": [[16, 22]]}}} 3 | {"id": 2, "label": {"game": {"辐射3-Bethesda": [[5, 16]]}}} 4 | {"id": 3, "label": {"scene": {"巴厘岛": [[9, 11]]}}} 5 | {"id": 4, "label": {?????}} 6 | {"id": 5, "label": {?????}} 7 | {"id": 6, "label": {?????}} -------------------------------------------------------------------------------- /baselines/models/README.md: -------------------------------------------------------------------------------- 1 | baselines/models目录下提供了本次比赛中四个任务的预测脚本,使用方法如下: 2 | 3 | ``` 4 | 可以进入baselines/models目录下 5 | 1. bert_wsc_csl目录下为WSC和CSL任务: bash run_classifier_clue.sh 6 | - 结果生成在wsc_ouput和csl_output目录下 7 | 2. bert_mrc 为 cmrc2018 任务baseline: bash run_ner_cmrc.sh 8 | - 结果生成为cmrc2018_predict.json 9 | 3. bert_ner 为 CLUENER任务 baseline: bash run_ner.sh 10 | - 结果生成在ner_output下 11 | 12 | ``` 13 | 只要加载相应的模型,即可进行预测。本demo中会自动使用"RoBERTa-tiny-clue"进行预测 14 | 15 | 提交指导:将以上的结果,打包为nlpcc-xxx.zip 进行提交。注意各个输出文件的名称是不能变的,要保留,不然系统无法识别。 -------------------------------------------------------------------------------- /baselines/models/bert_ner/prev_trained_models/RoBERTa-tiny-clue/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 312, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1248, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 4, 12 | 13 | "pooler_fc_size": 768, 14 | "pooler_num_attention_heads": 12, 15 | "pooler_num_fc_layers": 3, 16 | "pooler_size_per_head": 128, 17 | "pooler_type": "first_token_transform", 18 | "type_vocab_size": 2, 19 | "vocab_size": 8021 20 | } -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/prev_trained_models/RoBERTa-tiny-clue/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 312, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1248, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 4, 12 | 13 | "pooler_fc_size": 768, 14 | "pooler_num_attention_heads": 12, 15 | "pooler_num_fc_layers": 3, 16 | "pooler_size_per_head": 128, 17 | "pooler_type": "first_token_transform", 18 | "type_vocab_size": 2, 19 | "vocab_size": 8021 20 | } -------------------------------------------------------------------------------- /baselines/models/bert_mrc/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/len_count.json: -------------------------------------------------------------------------------- 1 | { 2 | "50": 574, 3 | "18": 138, 4 | "22": 239, 5 | "40": 530, 6 | "44": 593, 7 | "48": 644, 8 | "20": 156, 9 | "27": 113, 10 | "30": 181, 11 | "47": 578, 12 | "12": 77, 13 | "23": 105, 14 | "35": 303, 15 | "45": 666, 16 | "15": 96, 17 | "43": 636, 18 | "41": 544, 19 | "36": 319, 20 | "49": 673, 21 | "39": 455, 22 | "19": 124, 23 | "42": 600, 24 | "25": 133, 25 | "34": 248, 26 | "9": 43, 27 | "6": 20, 28 | "29": 133, 29 | "13": 68, 30 | "10": 56, 31 | "16": 121, 32 | "38": 435, 33 | "37": 383, 34 | "17": 171, 35 | "11": 92, 36 | "33": 226, 37 | "46": 561, 38 | "31": 177, 39 | "28": 147, 40 | "32": 219, 41 | "14": 84, 42 | "24": 100, 43 | "21": 151, 44 | "26": 101, 45 | "3": 9, 46 | "8": 30, 47 | "5": 12, 48 | "7": 15, 49 | "4": 9, 50 | "2": 3 51 | } -------------------------------------------------------------------------------- /baselines/models/bert_ner/label2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "O": 0, 3 | "S_address": 1, 4 | "B_address": 2, 5 | "M_address": 3, 6 | "E_address": 4, 7 | "S_book": 5, 8 | "B_book": 6, 9 | "M_book": 7, 10 | "E_book": 8, 11 | "S_company": 9, 12 | "B_company": 10, 13 | "M_company": 11, 14 | "E_company": 12, 15 | "S_game": 13, 16 | "B_game": 14, 17 | "M_game": 15, 18 | "E_game": 16, 19 | "S_government": 17, 20 | "B_government": 18, 21 | "M_government": 19, 22 | "E_government": 20, 23 | "S_movie": 21, 24 | "B_movie": 22, 25 | "M_movie": 23, 26 | "E_movie": 24, 27 | "S_name": 25, 28 | "B_name": 26, 29 | "M_name": 27, 30 | "E_name": 28, 31 | "S_organization": 29, 32 | "B_organization": 30, 33 | "M_organization": 31, 34 | "E_organization": 32, 35 | "S_position": 33, 36 | "B_position": 34, 37 | "M_position": 35, 38 | "E_position": 36, 39 | "S_scene": 37, 40 | "B_scene": 38, 41 | "M_scene": 39, 42 | "E_scene": 40 43 | } -------------------------------------------------------------------------------- /Schedule4clue.md: -------------------------------------------------------------------------------- 1 | - [x] 3.12列出比赛信息相关的时间节点 2 | - [ ] 3.13 给出比赛说明的详细版本draft-中文版。 3 | - [ ] 群内审核 4 | - [ ] 蓝博士,杨老师,李老师等老师请教意见 5 | - [ ] 明确排行榜的样式以及对应的周赛形式 - 3.15之前 6 | - [ ] 评估数据集:明确我们用哪些数据集,是否区分最终比赛和平时测试的比赛数据集 - 3.15之前确定方案,之后完善细节 7 | - [ ] 3.25-4.6 使用网络上的数据集,作为热身 8 | - [ ] 4.6-5.15 我们提供新的数据集,作为最终结果评测【是否要保留一部分测试集合作为5.15才释放的】 9 | - [ ] 以上的数据集选择我们需要有根据 10 | - [ ] 整理出baseline,以及对应的数据和评估脚本-2020.3.20之前 11 | - [ ] 测试系统,评测,并走通流程-2020.3.25之前。容错到4.6。可以做测试 12 | 13 | --- 以上为比较紧急的任务。 14 | 15 | - [ ] 列出本次比赛中的推荐参考文献,可以有选择性的解读 - 3.20之前完成 16 | 1. Albert【话说如果只修改NSP会不会也会有效果呢】 17 | 2. Electra 18 | 3. Distilling the knowledge in a Neural Network (Hinton) 19 | 4. Tinybert: Distilling Bert for Natrual Language Understanding 20 | 5. Attention-guided answer distillation for machine reading comprehension 21 | 6. efficientNet: Rethinking Model scaling for Convolution NN 22 | - [ ] 评估标准:以一定的size确定,使结果尽可能好【size明确一个数值】 23 | - [ ] 24 | 25 | 26 | 27 | ------ 以上使在3.25之前需要完成的事情 28 | 29 | - [ ] 撰写相关评测文档 并对作者提交结果进行check 30 | 31 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_inews.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="inews" 4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/ 5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME 6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 7 | 8 | python $CURRENT_DIR/../run_classifier.py \ 9 | --task_name=$TASK_NAME \ 10 | --do_train=true \ 11 | --do_eval=true \ 12 | --data_dir=$DATA_DIR \ 13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 16 | --max_seq_length=512 \ 17 | --train_batch_size=16 \ 18 | --learning_rate=2e-5 \ 19 | --num_train_epochs=8.0 \ 20 | --output_dir=$OUTPUT_DIR \ 21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470 22 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_lcqmc.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="lcqmc" 4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12 5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/hard_$TASK_NAME 6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 7 | 8 | python $CURRENT_DIR/../run_classifier.py \ 9 | --task_name=$TASK_NAME \ 10 | --do_train=true \ 11 | --do_eval=true \ 12 | --data_dir=$DATA_DIR \ 13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 16 | --max_seq_length=128 \ 17 | --train_batch_size=16 \ 18 | --learning_rate=2e-5 \ 19 | --num_train_epochs=8.0 \ 20 | --output_dir=$OUTPUT_DIR \ 21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://172.20.0.2:8470 22 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_xnli.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="xnli" 4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/ 5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME 6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 7 | 8 | python $CURRENT_DIR/../run_classifier.py \ 9 | --task_name=$TASK_NAME \ 10 | --do_train=true \ 11 | --do_eval=true \ 12 | --data_dir=$DATA_DIR \ 13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 16 | --max_seq_length=512 \ 17 | --train_batch_size=16 \ 18 | --learning_rate=2e-5 \ 19 | --num_train_epochs=8.0 \ 20 | --output_dir=$OUTPUT_DIR \ 21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470 22 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_thucnews.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="thucnews" 4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/ 5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME 6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 7 | 8 | python $CURRENT_DIR/../run_classifier.py \ 9 | --task_name=$TASK_NAME \ 10 | --do_train=true \ 11 | --do_eval=true \ 12 | --data_dir=$DATA_DIR \ 13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 16 | --max_seq_length=512 \ 17 | --train_batch_size=16 \ 18 | --learning_rate=2e-5 \ 19 | --num_train_epochs=8.0 \ 20 | --output_dir=$OUTPUT_DIR \ 21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470 22 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_jdcomment.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="jdcomment" 4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12 5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/hard_${TASK_NAME} 6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 7 | echo $DATA_DIR 8 | python3 $CURRENT_DIR/../run_classifier.py \ 9 | --task_name=$TASK_NAME \ 10 | --do_train=true \ 11 | --do_eval=true \ 12 | --data_dir=$DATA_DIR \ 13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 16 | --max_seq_length=128 \ 17 | --train_batch_size=32 \ 18 | --learning_rate=2e-5 \ 19 | --num_train_epochs=3.0 \ 20 | --output_dir=$OUTPUT_DIR \ 21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://172.18.0.2:8470 22 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/run_cmrc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################################### 3 | # File Name: run.sh 4 | # Author: Junyi Li 5 | # Personal page: dukeenglish.github.io 6 | # Created Time: 21:36:42 2020-03-30 7 | ######################################################################### 8 | PATH_TO_BERT=../bert_wsc_csl/prev_trained_models/RoBERTa-tiny-clue 9 | 10 | DATA_DIR=../..//CLUEdataset/cmrc 11 | 12 | MODEL_DIR=. 13 | 14 | 15 | python run_cmrc2018_drcd_baseline.py \ 16 | --vocab_file=${PATH_TO_BERT}/vocab.txt \ 17 | --bert_config_file=${PATH_TO_BERT}/bert_config.json \ 18 | --init_checkpoint=${PATH_TO_BERT}/bert_model.ckpt \ 19 | --do_train=True \ 20 | --train_file=${DATA_DIR}/train.json \ 21 | --do_predict=True \ 22 | --predict_file=${DATA_DIR}/test.json \ 23 | --train_batch_size=32 \ 24 | --num_train_epochs=2 \ 25 | --max_seq_length=512 \ 26 | --doc_stride=128 \ 27 | --learning_rate=3e-5 \ 28 | --save_checkpoints_steps=1000 \ 29 | --output_dir=${MODEL_DIR} \ 30 | --do_lower_case=False \ 31 | --use_tpu=False 32 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yiming Cui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 如何训练、提交测试 4 | 5 | 环境:Python 3 & Tensorflow 1.x,如1.14; 6 | 7 | ## 模型1 一键运行RoBERTa-wwm-large 8 | 9 | nohup bash run_classifier_roberta_wwm_large.sh & 10 | 11 | 12 | ## 模型2 13 | 14 | 第一步, 生成tf_record 15 | 修改 data_processor_seq.py 里面 函数的输入输出路径即可 16 | ``` 17 | python data_processor_seq.py 18 | ``` 19 | 20 | 第二步, 训练ner模型 21 | 修改 train_sequence_label.py 里面 config字典即可(如模型参数、文件路径等) 22 | ``` 23 | python train_sequence_label.py 24 | ``` 25 | 26 | 第三步, 加载模型进行测试 27 | 修改 predict_sequence_label.py 里面 model_path(保存模型的路径), 以及预测文件路径即可 28 | ``` 29 | python predict_sequence_label.py 30 | ``` 31 | 32 | ### 评估 33 | 以F1-Score为评测指标,修改 score.py 里面 pre ,gold文件即可(验证可用),测试阶段不提供哦 34 | ``` 35 | python score.py 36 | ``` 37 | 38 | | 模型 | 线上效果f1 | 39 | |:-------------:|:-----:| 40 | | bilstm+crf | 70.00 | 41 | | bert-base | 78.82 | 42 | | roberta-wwm-large-ext | **80.42** | 43 | |Human Performance|63.41| 44 | 45 | 各个实体的评测结果: 46 | 47 | 48 | | 实体 | bilstm+crf | bert-base | roberta-wwm-large-ext | Human Performance | 49 | |:-------------:|:-----:|:-----:|:-----:|:-----:| 50 | | Person Name | 74.04 | 88.75 | **89.09** | 74.49 | 51 | | Organization | 75.96 | 79.43 | **82.34** | 65.41 | 52 | | Position | 70.16 | 78.89 | **79.62** | 55.38 | 53 | | Company | 72.27 | 81.42 | **83.02** | 49.32 | 54 | | Address | 45.50 | 60.89 | **62.63** | 43.04 | 55 | | Game | 85.27 | 86.42 | **86.80** | 80.39 | 56 | | Government | 77.25 | 87.03 | **88.17** | 79.27 | 57 | | Scene | 52.42 | 65.10 | **70.49** | 51.85 | 58 | | Book | 67.20 | 73.68 | **74.60** | 71.70 | 59 | | Movie | 78.97 | 85.82 | **87.46** | 63.21 | 60 | 61 | 更具体的评测结果,请参考我们的技术报告:https://arxiv.org/abs/2001.04351 62 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/run_classifier_tiny.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # @Author: bo.shi 3 | # @Date: 2019-11-04 09:56:36 4 | # @Last Modified by: bo.shi 5 | # @Last Modified time: 2019-12-05 11:23:30 6 | 7 | TASK_NAME="ner" 8 | MODEL_NAME="RoBERTa-tiny-clue" 9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 10 | export CUDA_VISIBLE_DEVICES="0" 11 | export PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_models 12 | export ROBERTA_CLUE_DIR=$PRETRAINED_MODELS_DIR/$MODEL_NAME 13 | export CLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset 14 | 15 | # download and unzip dataset 16 | if [ ! -d $CLUE_DATA_DIR ]; then 17 | mkdir -p $CLUE_DATA_DIR 18 | echo "makedir $CLUE_DATA_DIR" 19 | fi 20 | cd $CLUE_DATA_DIR 21 | if [ ! -d $TASK_NAME ]; then 22 | mkdir $TASK_NAME 23 | echo "makedir $CLUE_DATA_DIR/$TASK_NAME" 24 | fi 25 | cd $TASK_NAME 26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 27 | rm * 28 | wget https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip 29 | unzip cluener_public.zip 30 | rm cluener_public.zip 31 | else 32 | echo "data exists" 33 | fi 34 | echo "Finish download dataset." 35 | 36 | # download model 37 | if [ ! -d $ROBERTA_CLUE_DIR ]; then 38 | mkdir -p $ROBERTA_CLUE_DIR 39 | echo "makedir $ROBERTA_CLUE_DIR" 40 | fi 41 | cd $ROBERTA_CLUE_DIR 42 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then 43 | rm * 44 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-clue.zip 45 | unzip RoBERTa-tiny-clue.zip 46 | rm RoBERTa-tiny-clue.zip 47 | else 48 | echo "model exists" 49 | fi 50 | echo "Finish download model." 51 | 52 | # run task 53 | cd $CURRENT_DIR 54 | echo "Start running..." 55 | 56 | python run_classifier_ner.py \ 57 | --task_name=$TASK_NAME \ 58 | --do_train=False \ 59 | --do_predict=True \ 60 | --data_dir=$CLUE_DATA_DIR/$TASK_NAME \ 61 | --vocab_file=$ROBERTA_CLUE_DIR/vocab.txt \ 62 | --bert_config_file=$ROBERTA_CLUE_DIR/bert_config.json \ 63 | --init_checkpoint=$ROBERTA_CLUE_DIR/bert_model.ckpt \ 64 | --max_seq_length=128 \ 65 | --train_batch_size=32 \ 66 | --learning_rate=2e-5 \ 67 | --num_train_epochs=4.0 \ 68 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # mac 2 | .DS_Store 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/data-ming.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf8 3 | """ 4 | @author: Cong Yu 5 | @time: 2020-01-09 18:51 6 | """ 7 | import re 8 | import json 9 | 10 | 11 | def prepare_label(): 12 | text = """ 13 | 地址(address): 544 14 | 书名(book): 258 15 | 公司(company): 479 16 | 游戏(game): 281 17 | 政府(government): 262 18 | 电影(movie): 307 19 | 姓名(name): 710 20 | 组织机构(organization): 515 21 | 职位(position): 573 22 | 景点(scene): 288 23 | """ 24 | 25 | a = re.findall(r"((.*?))", text.strip()) 26 | print(a) 27 | label2id = {"O": 0} 28 | index = 1 29 | for i in a: 30 | label2id["S_" + i] = index 31 | label2id["B_" + i] = index + 1 32 | label2id["M_" + i] = index + 2 33 | label2id["E_" + i] = index + 3 34 | index += 4 35 | 36 | open("label2id.json", "w").write(json.dumps(label2id, ensure_ascii=False, indent=2)) 37 | 38 | 39 | def prepare_len_count(): 40 | len_count = {} 41 | 42 | for line in open("data/thuctc_train.json"): 43 | if line.strip(): 44 | _ = json.loads(line.strip()) 45 | len_ = len(_["text"]) 46 | if len_count.get(len_): 47 | len_count[len_] += 1 48 | else: 49 | len_count[len_] = 1 50 | 51 | for line in open("data/thuctc_valid.json"): 52 | if line.strip(): 53 | _ = json.loads(line.strip()) 54 | len_ = len(_["text"]) 55 | if len_count.get(len_): 56 | len_count[len_] += 1 57 | else: 58 | len_count[len_] = 1 59 | 60 | print("len_count", json.dumps(len_count, indent=2)) 61 | open("len_count.json", "w").write(json.dumps(len_count, indent=2)) 62 | 63 | 64 | def label_count(path): 65 | labels = ['address', 'book', 'company', 'game', 'government', 'movie', 'name', 'organization', 'position', 'scene'] 66 | label2desc = { 67 | "address": "地址", 68 | "book": "书名", 69 | "company": "公司", 70 | "game": "游戏", 71 | "government": "政府", 72 | "movie": "电影", 73 | "name": "姓名", 74 | "organization": "组织机构", 75 | "position": "职位", 76 | "scene": "景点" 77 | } 78 | label_count_dict = {i: 0 for i in labels} 79 | for line in open(path): 80 | if line.strip(): 81 | _ = json.loads(line.strip()) 82 | for k, v in _["label"].items(): 83 | label_count_dict[k] += len(v) 84 | for k, v in label_count_dict.items(): 85 | print("{}({}):{}".format(label2desc[k], k, v)) 86 | print("\n") 87 | 88 | 89 | # prepare_label() 90 | # label_count("data/thuctc_train.json") 91 | # label_count("data/thuctc_valid.json") 92 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tpu/run_classifier_tnews.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S") 3 | TASK_NAME="tnews" 4 | 5 | GS="gs" # change it to yours 6 | TPU_IP="1.1.1.1" # chagne it to your 7 | # please create folder 8 | export PREV_TRAINED_MODEL_DIR=$GS/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12 9 | export DATA_DIR=$GS/nlp/chineseGLUEdatasets.v0.0.1/${TASK_NAME} 10 | export OUTPUT_DIR=$GS/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME 11 | 12 | 13 | MODEL_NAME="chinese_L-12_H-768_A-12" 14 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 15 | export CUDA_VISIBLE_DEVICES="0" 16 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model 17 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME 18 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset 19 | 20 | # download and unzip dataset 21 | if [ ! -d $GLUE_DATA_DIR ]; then 22 | mkdir -p $GLUE_DATA_DIR 23 | echo "makedir $GLUE_DATA_DIR" 24 | fi 25 | cd $GLUE_DATA_DIR 26 | if [ ! -d $TASK_NAME ]; then 27 | mkdir $TASK_NAME 28 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME" 29 | fi 30 | cd $TASK_NAME 31 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 32 | rm * 33 | wget https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip 34 | unzip tnews_public.zip 35 | rm tnews_public.zip 36 | else 37 | echo "data exists" 38 | fi 39 | echo "Finish download dataset." 40 | 41 | # download model 42 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then 43 | mkdir -p $BERT_PRETRAINED_MODELS_DIR 44 | echo "makedir $BERT_PRETRAINED_MODELS_DIR" 45 | fi 46 | cd $BERT_PRETRAINED_MODELS_DIR 47 | if [ ! -d $MODEL_NAME ]; then 48 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 49 | unzip chinese_L-12_H-768_A-12.zip 50 | rm chinese_L-12_H-768_A-12.zip 51 | else 52 | cd $MODEL_NAME 53 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then 54 | cd .. 55 | rm -rf $MODEL_NAME 56 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 57 | unzip chinese_L-12_H-768_A-12.zip 58 | rm chinese_L-12_H-768_A-12.zip 59 | else 60 | echo "model exists" 61 | fi 62 | fi 63 | echo "Finish download model." 64 | 65 | # upload model and data 66 | gsutil -m cp $BERT_PRETRAINED_MODELS_DIR/* $PREV_TRAINED_MODEL_DIR 67 | 68 | gsutil -m cp $GLUE_DATA_DIR/$TASK_NAME/* $DATA_DIR 69 | 70 | cd $CURRENT_DIR 71 | echo "Start running..." 72 | python $CURRENT_DIR/../run_classifier.py \ 73 | --task_name=$TASK_NAME \ 74 | --do_train=true \ 75 | --do_eval=true \ 76 | --data_dir=$DATA_DIR \ 77 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 78 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 79 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 80 | --max_seq_length=128 \ 81 | --train_batch_size=32 \ 82 | --learning_rate=2e-5 \ 83 | --num_train_epochs=3.0 \ 84 | --output_dir=$OUTPUT_DIR \ 85 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://$TPU_IP:8470 86 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/run_classifier_csl.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # @Author: Li Yudong 3 | # @Date: 2019-11-28 4 | # @Last Modified by: bo.shi 5 | # @Last Modified time: 2019-12-05 11:00:57 6 | 7 | TASK_NAME="csl" 8 | MODEL_NAME="chinese_L-12_H-768_A-12" 9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 10 | export CUDA_VISIBLE_DEVICES="0" 11 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model 12 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME 13 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset 14 | 15 | # download and unzip dataset 16 | if [ ! -d $GLUE_DATA_DIR ]; then 17 | mkdir -p $GLUE_DATA_DIR 18 | echo "makedir $GLUE_DATA_DIR" 19 | fi 20 | cd $GLUE_DATA_DIR 21 | if [ ! -d $TASK_NAME ]; then 22 | mkdir $TASK_NAME 23 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME" 24 | fi 25 | cd $TASK_NAME 26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 27 | rm * 28 | wget https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip 29 | unzip csl_public.zip 30 | rm csl_public.zip 31 | else 32 | echo "data exists" 33 | fi 34 | echo "Finish download dataset." 35 | 36 | # download model 37 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then 38 | mkdir -p $BERT_PRETRAINED_MODELS_DIR 39 | echo "makedir $BERT_PRETRAINED_MODELS_DIR" 40 | fi 41 | cd $BERT_PRETRAINED_MODELS_DIR 42 | if [ ! -d $MODEL_NAME ]; then 43 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 44 | unzip chinese_L-12_H-768_A-12.zip 45 | rm chinese_L-12_H-768_A-12.zip 46 | else 47 | cd $MODEL_NAME 48 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then 49 | cd .. 50 | rm -rf $MODEL_NAME 51 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 52 | unzip chinese_L-12_H-768_A-12.zip 53 | rm chinese_L-12_H-768_A-12.zip 54 | else 55 | echo "model exists" 56 | fi 57 | fi 58 | echo "Finish download model." 59 | 60 | # run task 61 | cd $CURRENT_DIR 62 | echo "Start running..." 63 | if [ $# == 0 ]; then 64 | python run_classifier.py \ 65 | --task_name=$TASK_NAME \ 66 | --do_train=true \ 67 | --do_eval=true \ 68 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \ 69 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 70 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 71 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 72 | --max_seq_length=128 \ 73 | --train_batch_size=32 \ 74 | --learning_rate=2e-5 \ 75 | --num_train_epochs=3.0 \ 76 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/ 77 | elif [ $1 == "predict" ]; then 78 | echo "Start predict..." 79 | python run_classifier.py \ 80 | --task_name=$TASK_NAME \ 81 | --do_train=false \ 82 | --do_eval=false \ 83 | --do_predict=true \ 84 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \ 85 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 86 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 87 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 88 | --max_seq_length=128 \ 89 | --train_batch_size=32 \ 90 | --learning_rate=2e-5 \ 91 | --num_train_epochs=3.0 \ 92 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/ 93 | fi 94 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/run_classifier_wsc.sh: -------------------------------------------------------------------------------- 1 | # @Author: bo.shi 2 | # @Date: 2019-12-01 22:28:41 3 | # @Last Modified by: bo.shi 4 | # @Last Modified time: 2019-12-05 11:01:11 5 | #!/usr/bin/env bash 6 | 7 | TASK_NAME="cluewsc2020" 8 | MODEL_NAME="chinese_L-12_H-768_A-12" 9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 10 | export CUDA_VISIBLE_DEVICES="0" 11 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model 12 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME 13 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset 14 | 15 | # download and unzip dataset 16 | if [ ! -d $GLUE_DATA_DIR ]; then 17 | mkdir -p $GLUE_DATA_DIR 18 | echo "makedir $GLUE_DATA_DIR" 19 | fi 20 | cd $GLUE_DATA_DIR 21 | if [ ! -d $TASK_NAME ]; then 22 | mkdir $TASK_NAME 23 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME" 24 | fi 25 | cd $TASK_NAME 26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 27 | rm * 28 | wget https://storage.googleapis.com/cluebenchmark/tasks/cluewsc2020_public.zip 29 | unzip cluewsc2020_public.zip 30 | rm cluewsc2020_public.zip 31 | else 32 | echo "data exists" 33 | fi 34 | echo "Finish download dataset." 35 | 36 | # download model 37 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then 38 | mkdir -p $BERT_PRETRAINED_MODELS_DIR 39 | echo "makedir $BERT_PRETRAINED_MODELS_DIR" 40 | fi 41 | cd $BERT_PRETRAINED_MODELS_DIR 42 | if [ ! -d $MODEL_NAME ]; then 43 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 44 | unzip chinese_L-12_H-768_A-12.zip 45 | rm chinese_L-12_H-768_A-12.zip 46 | else 47 | cd $MODEL_NAME 48 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then 49 | cd .. 50 | rm -rf $MODEL_NAME 51 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 52 | unzip chinese_L-12_H-768_A-12.zip 53 | rm chinese_L-12_H-768_A-12.zip 54 | else 55 | echo "model exists" 56 | fi 57 | fi 58 | echo "Finish download model." 59 | 60 | # run task 61 | cd $CURRENT_DIR 62 | echo "Start running..." 63 | if [ $# == 0 ]; then 64 | python run_classifier.py \ 65 | --task_name=$TASK_NAME \ 66 | --do_train=true \ 67 | --do_eval=true \ 68 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \ 69 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 70 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 71 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 72 | --max_seq_length=128 \ 73 | --train_batch_size=32 \ 74 | --learning_rate=2e-5 \ 75 | --num_train_epochs=3.0 \ 76 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/ 77 | elif [ $1 == "predict" ]; then 78 | echo "Start predict..." 79 | python run_classifier.py \ 80 | --task_name=$TASK_NAME \ 81 | --do_train=false \ 82 | --do_eval=false \ 83 | --do_predict=true \ 84 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \ 85 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 86 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 87 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 88 | --max_seq_length=128 \ 89 | --train_batch_size=32 \ 90 | --learning_rate=2e-5 \ 91 | --num_train_epochs=3.0 \ 92 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/ 93 | fi 94 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 - special 5 | Note: 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import Counter, OrderedDict 12 | import string 13 | import re 14 | import argparse 15 | import json 16 | import sys 17 | reload(sys) 18 | sys.setdefaultencoding('utf8') 19 | import nltk 20 | import pdb 21 | 22 | # split Chinese with English 23 | def mixed_segmentation(in_str, rm_punc=False): 24 | in_str = str(in_str).decode('utf-8').lower().strip() 25 | segs_out = [] 26 | temp_str = "" 27 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 28 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 29 | '「','」','(',')','-','~','『','』'] 30 | for char in in_str: 31 | if rm_punc and char in sp_char: 32 | continue 33 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char: 34 | if temp_str != "": 35 | ss = nltk.word_tokenize(temp_str) 36 | segs_out.extend(ss) 37 | temp_str = "" 38 | segs_out.append(char) 39 | else: 40 | temp_str += char 41 | 42 | #handling last part 43 | if temp_str != "": 44 | ss = nltk.word_tokenize(temp_str) 45 | segs_out.extend(ss) 46 | 47 | return segs_out 48 | 49 | 50 | # remove punctuation 51 | def remove_punctuation(in_str): 52 | in_str = str(in_str).decode('utf-8').lower().strip() 53 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 54 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 55 | '「','」','(',')','-','~','『','』'] 56 | out_segs = [] 57 | for char in in_str: 58 | if char in sp_char: 59 | continue 60 | else: 61 | out_segs.append(char) 62 | return ''.join(out_segs) 63 | 64 | 65 | # find longest common string 66 | def find_lcs(s1, s2): 67 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 68 | mmax = 0 69 | p = 0 70 | for i in range(len(s1)): 71 | for j in range(len(s2)): 72 | if s1[i] == s2[j]: 73 | m[i+1][j+1] = m[i][j]+1 74 | if m[i+1][j+1] > mmax: 75 | mmax=m[i+1][j+1] 76 | p=i+1 77 | return s1[p-mmax:p], mmax 78 | 79 | # 80 | def evaluate(ground_truth_file, prediction_file): 81 | f1 = 0 82 | em = 0 83 | total_count = 0 84 | skip_count = 0 85 | for instance in ground_truth_file["data"]: 86 | #context_id = instance['context_id'].strip() 87 | #context_text = instance['context_text'].strip() 88 | for para in instance["paragraphs"]: 89 | for qas in para['qas']: 90 | total_count += 1 91 | query_id = qas['id'].strip() 92 | query_text = qas['question'].strip() 93 | answers = [x["text"] for x in qas['answers']] 94 | 95 | if query_id not in prediction_file: 96 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 97 | skip_count += 1 98 | continue 99 | 100 | prediction = str(prediction_file[query_id]).decode('utf-8') 101 | f1 += calc_f1_score(answers, prediction) 102 | em += calc_em_score(answers, prediction) 103 | 104 | f1_score = 100.0 * f1 / total_count 105 | em_score = 100.0 * em / total_count 106 | return f1_score, em_score, total_count, skip_count 107 | 108 | 109 | def calc_f1_score(answers, prediction): 110 | f1_scores = [] 111 | for ans in answers: 112 | ans_segs = mixed_segmentation(ans, rm_punc=True) 113 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 114 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 115 | if lcs_len == 0: 116 | f1_scores.append(0) 117 | continue 118 | precision = 1.0*lcs_len/len(prediction_segs) 119 | recall = 1.0*lcs_len/len(ans_segs) 120 | f1 = (2*precision*recall)/(precision+recall) 121 | f1_scores.append(f1) 122 | return max(f1_scores) 123 | 124 | 125 | def calc_em_score(answers, prediction): 126 | em = 0 127 | for ans in answers: 128 | ans_ = remove_punctuation(ans) 129 | prediction_ = remove_punctuation(prediction) 130 | if ans_ == prediction_: 131 | em = 1 132 | break 133 | return em 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 137 | parser.add_argument('dataset_file', help='Official dataset file') 138 | parser.add_argument('prediction_file', help='Your prediction File') 139 | args = parser.parse_args() 140 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 141 | prediction_file = json.load(open(args.prediction_file, 'rb')) 142 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 143 | AVG = (EM+F1)*0.5 144 | output_result = OrderedDict() 145 | output_result['AVERAGE'] = '%.3f' % AVG 146 | output_result['F1'] = '%.3f' % F1 147 | output_result['EM'] = '%.3f' % EM 148 | output_result['TOTAL'] = TOTAL 149 | output_result['SKIP'] = SKIP 150 | output_result['FILE'] = args.prediction_file 151 | print(json.dumps(output_result)) 152 | 153 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/run_classifier_clue.sh: -------------------------------------------------------------------------------- 1 | # @Author: bo.shi 2 | # @Date: 2020-03-15 16:11:00 3 | # @Last Modified by: bo.shi 4 | # @Last Modified time: 2020-03-17 13:06:02 5 | #!/usr/bin/env bash 6 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P) 7 | CLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset 8 | CLUE_PREV_TRAINED_MODEL_DIR=$CURRENT_DIR/prev_trained_models 9 | 10 | download_data(){ 11 | TASK_NAME=$1 12 | if [ ! -d $CLUE_DATA_DIR ]; then 13 | mkdir -p $CLUE_DATA_DIR 14 | echo "makedir $CLUE_DATA_DIR" 15 | fi 16 | cd $CLUE_DATA_DIR 17 | if [ ! -d ${TASK_NAME} ]; then 18 | mkdir $TASK_NAME 19 | echo "make dataset dir $CLUE_DATA_DIR/$TASK_NAME" 20 | fi 21 | cd $TASK_NAME 22 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then 23 | rm * 24 | if [ $TASK_NAME = "wsc" ];then 25 | wget https://storage.googleapis.com/cluebenchmark/tasks/clue${TASK_NAME}2020_public.zip 26 | unzip clue${TASK_NAME}2020_public.zip 27 | rm clue${TASK_NAME}2020_public.zip 28 | else 29 | wget https://storage.googleapis.com/cluebenchmark/tasks/${TASK_NAME}_public.zip 30 | unzip ${TASK_NAME}_public.zip 31 | rm ${TASK_NAME}_public.zip 32 | fi 33 | else 34 | echo "data exists" 35 | fi 36 | echo "Finish download dataset." 37 | } 38 | 39 | download_model(){ 40 | MODEL_NAME=$1 41 | if [ ! -d $CLUE_PREV_TRAINED_MODEL_DIR ]; then 42 | mkdir -p $CLUE_PREV_TRAINED_MODEL_DIR 43 | echo "make prev_trained_model dir $BERT_PRETRAINED_MODELS_DIR" 44 | fi 45 | cd $CLUE_PREV_TRAINED_MODEL_DIR 46 | if [ ! -d $MODEL_NAME ]; then 47 | mkdir -p $MODEL_NAME 48 | else 49 | cd $MODEL_NAME 50 | rm * 51 | if [ "$MODEL_NAME" = "RoBERTa-tiny-clue" ]; then 52 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-clue.zip 53 | unzip RoBERTa-tiny-clue.zip 54 | rm RoBERTa-tiny-clue.zip 55 | elif [ "$MODEL_NAME" = "RoBERTa-tiny-pair" ]; then 56 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-pair.zip 57 | unzip RoBERTa-tiny-pair.zip 58 | rm RoBERTa-tiny-pair.zip 59 | elif [ "$MODEL_NAME" = "RoBERTa-tiny3L768-clue" ]; then 60 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny3L768-clue.zip 61 | unzip RoBERTa-tiny3L768-clue.zip 62 | rm RoBERTa-tiny3L768-clue.zip 63 | elif [ "$MODEL_NAME" = "RoBERTa-tiny3L312-clue" ]; then 64 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny3L312-clue.zip 65 | unzip RoBERTa-tiny3L312-clue.zip 66 | rm RoBERTa-tiny3L312-clue.zip 67 | elif [ "$MODEL_NAME" = "RoBERTa-large-clue" ]; then 68 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-large-clue.zip 69 | unzip RoBERTa-large-clue.zip 70 | rm RoBERTa-large-clue.zip 71 | elif [ "$MODEL_NAME" = "RoBERTa-large-pair" ]; then 72 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-large-pair.zip 73 | unzip RoBERTa-large-pair.zip 74 | rm RoBERTa-large-pair.zip 75 | else 76 | echo "unknown model_name, choose from [ RoBERTa-tiny-clue , RoBERTa-tiny-pair , RoBERTa-tiny3L768-clue , RoBERTa-tiny3L312-clue , RoBERTa-large-clue , RoBERTa-large-pair]" 77 | fi 78 | fi 79 | } 80 | 81 | run_task() { 82 | TASK_NAME=$1 83 | MODEL_NAME=$2 84 | download_data $TASK_NAME 85 | #if [ ! -d $CLUE_PREV_TRAINED_MODEL_DIR/$MODEL_NAME ]; then 86 | # download_model $MODEL_NAME 87 | #fi 88 | DATA_DIR=$CLUE_DATA_DIR/${TASK_NAME} 89 | PREV_TRAINED_MODEL_DIR=$CLUE_PREV_TRAINED_MODEL_DIR/$MODEL_NAME 90 | MAX_SEQ_LENGTH=$3 91 | TRAIN_BATCH_SIZE=$4 92 | LEARNING_RATE=$5 93 | NUM_TRAIN_EPOCHS=$6 94 | SAVE_CHECKPOINTS_STEPS=$7 95 | # TPU_IP=$8 96 | OUTPUT_DIR=$CURRENT_DIR/${TASK_NAME}_output/ 97 | COMMON_ARGS=" 98 | --task_name=$TASK_NAME \ 99 | --data_dir=$DATA_DIR \ 100 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \ 101 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \ 102 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \ 103 | --max_seq_length=$MAX_SEQ_LENGTH \ 104 | --train_batch_size=$TRAIN_BATCH_SIZE \ 105 | --learning_rate=$LEARNING_RATE \ 106 | --num_train_epochs=$NUM_TRAIN_EPOCHS \ 107 | --save_checkpoints_steps=$SAVE_CHECKPOINTS_STEPS \ 108 | --output_dir=$OUTPUT_DIR \ 109 | --keep_checkpoint_max=0 \ 110 | " 111 | cd $CURRENT_DIR 112 | echo "Start running..." 113 | python run_classifier.py \ 114 | $COMMON_ARGS \ 115 | --do_train=true \ 116 | --do_eval=false \ 117 | --do_predict=false 118 | 119 | echo "Start predict..." 120 | python run_classifier.py \ 121 | $COMMON_ARGS \ 122 | --do_train=false \ 123 | --do_eval=true \ 124 | --do_predict=true 125 | } 126 | 127 | 128 | run_task csl RoBERTa-tiny-clue 128 16 1e-5 0 100 129 | run_task cluewsc2020 RoBERTa-tiny-clue 128 16 1e-5 0 10 130 | #run_task csl RoBERTa-tiny-clue 128 16 1e-5 5 100 131 | #run_task wsc RoBERTa-tiny-clue 128 16 1e-5 10 10 132 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import tokenization 22 | import six 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | self.assertFalse(tokenization._is_control(u"\U0001F4A9")) 125 | 126 | def test_is_punctuation(self): 127 | self.assertTrue(tokenization._is_punctuation(u"-")) 128 | self.assertTrue(tokenization._is_punctuation(u"$")) 129 | self.assertTrue(tokenization._is_punctuation(u"`")) 130 | self.assertTrue(tokenization._is_punctuation(u".")) 131 | 132 | self.assertFalse(tokenization._is_punctuation(u"A")) 133 | self.assertFalse(tokenization._is_punctuation(u" ")) 134 | 135 | 136 | if __name__ == "__main__": 137 | tf.test.main() 138 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/data_processor_seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf8 3 | """ 4 | @author: Cong Yu 5 | @time: 2019-12-07 17:03 6 | """ 7 | import json 8 | import tokenization 9 | import collections 10 | import tensorflow as tf 11 | 12 | 13 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 14 | """Truncates a sequence pair in place to the maximum length.""" 15 | 16 | # This is a simple heuristic which will always truncate the longer sequence 17 | # one token at a time. This makes more sense than truncating an equal percent 18 | # of tokens from each, since if one sequence is very short then each token 19 | # that's truncated likely contains more information than a longer sequence. 20 | while True: 21 | total_length = len(tokens_a) + len(tokens_b) 22 | if total_length <= max_length: 23 | break 24 | if len(tokens_a) > len(tokens_b): 25 | tokens_a.pop() 26 | else: 27 | tokens_b.pop() 28 | 29 | 30 | def process_one_example(tokenizer, label2id, text, label, max_seq_len=128): 31 | # textlist = text.split(' ') 32 | # labellist = label.split(' ') 33 | textlist = list(text) 34 | labellist = list(label) 35 | tokens = [] 36 | labels = [] 37 | for i, word in enumerate(textlist): 38 | token = tokenizer.tokenize(word) 39 | tokens.extend(token) 40 | label_1 = labellist[i] 41 | for m in range(len(token)): 42 | if m == 0: 43 | labels.append(label_1) 44 | else: 45 | print("some unknown token...") 46 | labels.append(labels[0]) 47 | # tokens = tokenizer.tokenize(example.text) -2 的原因是因为序列需要加一个句首和句尾标志 48 | if len(tokens) >= max_seq_len - 1: 49 | tokens = tokens[0:(max_seq_len - 2)] 50 | labels = labels[0:(max_seq_len - 2)] 51 | ntokens = [] 52 | segment_ids = [] 53 | label_ids = [] 54 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 55 | segment_ids.append(0) 56 | # [CLS] [SEP] 可以为 他们构建标签,或者 统一到某个标签,反正他们是不变的,基本不参加训练 即:x-l 永远不变 57 | label_ids.append(0) # label2id["[CLS]"] 58 | for i, token in enumerate(tokens): 59 | ntokens.append(token) 60 | segment_ids.append(0) 61 | label_ids.append(label2id[labels[i]]) 62 | ntokens.append("[SEP]") 63 | segment_ids.append(0) 64 | # append("O") or append("[SEP]") not sure! 65 | label_ids.append(0) # label2id["[SEP]"] 66 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 67 | input_mask = [1] * len(input_ids) 68 | while len(input_ids) < max_seq_len: 69 | input_ids.append(0) 70 | input_mask.append(0) 71 | segment_ids.append(0) 72 | label_ids.append(0) 73 | ntokens.append("**NULL**") 74 | assert len(input_ids) == max_seq_len 75 | assert len(input_mask) == max_seq_len 76 | assert len(segment_ids) == max_seq_len 77 | assert len(label_ids) == max_seq_len 78 | 79 | feature = (input_ids, input_mask, segment_ids, label_ids) 80 | return feature 81 | 82 | 83 | def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path): 84 | """ 85 | 生成训练数据, tf.record, 单标签分类模型, 随机打乱数据 86 | """ 87 | writer = tf.python_io.TFRecordWriter(out_path) 88 | example_count = 0 89 | 90 | for line in open(path): 91 | if not line.strip(): 92 | continue 93 | _ = json.loads(line.strip()) 94 | len_ = len(_["text"]) 95 | labels = ["O"] * len_ 96 | for k, v in _["label"].items(): 97 | for kk, vv in v.items(): 98 | for vvv in vv: 99 | span = vvv 100 | s = span[0] 101 | e = span[1] + 1 102 | # print(s, e) 103 | if e - s == 1: 104 | labels[s] = "S_" + k 105 | else: 106 | labels[s] = "B_" + k 107 | for i in range(s + 1, e - 1): 108 | labels[i] = "M_" + k 109 | labels[e - 1] = "E_" + k 110 | # print() 111 | # feature = process_one_example(tokenizer, label2id, row[column_name_x1], row[column_name_y], 112 | # max_seq_len=max_seq_len) 113 | feature = process_one_example(tokenizer, label2id, list(_["text"]), labels, 114 | max_seq_len=max_seq_len) 115 | 116 | def create_int_feature(values): 117 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 118 | return f 119 | 120 | features = collections.OrderedDict() 121 | # 序列标注任务 122 | features["input_ids"] = create_int_feature(feature[0]) 123 | features["input_mask"] = create_int_feature(feature[1]) 124 | features["segment_ids"] = create_int_feature(feature[2]) 125 | features["label_ids"] = create_int_feature(feature[3]) 126 | if example_count < 5: 127 | print("*** Example ***") 128 | print(_["text"]) 129 | print(_["label"]) 130 | print("input_ids: %s" % " ".join([str(x) for x in feature[0]])) 131 | print("input_mask: %s" % " ".join([str(x) for x in feature[1]])) 132 | print("segment_ids: %s" % " ".join([str(x) for x in feature[2]])) 133 | print("label: %s " % " ".join([str(x) for x in feature[3]])) 134 | 135 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 136 | writer.write(tf_example.SerializeToString()) 137 | example_count += 1 138 | 139 | # if example_count == 20: 140 | # break 141 | if example_count % 3000 == 0: 142 | print(example_count) 143 | print("total example:", example_count) 144 | writer.close() 145 | 146 | 147 | if __name__ == "__main__": 148 | vocab_file = "./vocab.txt" 149 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) 150 | label2id = json.loads(open("label2id.json").read()) 151 | 152 | max_seq_len = 64 153 | 154 | prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json", 155 | out_path="data/train.tf_record") 156 | prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json", 157 | out_path="data/dev.tf_record") 158 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/predict_sequence_label.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf8 3 | """ 4 | @author: Cong Yu 5 | @time: 2019-12-07 20:51 6 | """ 7 | import os 8 | import re 9 | import json 10 | import tensorflow as tf 11 | import tokenization 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | vocab_file = "./vocab.txt" 16 | tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) 17 | label2id = json.loads(open("./label2id.json").read()) 18 | id2label = [k for k, v in label2id.items()] 19 | 20 | 21 | def process_one_example_p(tokenizer, text, max_seq_len=128): 22 | textlist = list(text) 23 | tokens = [] 24 | # labels = [] 25 | for i, word in enumerate(textlist): 26 | token = tokenizer.tokenize(word) 27 | # print(token) 28 | tokens.extend(token) 29 | if len(tokens) >= max_seq_len - 1: 30 | tokens = tokens[0:(max_seq_len - 2)] 31 | # labels = labels[0:(max_seq_len - 2)] 32 | ntokens = [] 33 | segment_ids = [] 34 | label_ids = [] 35 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 36 | segment_ids.append(0) 37 | for i, token in enumerate(tokens): 38 | ntokens.append(token) 39 | segment_ids.append(0) 40 | # label_ids.append(label2id[labels[i]]) 41 | ntokens.append("[SEP]") 42 | segment_ids.append(0) 43 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 44 | input_mask = [1] * len(input_ids) 45 | while len(input_ids) < max_seq_len: 46 | input_ids.append(0) 47 | input_mask.append(0) 48 | segment_ids.append(0) 49 | label_ids.append(0) 50 | ntokens.append("**NULL**") 51 | assert len(input_ids) == max_seq_len 52 | assert len(input_mask) == max_seq_len 53 | assert len(segment_ids) == max_seq_len 54 | 55 | feature = (input_ids, input_mask, segment_ids) 56 | return feature 57 | 58 | 59 | def load_model(model_folder): 60 | # We retrieve our checkpoint fullpath 61 | try: 62 | checkpoint = tf.train.get_checkpoint_state(model_folder) 63 | input_checkpoint = checkpoint.model_checkpoint_path 64 | print("[INFO] input_checkpoint:", input_checkpoint) 65 | except Exception as e: 66 | input_checkpoint = model_folder 67 | print("[INFO] Model folder", model_folder, repr(e)) 68 | 69 | # We clear devices to allow TensorFlow to control on which device it will load operations 70 | clear_devices = True 71 | tf.reset_default_graph() 72 | # We import the meta graph and retrieve a Saver 73 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) 74 | 75 | # We start a session and restore the graph weights 76 | sess_ = tf.Session() 77 | saver.restore(sess_, input_checkpoint) 78 | 79 | # opts = sess_.graph.get_operations() 80 | # for v in opts: 81 | # print(v.name) 82 | return sess_ 83 | 84 | 85 | model_path = "./ner_bert_base/" 86 | sess = load_model(model_path) 87 | input_ids = sess.graph.get_tensor_by_name("input_ids:0") 88 | input_mask = sess.graph.get_tensor_by_name("input_mask:0") # is_training 89 | segment_ids = sess.graph.get_tensor_by_name("segment_ids:0") # fc/dense/Relu cnn_block/Reshape 90 | keep_prob = sess.graph.get_tensor_by_name("keep_prob:0") 91 | p = sess.graph.get_tensor_by_name("loss/ReverseSequence_1:0") 92 | 93 | 94 | def predict(text): 95 | data = [text] 96 | # 逐个分成 最大62长度的 text 进行 batch 预测 97 | features = [] 98 | for i in data: 99 | feature = process_one_example_p(tokenizer_, i, max_seq_len=64) 100 | features.append(feature) 101 | feed = {input_ids: [feature[0] for feature in features], 102 | input_mask: [feature[1] for feature in features], 103 | segment_ids: [feature[2] for feature in features], 104 | keep_prob: 1.0 105 | } 106 | 107 | [probs] = sess.run([p], feed) 108 | result = [] 109 | for index, prob in enumerate(probs): 110 | for v in prob[1:len(data[index]) + 1]: 111 | result.append(id2label[int(v)]) 112 | print(result) 113 | labels = {} 114 | start = None 115 | index = 0 116 | for w, t in zip("".join(data), result): 117 | if re.search("^[BS]", t): 118 | if start is not None: 119 | label = result[index - 1][2:] 120 | if labels.get(label): 121 | te_ = text[start:index] 122 | # print(te_, labels) 123 | labels[label][te_] = [[start, index - 1]] 124 | else: 125 | te_ = text[start:index] 126 | # print(te_, labels) 127 | labels[label] = {te_: [[start, index - 1]]} 128 | start = index 129 | # print(start) 130 | if re.search("^O", t): 131 | if start is not None: 132 | # print(start) 133 | label = result[index - 1][2:] 134 | if labels.get(label): 135 | te_ = text[start:index] 136 | # print(te_, labels) 137 | labels[label][te_] = [[start, index - 1]] 138 | else: 139 | te_ = text[start:index] 140 | # print(te_, labels) 141 | labels[label] = {te_: [[start, index - 1]]} 142 | # else: 143 | # print(start, labels) 144 | start = None 145 | index += 1 146 | if start is not None: 147 | # print(start) 148 | label = result[start][2:] 149 | if labels.get(label): 150 | te_ = text[start:index] 151 | # print(te_, labels) 152 | labels[label][te_] = [[start, index - 1]] 153 | else: 154 | te_ = text[start:index] 155 | # print(te_, labels) 156 | labels[label] = {te_: [[start, index - 1]]} 157 | # print(labels) 158 | return labels 159 | 160 | 161 | def submit(path): 162 | data = [] 163 | for line in open(path): 164 | if not line.strip(): 165 | continue 166 | _ = json.loads(line.strip()) 167 | res = predict(_["text"]) 168 | data.append(json.dumps({"label": res}, ensure_ascii=False)) 169 | open("ner_predict.json", "w").write("\n".join(data)) 170 | 171 | 172 | if __name__ == "__main__": 173 | text_ = "梅塔利斯在乌克兰联赛、杯赛及联盟杯中保持9场不败,状态相当出色;" 174 | res_ = predict(text_) 175 | print(res_) 176 | 177 | submit("data/thuctc_valid.json") 178 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/optimization.py: -------------------------------------------------------------------------------- 1 | """Functions and classes related to optimization (weight updates).""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import re 8 | import tensorflow as tf 9 | 10 | 11 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 12 | """Creates an optimizer training op.""" 13 | global_step = tf.train.get_or_create_global_step() 14 | 15 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 16 | 17 | # Implements linear decay of the learning rate. 18 | learning_rate = tf.train.polynomial_decay( 19 | learning_rate, 20 | global_step, 21 | num_train_steps, 22 | end_learning_rate=0.0, 23 | power=1.0, 24 | cycle=False) 25 | 26 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 27 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 28 | if num_warmup_steps: 29 | global_steps_int = tf.cast(global_step, tf.int32) 30 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 31 | 32 | global_steps_float = tf.cast(global_steps_int, tf.float32) 33 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 34 | 35 | warmup_percent_done = global_steps_float / warmup_steps_float 36 | warmup_learning_rate = init_lr * warmup_percent_done 37 | 38 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 39 | learning_rate = ( 40 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 41 | 42 | # It is recommended that you use this optimizer for fine tuning, since this 43 | # is how the model was trained (note that the Adam m/v variables are NOT 44 | # loaded from init_checkpoint.) 45 | optimizer = AdamWeightDecayOptimizer( 46 | learning_rate=learning_rate, 47 | weight_decay_rate=0.01, 48 | beta_1=0.9, 49 | beta_2=0.999, 50 | epsilon=1e-6, 51 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 52 | 53 | if use_tpu: 54 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 55 | 56 | tvars = tf.trainable_variables() 57 | grads = tf.gradients(loss, tvars) 58 | 59 | # This is how the model was pre-trained. 60 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 61 | 62 | train_op = optimizer.apply_gradients( 63 | zip(grads, tvars), global_step=global_step) 64 | 65 | # Normally the global step update is done inside of `apply_gradients`. 66 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 67 | # a different optimizer, you should probably take this line out. 68 | new_global_step = global_step + 1 69 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 70 | return train_op, learning_rate 71 | 72 | 73 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 74 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 75 | 76 | def __init__(self, 77 | learning_rate, 78 | weight_decay_rate=0.0, 79 | beta_1=0.9, 80 | beta_2=0.999, 81 | epsilon=1e-6, 82 | exclude_from_weight_decay=None, 83 | name="AdamWeightDecayOptimizer"): 84 | """Constructs a AdamWeightDecayOptimizer.""" 85 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 86 | 87 | self.learning_rate = learning_rate 88 | self.weight_decay_rate = weight_decay_rate 89 | self.beta_1 = beta_1 90 | self.beta_2 = beta_2 91 | self.epsilon = epsilon 92 | self.exclude_from_weight_decay = exclude_from_weight_decay 93 | 94 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 95 | """See base class.""" 96 | assignments = [] 97 | for (grad, param) in grads_and_vars: 98 | if grad is None or param is None: 99 | continue 100 | 101 | param_name = self._get_variable_name(param.name) 102 | 103 | m = tf.get_variable( 104 | name=param_name + "/adam_m", 105 | shape=param.shape.as_list(), 106 | dtype=tf.float32, 107 | trainable=False, 108 | initializer=tf.zeros_initializer()) 109 | v = tf.get_variable( 110 | name=param_name + "/adam_v", 111 | shape=param.shape.as_list(), 112 | dtype=tf.float32, 113 | trainable=False, 114 | initializer=tf.zeros_initializer()) 115 | 116 | # Standard Adam update. 117 | next_m = ( 118 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 119 | next_v = ( 120 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 121 | tf.square(grad))) 122 | 123 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 124 | 125 | # Just adding the square of the weights to the loss function is *not* 126 | # the correct way of using L2 regularization/weight decay with Adam, 127 | # since that will interact with the m and v parameters in strange ways. 128 | # 129 | # Instead we want ot decay the weights in a manner that doesn't interact 130 | # with the m/v parameters. This is equivalent to adding the square 131 | # of the weights to the loss with plain (non-momentum) SGD. 132 | if self._do_use_weight_decay(param_name): 133 | update += self.weight_decay_rate * param 134 | 135 | update_with_lr = self.learning_rate * update 136 | 137 | next_param = param - update_with_lr 138 | 139 | assignments.extend( 140 | [param.assign(next_param), 141 | m.assign(next_m), 142 | v.assign(next_v)]) 143 | return tf.group(*assignments, name=name) 144 | 145 | def _do_use_weight_decay(self, param_name): 146 | """Whether to use L2 weight decay for `param_name`.""" 147 | if not self.weight_decay_rate: 148 | return False 149 | if self.exclude_from_weight_decay: 150 | for r in self.exclude_from_weight_decay: 151 | if re.search(r, param_name) is not None: 152 | return False 153 | return True 154 | 155 | def _get_variable_name(self, param_name): 156 | """Get the variable name from the tensor name.""" 157 | m = re.match("^(.*):\\d+$", param_name) 158 | if m is not None: 159 | param_name = m.group(1) 160 | return param_name 161 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightLM 高性能小模型测评 2 | 3 | NLPCC 2020 测评任务 Shared Tasks in NLPCC 2020. 4 | 5 | 高性能小模型测评 Task 1 - Light Pre-Training Chinese Language Model for NLP Task 6 | 7 | 现已可以提交到排行榜,提交样例见:提交样例 8 | 9 | 注册地址: 10 | 11 | https://www.cluebenchmarks.com/NLPCC.html 12 | 13 | NLPCC 2020 网站: 14 | 15 | http://tcci.ccf.org.cn/conference/2020/cfpt.php 16 | 17 | 任务guideline: 18 | 19 | http://tcci.ccf.org.cn/conference/2020/dldoc/taskgline01.pdf 20 | 21 | 比赛交流群: 22 | 23 | ![img](https://github.com/CLUEbenchmark/LightLM/blob/master/img/shared_task_QA_code) 24 | 25 | ## 任务介绍 Task overview 26 | 27 | **Task 1 - 高性能小模型测评 Light Pre-Training Chinese Language Model for NLP Task** 28 | 29 | 这个任务的目标是训练一个和正常大小的语言模型效果相似的轻量级的语言模型。每个提交上来的模型都会在多个不同的下游NLP任务上评估性能。我们将会综合考虑模型参数数量,模型准确率以及模型推理时间,这些将一起作为模型的评估标准。 30 | 31 | The goal of this task is to train a light language model which is still as powerful as the other normal models. Each model will be tested on many different downstream NLP tasks. We would take the number of parameters, accuracy and inference time as the metrics to measure the performance of a model. To meet the challenge of the lack of Chinese corpus, we will provide a big Chinese corpus for this task and will release them for all the researchers later. 32 | 33 | 为了满足很多参赛者对中文语料缺乏的情况,我们提供了目前为止最大中文语料库作为这个任务的补充资源。这些语料将会在之后公布给大家。 34 | 35 | ## **如何报名 How to Participate** 36 | 37 | 任务注册方式: 38 | 39 | (1.1) 访问[www.CLUEbenchmark.com](http://www.CLUEbenchmark.com), 右上角点击【注册】并登录。 40 | 41 | (1.2) 进入【NLPCC测评】tab页,选中【注册栏】后进行比赛注册,并点击提交。 42 | 43 | Registration online with the following steps: 44 | 45 | (1.1) Visit www.CLUEbenchmark.com, and click the button 【注册】 at the top right corner of the page. After that, please log in. 46 | 47 | (1.2) After selecting the【NLPCC测评】in the top navigation bar, please register our task in 【比赛注册】. 48 | 49 | # 本次测评的基线模型、代码和一键运行脚本 50 | baselines/models目录下提供了本次比赛中四个任务的预测脚本,使用方法如下: 51 | ``` 52 | 可以进入baselines/models目录下 53 | 1. bert_wsc_csl目录下为WSC和CSL任务: bash run_classifier_clue.sh 54 | - 结果生成在wsc_ouput和csl_output目录下 55 | 2. bert_mrc 为 cmrc2018 任务baseline: bash run_ner_cmrc.sh 56 | - 结果生成为cmrc2018_predict.json 57 | 3. bert_ner 为 CLUENER任务 baseline: bash run_ner.sh 58 | - 结果生成在ner_output下 59 | 60 | ``` 61 | 只要加载相应的模型,即可进行预测。本demo中会自动使用"RoBERTa-tiny-clue"进行预测 62 | 63 | 提交指导:将以上的结果,打包为nlpcc-xxx.zip 进行提交。注意各个输出文件的名称是不能变的,要保留,不然系统无法识别。提交示例请见example。 64 | 65 | CLUENER2020、WSC、CSL: 见CLUEPretrainedModels 66 | 67 | CMRC 2018: 见CLUE 68 | 69 | # 任务描述 Dataset Description 70 | 71 | ##### 1. CLUENER2020 细粒度命名实体识别 [详情](https://github.com/CLUEbenchmark/CLUENER2020) [NER数据下载]([数据下载](http://www.cluebenchmark.com/introduce.html)) 72 | 73 | ``` 74 | 训练集:10748 75 | 验证集集:1343 76 | 77 | 按照不同标签类别统计,训练集数据分布如下(注:一条数据中出现的所有实体都进行标注,如果一条数据出现两个地址(address)实体,那么统计地址(address)类别数据的时候,算两条数据): 78 | 【训练集】标签数据分布如下: 79 | 地址(address):2829 80 | 书名(book):1131 81 | 公司(company):2897 82 | 游戏(game):2325 83 | 政府(government):1797 84 | 电影(movie):1109 85 | 姓名(name):3661 86 | 组织机构(organization):3075 87 | 职位(position):3052 88 | 景点(scene):1462 89 | 90 | 【验证集】标签数据分布如下: 91 | 地址(address):364 92 | 书名(book):152 93 | 公司(company):366 94 | 游戏(game):287 95 | 政府(government):244 96 | 电影(movie):150 97 | 姓名(name):451 98 | 组织机构(organization):344 99 | 职位(position):425 100 | 景点(scene):199 101 | ``` 102 | 103 | ##### 2. CLUEWSC2020: WSC Winograd模式挑战中文版,新版2020-03-25发布 CLUEWSC2020数据集下载 104 | 105 | Winograd Scheme Challenge(WSC)是一类代词消歧的任务。新版与原CLUE项目WSC内容不同 106 | 107 | 即判断句子中的代词指代的是哪个名词。题目以真假判别的方式出现,如: 108 | 109 | 句子:这时候放在床上枕头旁边的手机响了,我感到奇怪,因为欠费已被停机两个月,现在它突然响了。需要判断“它”指代的是“床”、“枕头”,还是“手机”? 110 | 111 | 数据来源:数据有CLUE benchmark提供,从中国现当代作家文学作品中抽取,再经语言专家人工挑选、标注。 112 | 113 | 数据形式: 114 | 115 | {"target": 116 | {"span2_index": 37, 117 | "span1_index": 5, 118 | "span1_text": "床", 119 | "span2_text": "它"}, 120 | "idx": 261, 121 | "label": "false", 122 | "text": "这时候放在床上枕头旁边的手机响了,我感到奇怪,因为欠费已被停机两个月,现在它突然响了。"} 123 | "true"表示代词确实是指代span1_text中的名词的,"false"代表不是。 124 | 125 | 数据集大小: 126 | - 训练集:1244 127 | - 开发集:304 128 | 129 | ##### 3. CSL 论文关键词识别 Keyword Recognition [详情](https://github.com/CLUEbenchmark/CLUE) [CSL数据集下载](https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip) 130 | 131 | [中文科技文献数据集(CSL)](https://github.com/P01son6415/chinese-scientific-literature-dataset)取自中文论文摘要及其关键词,论文选自部分中文社会科学和自然科学核心期刊。 使用tf-idf生成伪造关键词与论文真实关键词混合,构造摘要-关键词对,任务目标是根据摘要判断关键词是否全部为真实关键词。 132 | 133 | ``` 134 | 数据量:训练集(20,000),验证集(3,000),测试集(3,000) 135 | 例子: 136 | {"id": 1, "abst": "为解决传统均匀FFT波束形成算法引起的3维声呐成像分辨率降低的问题,该文提出分区域FFT波束形成算法.远场条件下,以保证成像分辨率为约束条件,以划分数量最少为目标,采用遗传算法作为优化手段将成像区域划分为多个区域.在每个区域内选取一个波束方向,获得每一个接收阵元收到该方向回波时的解调输出,以此为原始数据在该区域内进行传统均匀FFT波束形成.对FFT计算过程进行优化,降低新算法的计算量,使其满足3维成像声呐实时性的要求.仿真与实验结果表明,采用分区域FFT波束形成算法的成像分辨率较传统均匀FFT波束形成算法有显著提高,且满足实时性要求.", "keyword": ["水声学", "FFT", "波束形成", "3维成像声呐"], "label": "1"} 137 | 每一条数据有四个属性,从前往后分别是 数据ID,论文摘要,关键词,真假标签。 138 | 139 | ``` 140 | 141 | 142 | ##### 4.CMRC2018 简体中文阅读理解任务 Reading Comprehension for Simplified Chinese [详情](https://github.com/CLUEbenchmark/CLUE) [CMRC2018数据集下载](https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip) 143 | 144 | https://hfl-rc.github.io/cmrc2018/ 145 | 146 | ``` 147 | 数据量:训练集(短文数2,403,问题数10,142),试验集(短文数256,问题数1,002),开发集(短文数848,问题数3,219) 148 | 例子: 149 | { 150 | "version": "1.0", 151 | "data": [ 152 | { 153 | "title": "傻钱策略", 154 | "context_id": "TRIAL_0", 155 | "context_text": "工商协进会报告,12月消费者信心上升到78.1,明显高于11月的72。另据《华尔街日报》报道,2013年是1995年以来美国股市表现最好的一年。这一年里,投资美国股市的明智做法是追着“傻钱”跑。所谓的“傻钱”策略,其实就是买入并持有美国股票这样的普通组合。这个策略要比对冲基金和其它专业投资者使用的更为复杂的投资方法效果好得多。", 156 | "qas":[ 157 | { 158 | "query_id": "TRIAL_0_QUERY_0", 159 | "query_text": "什么是傻钱策略?", 160 | "answers": [ 161 | "所谓的“傻钱”策略,其实就是买入并持有美国股票这样的普通组合", 162 | "其实就是买入并持有美国股票这样的普通组合", 163 | "买入并持有美国股票这样的普通组合" 164 | ] 165 | }, 166 | { 167 | "query_id": "TRIAL_0_QUERY_1", 168 | "query_text": "12月的消费者信心指数是多少?", 169 | "answers": [ 170 | "78.1", 171 | "78.1", 172 | "78.1" 173 | ] 174 | }, 175 | { 176 | "query_id": "TRIAL_0_QUERY_2", 177 | "query_text": "消费者信心指数由什么机构发布?", 178 | "answers": [ 179 | "工商协进会", 180 | "工商协进会", 181 | "工商协进会" 182 | ] 183 | } 184 | ] 185 | } 186 | ] 187 | } 188 | ``` 189 | 190 | ## 双周赛奖励 191 | 192 | 为了鼓励大家的参与,我们决定设立双周赛奖励,时间定在 4.19、5.3、5.17这三个节点。每只队伍只能领取不超过两次双周赛奖励。 193 | 奖励: 194 | 195 | 1. 1000元人民币整 196 | 2. CLUE周赛证书(可以在线访问) 197 | 3. 我们提供提供200G预训练语料给周赛获奖队伍 198 | 199 | 获奖要求: 200 | 201 | 1. 在4.18, 5.3, 5.16 三个时间节点早上8点前处于第一名的队伍 202 | 2. 需要在第二天的时候做一次关于小模型训练的分享,可以有选择性的分享一些比赛相关的技巧 203 | 204 | NOTE: 205 | 206 | 1. 以上两点都是获奖的必要条件 207 | 2. 如果第一名重复,则顺延 208 | 209 | ## Reference 210 | 211 | Information in this page we refered to: 212 | 213 | [1] http://tcci.ccf.org.cn/conference/2019/taskdata.php 214 | 215 | ## NOTE 216 | 217 | Any question, please contact us via CLUEbenchmark@163.com or just open an issue. 218 | 219 | ## 结果: 220 | 221 | 首先感谢大家本次的参与,我们本次比赛的评定规则如下: 222 | 1. 由于测速涉及到硬件条件的限制,所以我们将依据参赛同学自己测定的速度作为模型是否符合规范的依据。 223 | 2. 最终分数评定由我们实际测算的速度和参数量按照任务指南中的公式进行计算,分数可能会存在误差,以不影响名次为可接受误差范围内。 224 | 3. 具体名次如下: 225 | 第一名:Huawei Cloud & Noah's Ark lab,得分:0.777126778 226 | 第二名:Tencent Oteam,得分:0.768969257 227 | 第三名:Xiaomi AI Lab,得分:0.758543871 228 | 结果公示三天,如果有问题可以和我们联系。 229 | 本次比赛是CLUE第一次正式举办的对外比赛,不免有很多疏漏之处,也非常感谢得到来自各个学校、公司的各位同学大神的指导和帮助。 230 | 本次比赛的其他信息会在最终的评估报告中有详细的解释。 231 | CLUE保留对结果的最终解释权。关于奖励发放的事宜会在后续和各位获奖者沟通。 232 | 恭喜这些同学,也非常感谢各位的支持。 233 | 234 | ## **关键日期 Important dates** 235 | 236 | - [x] 2020/03/10:announcement of shared tasks and call for participation; 237 | - [x] 2020/03/10:registration open; 238 | - [x] 2020/03/25:release of detailed task guidelines & training data; 239 | - [x] 2020/04/06:release of baseline and start the competition 240 | 241 | - [x] 2020/05/01:registration deadline; 242 | - [x] 2020/05/15:release of test data; 243 | - [x] 2020/05/20:participants’ results submission deadline; 244 | - [x] 2020/05/30:evaluation results release and call for system reports and conference paper; 245 | - [ ] 2020/06/30:conference paper submission deadline (only for shared tasks); 246 | - [ ] 2020/07/30:conference paper accept/reject notification; 247 | - [ ] 2020/08/10:camera-ready paper submission deadline; 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() -------------------------------------------------------------------------------- /example/cluewsc2020_predict.json: -------------------------------------------------------------------------------- 1 | {"id": 0, "label": "false"} 2 | {"id": 1, "label": "false"} 3 | {"id": 2, "label": "false"} 4 | {"id": 3, "label": "false"} 5 | {"id": 4, "label": "false"} 6 | {"id": 5, "label": "false"} 7 | {"id": 6, "label": "false"} 8 | {"id": 7, "label": "false"} 9 | {"id": 8, "label": "false"} 10 | {"id": 9, "label": "false"} 11 | {"id": 10, "label": "false"} 12 | {"id": 11, "label": "false"} 13 | {"id": 12, "label": "false"} 14 | {"id": 13, "label": "false"} 15 | {"id": 14, "label": "false"} 16 | {"id": 15, "label": "false"} 17 | {"id": 16, "label": "false"} 18 | {"id": 17, "label": "false"} 19 | {"id": 18, "label": "false"} 20 | {"id": 19, "label": "false"} 21 | {"id": 20, "label": "false"} 22 | {"id": 21, "label": "false"} 23 | {"id": 22, "label": "false"} 24 | {"id": 23, "label": "false"} 25 | {"id": 24, "label": "false"} 26 | {"id": 25, "label": "false"} 27 | {"id": 26, "label": "false"} 28 | {"id": 27, "label": "false"} 29 | {"id": 28, "label": "false"} 30 | {"id": 29, "label": "false"} 31 | {"id": 30, "label": "false"} 32 | {"id": 31, "label": "false"} 33 | {"id": 32, "label": "false"} 34 | {"id": 33, "label": "false"} 35 | {"id": 34, "label": "false"} 36 | {"id": 35, "label": "false"} 37 | {"id": 36, "label": "false"} 38 | {"id": 37, "label": "false"} 39 | {"id": 38, "label": "false"} 40 | {"id": 39, "label": "false"} 41 | {"id": 40, "label": "false"} 42 | {"id": 41, "label": "false"} 43 | {"id": 42, "label": "false"} 44 | {"id": 43, "label": "false"} 45 | {"id": 44, "label": "false"} 46 | {"id": 45, "label": "false"} 47 | {"id": 46, "label": "false"} 48 | {"id": 47, "label": "false"} 49 | {"id": 48, "label": "false"} 50 | {"id": 49, "label": "false"} 51 | {"id": 50, "label": "false"} 52 | {"id": 51, "label": "false"} 53 | {"id": 52, "label": "false"} 54 | {"id": 53, "label": "false"} 55 | {"id": 54, "label": "false"} 56 | {"id": 55, "label": "false"} 57 | {"id": 56, "label": "false"} 58 | {"id": 57, "label": "false"} 59 | {"id": 58, "label": "false"} 60 | {"id": 59, "label": "false"} 61 | {"id": 60, "label": "false"} 62 | {"id": 61, "label": "false"} 63 | {"id": 62, "label": "false"} 64 | {"id": 63, "label": "false"} 65 | {"id": 64, "label": "false"} 66 | {"id": 65, "label": "false"} 67 | {"id": 66, "label": "false"} 68 | {"id": 67, "label": "false"} 69 | {"id": 68, "label": "false"} 70 | {"id": 69, "label": "false"} 71 | {"id": 70, "label": "false"} 72 | {"id": 71, "label": "false"} 73 | {"id": 72, "label": "false"} 74 | {"id": 73, "label": "false"} 75 | {"id": 74, "label": "false"} 76 | {"id": 75, "label": "false"} 77 | {"id": 76, "label": "false"} 78 | {"id": 77, "label": "false"} 79 | {"id": 78, "label": "false"} 80 | {"id": 79, "label": "false"} 81 | {"id": 80, "label": "false"} 82 | {"id": 81, "label": "false"} 83 | {"id": 82, "label": "false"} 84 | {"id": 83, "label": "false"} 85 | {"id": 84, "label": "false"} 86 | {"id": 85, "label": "false"} 87 | {"id": 86, "label": "false"} 88 | {"id": 87, "label": "false"} 89 | {"id": 88, "label": "false"} 90 | {"id": 89, "label": "false"} 91 | {"id": 90, "label": "false"} 92 | {"id": 91, "label": "false"} 93 | {"id": 92, "label": "false"} 94 | {"id": 93, "label": "false"} 95 | {"id": 94, "label": "false"} 96 | {"id": 95, "label": "false"} 97 | {"id": 96, "label": "false"} 98 | {"id": 97, "label": "false"} 99 | {"id": 98, "label": "false"} 100 | {"id": 99, "label": "false"} 101 | {"id": 100, "label": "false"} 102 | {"id": 101, "label": "false"} 103 | {"id": 102, "label": "false"} 104 | {"id": 103, "label": "false"} 105 | {"id": 104, "label": "false"} 106 | {"id": 105, "label": "false"} 107 | {"id": 106, "label": "false"} 108 | {"id": 107, "label": "false"} 109 | {"id": 108, "label": "false"} 110 | {"id": 109, "label": "false"} 111 | {"id": 110, "label": "false"} 112 | {"id": 111, "label": "false"} 113 | {"id": 112, "label": "false"} 114 | {"id": 113, "label": "false"} 115 | {"id": 114, "label": "false"} 116 | {"id": 115, "label": "false"} 117 | {"id": 116, "label": "false"} 118 | {"id": 117, "label": "false"} 119 | {"id": 118, "label": "false"} 120 | {"id": 119, "label": "false"} 121 | {"id": 120, "label": "false"} 122 | {"id": 121, "label": "false"} 123 | {"id": 122, "label": "false"} 124 | {"id": 123, "label": "false"} 125 | {"id": 124, "label": "false"} 126 | {"id": 125, "label": "false"} 127 | {"id": 126, "label": "false"} 128 | {"id": 127, "label": "false"} 129 | {"id": 128, "label": "false"} 130 | {"id": 129, "label": "false"} 131 | {"id": 130, "label": "false"} 132 | {"id": 131, "label": "false"} 133 | {"id": 132, "label": "false"} 134 | {"id": 133, "label": "false"} 135 | {"id": 134, "label": "false"} 136 | {"id": 135, "label": "false"} 137 | {"id": 136, "label": "false"} 138 | {"id": 137, "label": "false"} 139 | {"id": 138, "label": "false"} 140 | {"id": 139, "label": "false"} 141 | {"id": 140, "label": "false"} 142 | {"id": 141, "label": "false"} 143 | {"id": 142, "label": "false"} 144 | {"id": 143, "label": "false"} 145 | {"id": 144, "label": "false"} 146 | {"id": 145, "label": "false"} 147 | {"id": 146, "label": "false"} 148 | {"id": 147, "label": "false"} 149 | {"id": 148, "label": "false"} 150 | {"id": 149, "label": "false"} 151 | {"id": 150, "label": "false"} 152 | {"id": 151, "label": "false"} 153 | {"id": 152, "label": "false"} 154 | {"id": 153, "label": "false"} 155 | {"id": 154, "label": "false"} 156 | {"id": 155, "label": "false"} 157 | {"id": 156, "label": "false"} 158 | {"id": 157, "label": "false"} 159 | {"id": 158, "label": "false"} 160 | {"id": 159, "label": "false"} 161 | {"id": 160, "label": "false"} 162 | {"id": 161, "label": "false"} 163 | {"id": 162, "label": "false"} 164 | {"id": 163, "label": "false"} 165 | {"id": 164, "label": "false"} 166 | {"id": 165, "label": "false"} 167 | {"id": 166, "label": "false"} 168 | {"id": 167, "label": "false"} 169 | {"id": 168, "label": "false"} 170 | {"id": 169, "label": "false"} 171 | {"id": 170, "label": "false"} 172 | {"id": 171, "label": "false"} 173 | {"id": 172, "label": "false"} 174 | {"id": 173, "label": "false"} 175 | {"id": 174, "label": "false"} 176 | {"id": 175, "label": "false"} 177 | {"id": 176, "label": "false"} 178 | {"id": 177, "label": "false"} 179 | {"id": 178, "label": "false"} 180 | {"id": 179, "label": "false"} 181 | {"id": 180, "label": "false"} 182 | {"id": 181, "label": "false"} 183 | {"id": 182, "label": "false"} 184 | {"id": 183, "label": "false"} 185 | {"id": 184, "label": "false"} 186 | {"id": 185, "label": "false"} 187 | {"id": 186, "label": "false"} 188 | {"id": 187, "label": "true"} 189 | {"id": 188, "label": "true"} 190 | {"id": 189, "label": "false"} 191 | {"id": 190, "label": "false"} 192 | {"id": 191, "label": "false"} 193 | {"id": 192, "label": "false"} 194 | {"id": 193, "label": "false"} 195 | {"id": 194, "label": "false"} 196 | {"id": 195, "label": "false"} 197 | {"id": 196, "label": "false"} 198 | {"id": 197, "label": "false"} 199 | {"id": 198, "label": "false"} 200 | {"id": 199, "label": "false"} 201 | {"id": 200, "label": "false"} 202 | {"id": 201, "label": "false"} 203 | {"id": 202, "label": "false"} 204 | {"id": 203, "label": "false"} 205 | {"id": 204, "label": "false"} 206 | {"id": 205, "label": "false"} 207 | {"id": 206, "label": "false"} 208 | {"id": 207, "label": "false"} 209 | {"id": 208, "label": "false"} 210 | {"id": 209, "label": "false"} 211 | {"id": 210, "label": "false"} 212 | {"id": 211, "label": "false"} 213 | {"id": 212, "label": "false"} 214 | {"id": 213, "label": "false"} 215 | {"id": 214, "label": "false"} 216 | {"id": 215, "label": "false"} 217 | {"id": 216, "label": "false"} 218 | {"id": 217, "label": "false"} 219 | {"id": 218, "label": "false"} 220 | {"id": 219, "label": "false"} 221 | {"id": 220, "label": "false"} 222 | {"id": 221, "label": "false"} 223 | {"id": 222, "label": "false"} 224 | {"id": 223, "label": "false"} 225 | {"id": 224, "label": "false"} 226 | {"id": 225, "label": "false"} 227 | {"id": 226, "label": "false"} 228 | {"id": 227, "label": "false"} 229 | {"id": 228, "label": "false"} 230 | {"id": 229, "label": "false"} 231 | {"id": 230, "label": "false"} 232 | {"id": 231, "label": "false"} 233 | {"id": 232, "label": "false"} 234 | {"id": 233, "label": "false"} 235 | {"id": 234, "label": "false"} 236 | {"id": 235, "label": "false"} 237 | {"id": 236, "label": "false"} 238 | {"id": 237, "label": "false"} 239 | {"id": 238, "label": "false"} 240 | {"id": 239, "label": "false"} 241 | {"id": 240, "label": "false"} 242 | {"id": 241, "label": "false"} 243 | {"id": 242, "label": "false"} 244 | {"id": 243, "label": "false"} 245 | {"id": 244, "label": "false"} 246 | {"id": 245, "label": "false"} 247 | {"id": 246, "label": "false"} 248 | {"id": 247, "label": "false"} 249 | {"id": 248, "label": "false"} 250 | {"id": 249, "label": "false"} 251 | {"id": 250, "label": "false"} 252 | {"id": 251, "label": "false"} 253 | {"id": 252, "label": "false"} 254 | {"id": 253, "label": "false"} 255 | {"id": 254, "label": "false"} 256 | {"id": 255, "label": "false"} 257 | {"id": 256, "label": "false"} 258 | {"id": 257, "label": "false"} 259 | {"id": 258, "label": "false"} 260 | {"id": 259, "label": "false"} 261 | {"id": 260, "label": "false"} 262 | {"id": 261, "label": "false"} 263 | {"id": 262, "label": "false"} 264 | {"id": 263, "label": "false"} 265 | {"id": 264, "label": "false"} 266 | {"id": 265, "label": "false"} 267 | {"id": 266, "label": "false"} 268 | {"id": 267, "label": "false"} 269 | {"id": 268, "label": "false"} 270 | {"id": 269, "label": "false"} 271 | {"id": 270, "label": "false"} 272 | {"id": 271, "label": "false"} 273 | {"id": 272, "label": "false"} 274 | {"id": 273, "label": "false"} 275 | {"id": 274, "label": "false"} 276 | {"id": 275, "label": "false"} 277 | {"id": 276, "label": "false"} 278 | {"id": 277, "label": "false"} 279 | {"id": 278, "label": "false"} 280 | {"id": 279, "label": "false"} 281 | {"id": 280, "label": "false"} 282 | {"id": 281, "label": "false"} 283 | {"id": 282, "label": "false"} 284 | {"id": 283, "label": "false"} 285 | {"id": 284, "label": "false"} 286 | {"id": 285, "label": "false"} 287 | {"id": 286, "label": "false"} 288 | {"id": 287, "label": "false"} 289 | {"id": 288, "label": "false"} 290 | {"id": 289, "label": "false"} 291 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/wsc_output/wsc_predict.json: -------------------------------------------------------------------------------- 1 | {"id": 0, "label": "false"} 2 | {"id": 1, "label": "false"} 3 | {"id": 2, "label": "false"} 4 | {"id": 3, "label": "false"} 5 | {"id": 4, "label": "false"} 6 | {"id": 5, "label": "false"} 7 | {"id": 6, "label": "false"} 8 | {"id": 7, "label": "false"} 9 | {"id": 8, "label": "false"} 10 | {"id": 9, "label": "false"} 11 | {"id": 10, "label": "false"} 12 | {"id": 11, "label": "false"} 13 | {"id": 12, "label": "false"} 14 | {"id": 13, "label": "false"} 15 | {"id": 14, "label": "false"} 16 | {"id": 15, "label": "false"} 17 | {"id": 16, "label": "false"} 18 | {"id": 17, "label": "false"} 19 | {"id": 18, "label": "false"} 20 | {"id": 19, "label": "false"} 21 | {"id": 20, "label": "false"} 22 | {"id": 21, "label": "false"} 23 | {"id": 22, "label": "false"} 24 | {"id": 23, "label": "false"} 25 | {"id": 24, "label": "false"} 26 | {"id": 25, "label": "false"} 27 | {"id": 26, "label": "false"} 28 | {"id": 27, "label": "false"} 29 | {"id": 28, "label": "false"} 30 | {"id": 29, "label": "false"} 31 | {"id": 30, "label": "false"} 32 | {"id": 31, "label": "false"} 33 | {"id": 32, "label": "false"} 34 | {"id": 33, "label": "false"} 35 | {"id": 34, "label": "false"} 36 | {"id": 35, "label": "false"} 37 | {"id": 36, "label": "false"} 38 | {"id": 37, "label": "false"} 39 | {"id": 38, "label": "false"} 40 | {"id": 39, "label": "false"} 41 | {"id": 40, "label": "false"} 42 | {"id": 41, "label": "false"} 43 | {"id": 42, "label": "false"} 44 | {"id": 43, "label": "false"} 45 | {"id": 44, "label": "false"} 46 | {"id": 45, "label": "false"} 47 | {"id": 46, "label": "false"} 48 | {"id": 47, "label": "false"} 49 | {"id": 48, "label": "false"} 50 | {"id": 49, "label": "false"} 51 | {"id": 50, "label": "false"} 52 | {"id": 51, "label": "false"} 53 | {"id": 52, "label": "false"} 54 | {"id": 53, "label": "false"} 55 | {"id": 54, "label": "false"} 56 | {"id": 55, "label": "false"} 57 | {"id": 56, "label": "false"} 58 | {"id": 57, "label": "false"} 59 | {"id": 58, "label": "false"} 60 | {"id": 59, "label": "false"} 61 | {"id": 60, "label": "false"} 62 | {"id": 61, "label": "false"} 63 | {"id": 62, "label": "false"} 64 | {"id": 63, "label": "false"} 65 | {"id": 64, "label": "false"} 66 | {"id": 65, "label": "false"} 67 | {"id": 66, "label": "false"} 68 | {"id": 67, "label": "false"} 69 | {"id": 68, "label": "false"} 70 | {"id": 69, "label": "false"} 71 | {"id": 70, "label": "false"} 72 | {"id": 71, "label": "false"} 73 | {"id": 72, "label": "false"} 74 | {"id": 73, "label": "false"} 75 | {"id": 74, "label": "false"} 76 | {"id": 75, "label": "false"} 77 | {"id": 76, "label": "false"} 78 | {"id": 77, "label": "false"} 79 | {"id": 78, "label": "false"} 80 | {"id": 79, "label": "false"} 81 | {"id": 80, "label": "false"} 82 | {"id": 81, "label": "false"} 83 | {"id": 82, "label": "false"} 84 | {"id": 83, "label": "false"} 85 | {"id": 84, "label": "false"} 86 | {"id": 85, "label": "false"} 87 | {"id": 86, "label": "false"} 88 | {"id": 87, "label": "false"} 89 | {"id": 88, "label": "false"} 90 | {"id": 89, "label": "false"} 91 | {"id": 90, "label": "false"} 92 | {"id": 91, "label": "false"} 93 | {"id": 92, "label": "false"} 94 | {"id": 93, "label": "false"} 95 | {"id": 94, "label": "false"} 96 | {"id": 95, "label": "false"} 97 | {"id": 96, "label": "false"} 98 | {"id": 97, "label": "false"} 99 | {"id": 98, "label": "false"} 100 | {"id": 99, "label": "false"} 101 | {"id": 100, "label": "false"} 102 | {"id": 101, "label": "false"} 103 | {"id": 102, "label": "false"} 104 | {"id": 103, "label": "false"} 105 | {"id": 104, "label": "false"} 106 | {"id": 105, "label": "false"} 107 | {"id": 106, "label": "false"} 108 | {"id": 107, "label": "false"} 109 | {"id": 108, "label": "false"} 110 | {"id": 109, "label": "false"} 111 | {"id": 110, "label": "false"} 112 | {"id": 111, "label": "false"} 113 | {"id": 112, "label": "false"} 114 | {"id": 113, "label": "false"} 115 | {"id": 114, "label": "false"} 116 | {"id": 115, "label": "false"} 117 | {"id": 116, "label": "false"} 118 | {"id": 117, "label": "false"} 119 | {"id": 118, "label": "false"} 120 | {"id": 119, "label": "false"} 121 | {"id": 120, "label": "false"} 122 | {"id": 121, "label": "false"} 123 | {"id": 122, "label": "false"} 124 | {"id": 123, "label": "false"} 125 | {"id": 124, "label": "false"} 126 | {"id": 125, "label": "false"} 127 | {"id": 126, "label": "false"} 128 | {"id": 127, "label": "false"} 129 | {"id": 128, "label": "false"} 130 | {"id": 129, "label": "false"} 131 | {"id": 130, "label": "false"} 132 | {"id": 131, "label": "false"} 133 | {"id": 132, "label": "false"} 134 | {"id": 133, "label": "false"} 135 | {"id": 134, "label": "false"} 136 | {"id": 135, "label": "false"} 137 | {"id": 136, "label": "false"} 138 | {"id": 137, "label": "false"} 139 | {"id": 138, "label": "false"} 140 | {"id": 139, "label": "false"} 141 | {"id": 140, "label": "false"} 142 | {"id": 141, "label": "false"} 143 | {"id": 142, "label": "false"} 144 | {"id": 143, "label": "false"} 145 | {"id": 144, "label": "false"} 146 | {"id": 145, "label": "false"} 147 | {"id": 146, "label": "false"} 148 | {"id": 147, "label": "false"} 149 | {"id": 148, "label": "false"} 150 | {"id": 149, "label": "false"} 151 | {"id": 150, "label": "false"} 152 | {"id": 151, "label": "false"} 153 | {"id": 152, "label": "false"} 154 | {"id": 153, "label": "false"} 155 | {"id": 154, "label": "false"} 156 | {"id": 155, "label": "false"} 157 | {"id": 156, "label": "false"} 158 | {"id": 157, "label": "false"} 159 | {"id": 158, "label": "false"} 160 | {"id": 159, "label": "false"} 161 | {"id": 160, "label": "false"} 162 | {"id": 161, "label": "false"} 163 | {"id": 162, "label": "false"} 164 | {"id": 163, "label": "false"} 165 | {"id": 164, "label": "false"} 166 | {"id": 165, "label": "false"} 167 | {"id": 166, "label": "false"} 168 | {"id": 167, "label": "false"} 169 | {"id": 168, "label": "false"} 170 | {"id": 169, "label": "false"} 171 | {"id": 170, "label": "false"} 172 | {"id": 171, "label": "false"} 173 | {"id": 172, "label": "false"} 174 | {"id": 173, "label": "false"} 175 | {"id": 174, "label": "false"} 176 | {"id": 175, "label": "false"} 177 | {"id": 176, "label": "false"} 178 | {"id": 177, "label": "false"} 179 | {"id": 178, "label": "false"} 180 | {"id": 179, "label": "false"} 181 | {"id": 180, "label": "false"} 182 | {"id": 181, "label": "false"} 183 | {"id": 182, "label": "false"} 184 | {"id": 183, "label": "false"} 185 | {"id": 184, "label": "false"} 186 | {"id": 185, "label": "false"} 187 | {"id": 186, "label": "false"} 188 | {"id": 187, "label": "true"} 189 | {"id": 188, "label": "true"} 190 | {"id": 189, "label": "false"} 191 | {"id": 190, "label": "false"} 192 | {"id": 191, "label": "false"} 193 | {"id": 192, "label": "false"} 194 | {"id": 193, "label": "false"} 195 | {"id": 194, "label": "false"} 196 | {"id": 195, "label": "false"} 197 | {"id": 196, "label": "false"} 198 | {"id": 197, "label": "false"} 199 | {"id": 198, "label": "false"} 200 | {"id": 199, "label": "false"} 201 | {"id": 200, "label": "false"} 202 | {"id": 201, "label": "false"} 203 | {"id": 202, "label": "false"} 204 | {"id": 203, "label": "false"} 205 | {"id": 204, "label": "false"} 206 | {"id": 205, "label": "false"} 207 | {"id": 206, "label": "false"} 208 | {"id": 207, "label": "false"} 209 | {"id": 208, "label": "false"} 210 | {"id": 209, "label": "false"} 211 | {"id": 210, "label": "false"} 212 | {"id": 211, "label": "false"} 213 | {"id": 212, "label": "false"} 214 | {"id": 213, "label": "false"} 215 | {"id": 214, "label": "false"} 216 | {"id": 215, "label": "false"} 217 | {"id": 216, "label": "false"} 218 | {"id": 217, "label": "false"} 219 | {"id": 218, "label": "false"} 220 | {"id": 219, "label": "false"} 221 | {"id": 220, "label": "false"} 222 | {"id": 221, "label": "false"} 223 | {"id": 222, "label": "false"} 224 | {"id": 223, "label": "false"} 225 | {"id": 224, "label": "false"} 226 | {"id": 225, "label": "false"} 227 | {"id": 226, "label": "false"} 228 | {"id": 227, "label": "false"} 229 | {"id": 228, "label": "false"} 230 | {"id": 229, "label": "false"} 231 | {"id": 230, "label": "false"} 232 | {"id": 231, "label": "false"} 233 | {"id": 232, "label": "false"} 234 | {"id": 233, "label": "false"} 235 | {"id": 234, "label": "false"} 236 | {"id": 235, "label": "false"} 237 | {"id": 236, "label": "false"} 238 | {"id": 237, "label": "false"} 239 | {"id": 238, "label": "false"} 240 | {"id": 239, "label": "false"} 241 | {"id": 240, "label": "false"} 242 | {"id": 241, "label": "false"} 243 | {"id": 242, "label": "false"} 244 | {"id": 243, "label": "false"} 245 | {"id": 244, "label": "false"} 246 | {"id": 245, "label": "false"} 247 | {"id": 246, "label": "false"} 248 | {"id": 247, "label": "false"} 249 | {"id": 248, "label": "false"} 250 | {"id": 249, "label": "false"} 251 | {"id": 250, "label": "false"} 252 | {"id": 251, "label": "false"} 253 | {"id": 252, "label": "false"} 254 | {"id": 253, "label": "false"} 255 | {"id": 254, "label": "false"} 256 | {"id": 255, "label": "false"} 257 | {"id": 256, "label": "false"} 258 | {"id": 257, "label": "false"} 259 | {"id": 258, "label": "false"} 260 | {"id": 259, "label": "false"} 261 | {"id": 260, "label": "false"} 262 | {"id": 261, "label": "false"} 263 | {"id": 262, "label": "false"} 264 | {"id": 263, "label": "false"} 265 | {"id": 264, "label": "false"} 266 | {"id": 265, "label": "false"} 267 | {"id": 266, "label": "false"} 268 | {"id": 267, "label": "false"} 269 | {"id": 268, "label": "false"} 270 | {"id": 269, "label": "false"} 271 | {"id": 270, "label": "false"} 272 | {"id": 271, "label": "false"} 273 | {"id": 272, "label": "false"} 274 | {"id": 273, "label": "false"} 275 | {"id": 274, "label": "false"} 276 | {"id": 275, "label": "false"} 277 | {"id": 276, "label": "false"} 278 | {"id": 277, "label": "false"} 279 | {"id": 278, "label": "false"} 280 | {"id": 279, "label": "false"} 281 | {"id": 280, "label": "false"} 282 | {"id": 281, "label": "false"} 283 | {"id": 282, "label": "false"} 284 | {"id": 283, "label": "false"} 285 | {"id": 284, "label": "false"} 286 | {"id": 285, "label": "false"} 287 | {"id": 286, "label": "false"} 288 | {"id": 287, "label": "false"} 289 | {"id": 288, "label": "false"} 290 | {"id": 289, "label": "false"} 291 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for line in iterable: 77 | line = line.rstrip('\r\n') 78 | 79 | if options.delimiter == ANY_SPACE: 80 | features = line.split() 81 | else: 82 | features = line.split(options.delimiter) 83 | 84 | if num_features is None: 85 | num_features = len(features) 86 | elif num_features != len(features) and len(features) != 0: 87 | raise FormatError('unexpected number of features: %d (%d)' % 88 | (len(features), num_features)) 89 | 90 | if len(features) == 0 or features[0] == options.boundary: 91 | features = [options.boundary, 'O', 'O'] 92 | if len(features) < 3: 93 | raise FormatError('unexpected number of features in line %s' % line) 94 | 95 | guessed, guessed_type = parse_tag(features.pop()) 96 | correct, correct_type = parse_tag(features.pop()) 97 | first_item = features.pop(0) 98 | 99 | if first_item == options.boundary: 100 | guessed = 'O' 101 | 102 | end_correct = end_of_chunk(last_correct, correct, 103 | last_correct_type, correct_type) 104 | end_guessed = end_of_chunk(last_guessed, guessed, 105 | last_guessed_type, guessed_type) 106 | start_correct = start_of_chunk(last_correct, correct, 107 | last_correct_type, correct_type) 108 | start_guessed = start_of_chunk(last_guessed, guessed, 109 | last_guessed_type, guessed_type) 110 | 111 | if in_correct: 112 | if (end_correct and end_guessed and 113 | last_guessed_type == last_correct_type): 114 | in_correct = False 115 | counts.correct_chunk += 1 116 | counts.t_correct_chunk[last_correct_type] += 1 117 | elif (end_correct != end_guessed or guessed_type != correct_type): 118 | in_correct = False 119 | 120 | if start_correct and start_guessed and guessed_type == correct_type: 121 | in_correct = True 122 | 123 | if start_correct: 124 | counts.found_correct += 1 125 | counts.t_found_correct[correct_type] += 1 126 | if start_guessed: 127 | counts.found_guessed += 1 128 | counts.t_found_guessed[guessed_type] += 1 129 | if first_item != options.boundary: 130 | if correct == guessed and guessed_type == correct_type: 131 | counts.correct_tags += 1 132 | counts.token_counter += 1 133 | 134 | last_guessed = guessed 135 | last_correct = correct 136 | last_guessed_type = guessed_type 137 | last_correct_type = correct_type 138 | 139 | if in_correct: 140 | counts.correct_chunk += 1 141 | counts.t_correct_chunk[last_correct_type] += 1 142 | 143 | return counts 144 | 145 | 146 | 147 | def uniq(iterable): 148 | seen = set() 149 | return [i for i in iterable if not (i in seen or seen.add(i))] 150 | 151 | 152 | def calculate_metrics(correct, guessed, total): 153 | tp, fp, fn = correct, guessed-correct, total-correct 154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 156 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 157 | return Metrics(tp, fp, fn, p, r, f) 158 | 159 | 160 | def metrics(counts): 161 | c = counts 162 | overall = calculate_metrics( 163 | c.correct_chunk, c.found_guessed, c.found_correct 164 | ) 165 | by_type = {} 166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 167 | by_type[t] = calculate_metrics( 168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 169 | ) 170 | return overall, by_type 171 | 172 | 173 | def report(counts, out=None): 174 | if out is None: 175 | out = sys.stdout 176 | 177 | overall, by_type = metrics(counts) 178 | 179 | c = counts 180 | out.write('processed %d tokens with %d phrases; ' % 181 | (c.token_counter, c.found_correct)) 182 | out.write('found: %d phrases; correct: %d.\n' % 183 | (c.found_guessed, c.correct_chunk)) 184 | 185 | if c.token_counter > 0: 186 | out.write('accuracy: %6.2f%%; ' % 187 | (100.*c.correct_tags/c.token_counter)) 188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 191 | 192 | for i, m in sorted(by_type.items()): 193 | out.write('%17s: ' % i) 194 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 195 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 197 | 198 | 199 | def report_notprint(counts, out=None): 200 | if out is None: 201 | out = sys.stdout 202 | 203 | overall, by_type = metrics(counts) 204 | 205 | c = counts 206 | final_report = [] 207 | line = [] 208 | line.append('processed %d tokens with %d phrases; ' % 209 | (c.token_counter, c.found_correct)) 210 | line.append('found: %d phrases; correct: %d.\n' % 211 | (c.found_guessed, c.correct_chunk)) 212 | final_report.append("".join(line)) 213 | 214 | if c.token_counter > 0: 215 | line = [] 216 | line.append('accuracy: %6.2f%%; ' % 217 | (100.*c.correct_tags/c.token_counter)) 218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 221 | final_report.append("".join(line)) 222 | 223 | for i, m in sorted(by_type.items()): 224 | line = [] 225 | line.append('%17s: ' % i) 226 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 227 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 229 | final_report.append("".join(line)) 230 | return final_report 231 | 232 | 233 | def end_of_chunk(prev_tag, tag, prev_type, type_): 234 | # check if a chunk ended between the previous and current word 235 | # arguments: previous and current chunk tags, previous and current types 236 | chunk_end = False 237 | 238 | if prev_tag == 'E': chunk_end = True 239 | if prev_tag == 'S': chunk_end = True 240 | 241 | if prev_tag == 'B' and tag == 'B': chunk_end = True 242 | if prev_tag == 'B' and tag == 'S': chunk_end = True 243 | if prev_tag == 'B' and tag == 'O': chunk_end = True 244 | if prev_tag == 'I' and tag == 'B': chunk_end = True 245 | if prev_tag == 'I' and tag == 'S': chunk_end = True 246 | if prev_tag == 'I' and tag == 'O': chunk_end = True 247 | 248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 249 | chunk_end = True 250 | 251 | # these chunks are assumed to have length 1 252 | if prev_tag == ']': chunk_end = True 253 | if prev_tag == '[': chunk_end = True 254 | 255 | return chunk_end 256 | 257 | 258 | def start_of_chunk(prev_tag, tag, prev_type, type_): 259 | # check if a chunk started between the previous and current word 260 | # arguments: previous and current chunk tags, previous and current types 261 | chunk_start = False 262 | 263 | if tag == 'B': chunk_start = True 264 | if tag == 'S': chunk_start = True 265 | 266 | if prev_tag == 'E' and tag == 'E': chunk_start = True 267 | if prev_tag == 'E' and tag == 'I': chunk_start = True 268 | if prev_tag == 'S' and tag == 'E': chunk_start = True 269 | if prev_tag == 'S' and tag == 'I': chunk_start = True 270 | if prev_tag == 'O' and tag == 'E': chunk_start = True 271 | if prev_tag == 'O' and tag == 'I': chunk_start = True 272 | 273 | if tag != 'O' and tag != '.' and prev_type != type_: 274 | chunk_start = True 275 | 276 | # these chunks are assumed to have length 1 277 | if tag == '[': chunk_start = True 278 | if tag == ']': chunk_start = True 279 | 280 | return chunk_start 281 | 282 | 283 | def return_report(input_file): 284 | with codecs.open(input_file, "r", "utf8") as f: 285 | counts = evaluate(f) 286 | return report_notprint(counts) 287 | 288 | 289 | def main(argv): 290 | args = parse_args(argv[1:]) 291 | 292 | if args.file is None: 293 | counts = evaluate(sys.stdin, args) 294 | else: 295 | with open(args.file) as f: 296 | counts = evaluate(f, args) 297 | report(counts) 298 | 299 | if __name__ == '__main__': 300 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual Cased (New, recommended)`](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)**: 8 | 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 10 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 11 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 12 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 13 | parameters 14 | 15 | **The `Multilingual Cased (New)` model also fixes normalization issues in many 16 | languages, so it is recommended in languages with non-Latin alphabets (and is 17 | often better for most languages with Latin alphabets). When using this model, 18 | make sure to pass `--do_lower_case=false` to `run_pretraining.py` and other 19 | scripts.** 20 | 21 | See the [list of languages](#list-of-languages) that the Multilingual model 22 | supports. The Multilingual model does include Chinese (and English), but if your 23 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 24 | better results. 25 | 26 | ## Results 27 | 28 | To evaluate these systems, we use the 29 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 30 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 31 | dev and test sets have been translated (by humans) into 15 languages. Note that 32 | the training set was *machine* translated (we used the translations provided by 33 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 34 | 35 | 36 | 37 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 38 | | --------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 39 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 40 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 41 | | BERT - Translate Train Cased | **81.9** | **76.6** | **77.8** | **75.9** | **70.7** | 61.6 | 42 | | BERT - Translate Train Uncased | 81.4 | 74.2 | 77.3 | 75.2 | 70.5 | 61.7 | 43 | | BERT - Translate Test Uncased | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 44 | | BERT - Zero Shot Uncased | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 45 | 46 | 47 | 48 | The first two rows are baselines from the XNLI paper and the last three rows are 49 | our results with BERT. 50 | 51 | **Translate Train** means that the MultiNLI training set was machine translated 52 | from English into the foreign language. So training and evaluation were both 53 | done in the foreign language. Unfortunately, training was done on 54 | machine-translated data, so it is impossible to quantify how much of the lower 55 | accuracy (compared to English) is due to the quality of the machine translation 56 | vs. the quality of the pre-trained model. 57 | 58 | **Translate Test** means that the XNLI test set was machine translated from the 59 | foreign language into English. So training and evaluation were both done on 60 | English. However, test evaluation was done on machine-translated English, so the 61 | accuracy depends on the quality of the machine translation system. 62 | 63 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 64 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 65 | machine translation was not involved at all in either the pre-training or 66 | fine-tuning. 67 | 68 | Note that the English result is worse than the 84.2 MultiNLI baseline because 69 | this training used Multilingual BERT rather than English-only BERT. This implies 70 | that for high-resource languages, the Multilingual model is somewhat worse than 71 | a single-language model. However, it is not feasible for us to train and 72 | maintain dozens of single-language models. Therefore, if your goal is to maximize 73 | performance with a language other than English or Chinese, you might find it 74 | beneficial to run pre-training for additional steps starting from our 75 | Multilingual model on data from your language of interest. 76 | 77 | Here is a comparison of training Chinese models with the Multilingual 78 | `BERT-Base` and Chinese-only `BERT-Base`: 79 | 80 | System | Chinese 81 | ----------------------- | ------- 82 | XNLI Baseline | 67.0 83 | BERT Multilingual Model | 74.2 84 | BERT Chinese-only Model | 77.2 85 | 86 | Similar to English, the single-language model does 3% better than the 87 | Multilingual model. 88 | 89 | ## Fine-tuning Example 90 | 91 | The multilingual model does **not** require any special consideration or API 92 | changes. We did update the implementation of `BasicTokenizer` in 93 | `tokenization.py` to support Chinese character tokenization, so please update if 94 | you forked it. However, we did not change the tokenization API. 95 | 96 | To test the new models, we did modify `run_classifier.py` to add support for the 97 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 98 | version of MultiNLI where the dev/test sets have been human-translated, and the 99 | training set has been machine-translated. 100 | 101 | To run the fine-tuning code, please download the 102 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 103 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 104 | and then unpack both .zip files into some directory `$XNLI_DIR`. 105 | 106 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 107 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 108 | another language. 109 | 110 | This is a large dataset, so this will training will take a few hours on a GPU 111 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 112 | debugging, just set `num_train_epochs` to a small value like `0.1`. 113 | 114 | ```shell 115 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 116 | export XNLI_DIR=/path/to/xnli 117 | 118 | python run_classifier.py \ 119 | --task_name=XNLI \ 120 | --do_train=true \ 121 | --do_eval=true \ 122 | --data_dir=$XNLI_DIR \ 123 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 124 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 125 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 126 | --max_seq_length=128 \ 127 | --train_batch_size=32 \ 128 | --learning_rate=5e-5 \ 129 | --num_train_epochs=2.0 \ 130 | --output_dir=/tmp/xnli_output/ 131 | ``` 132 | 133 | With the Chinese-only model, the results should look something like this: 134 | 135 | ``` 136 | ***** Eval results ***** 137 | eval_accuracy = 0.774116 138 | eval_loss = 0.83554 139 | global_step = 24543 140 | loss = 0.74603 141 | ``` 142 | 143 | ## Details 144 | 145 | ### Data Source and Sampling 146 | 147 | The languages chosen were the 148 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 149 | The entire Wikipedia dump for each language (excluding user and talk pages) was 150 | taken as the training data for each language 151 | 152 | However, the size of the Wikipedia for a given language varies greatly, and 153 | therefore low-resource languages may be "under-represented" in terms of the 154 | neural network model (under the assumption that languages are "competing" for 155 | limited model capacity to some extent). At the same time, we also don't want 156 | to overfit the model by performing thousands of epochs over a tiny Wikipedia 157 | for a particular language. 158 | 159 | To balance these two factors, we performed exponentially smoothed weighting of 160 | the data during pre-training data creation (and WordPiece vocab creation). In 161 | other words, let's say that the probability of a language is *P(L)*, e.g., 162 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 163 | together, 21% of our data is English. We exponentiate each probability by some 164 | factor *S* and then re-normalize, and sample from that distribution. In our case 165 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 166 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 167 | original distribution English would be sampled 1000x more than Icelandic, but 168 | after smoothing it's only sampled 100x more. 169 | 170 | ### Tokenization 171 | 172 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 173 | weighted the same way as the data, so low-resource languages are upweighted by 174 | some factor. We intentionally do *not* use any marker to denote the input 175 | language (so that zero-shot training can work). 176 | 177 | Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace 178 | characters, we add spaces around every character in the 179 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 180 | before applying WordPiece. This means that Chinese is effectively 181 | character-tokenized. Note that the CJK Unicode block only includes 182 | Chinese-origin characters and does *not* include Hangul Korean or 183 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 184 | all other languages. 185 | 186 | For all other languages, we apply the 187 | [same recipe as English](https://github.com/google-research/bert#tokenization): 188 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 189 | tokenization. We understand that accent markers have substantial meaning in some 190 | languages, but felt that the benefits of reducing the effective vocabulary make 191 | up for this. Generally the strong contextual models of BERT should make up for 192 | any ambiguity introduced by stripping accent markers. 193 | 194 | ### List of Languages 195 | 196 | The multilingual model supports the following languages. These languages were 197 | chosen because they are the top 100 languages with the largest Wikipedias: 198 | 199 | * Afrikaans 200 | * Albanian 201 | * Arabic 202 | * Aragonese 203 | * Armenian 204 | * Asturian 205 | * Azerbaijani 206 | * Bashkir 207 | * Basque 208 | * Bavarian 209 | * Belarusian 210 | * Bengali 211 | * Bishnupriya Manipuri 212 | * Bosnian 213 | * Breton 214 | * Bulgarian 215 | * Burmese 216 | * Catalan 217 | * Cebuano 218 | * Chechen 219 | * Chinese (Simplified) 220 | * Chinese (Traditional) 221 | * Chuvash 222 | * Croatian 223 | * Czech 224 | * Danish 225 | * Dutch 226 | * English 227 | * Estonian 228 | * Finnish 229 | * French 230 | * Galician 231 | * Georgian 232 | * German 233 | * Greek 234 | * Gujarati 235 | * Haitian 236 | * Hebrew 237 | * Hindi 238 | * Hungarian 239 | * Icelandic 240 | * Ido 241 | * Indonesian 242 | * Irish 243 | * Italian 244 | * Japanese 245 | * Javanese 246 | * Kannada 247 | * Kazakh 248 | * Kirghiz 249 | * Korean 250 | * Latin 251 | * Latvian 252 | * Lithuanian 253 | * Lombard 254 | * Low Saxon 255 | * Luxembourgish 256 | * Macedonian 257 | * Malagasy 258 | * Malay 259 | * Malayalam 260 | * Marathi 261 | * Minangkabau 262 | * Nepali 263 | * Newar 264 | * Norwegian (Bokmal) 265 | * Norwegian (Nynorsk) 266 | * Occitan 267 | * Persian (Farsi) 268 | * Piedmontese 269 | * Polish 270 | * Portuguese 271 | * Punjabi 272 | * Romanian 273 | * Russian 274 | * Scots 275 | * Serbian 276 | * Serbo-Croatian 277 | * Sicilian 278 | * Slovak 279 | * Slovenian 280 | * South Azerbaijani 281 | * Spanish 282 | * Sundanese 283 | * Swahili 284 | * Swedish 285 | * Tagalog 286 | * Tajik 287 | * Tamil 288 | * Tatar 289 | * Telugu 290 | * Turkish 291 | * Ukrainian 292 | * Urdu 293 | * Uzbek 294 | * Vietnamese 295 | * Volapük 296 | * Waray-Waray 297 | * Welsh 298 | * West Frisian 299 | * Western Punjabi 300 | * Yoruba 301 | 302 | The **Multilingual Cased (New)** release contains additionally **Thai** and 303 | **Mongolian**, which were not included in the original release. 304 | -------------------------------------------------------------------------------- /baselines/models/bert_mrc/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /baselines/models/bert_wsc_csl/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /baselines/models/bert_ner/tokenization.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | import re 9 | import unicodedata 10 | import six 11 | import tensorflow as tf 12 | 13 | 14 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 15 | """Checks whether the casing config is consistent with the checkpoint name.""" 16 | 17 | # The casing has to be passed in by the user and there is no explicit check 18 | # as to whether it matches the checkpoint. The casing information probably 19 | # should have been stored in the bert_config.json file, but it's not, so 20 | # we have to heuristically detect it to validate. 21 | 22 | if not init_checkpoint: 23 | return 24 | 25 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 26 | if m is None: 27 | return 28 | 29 | model_name = m.group(1) 30 | 31 | lower_models = [ 32 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 33 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 34 | ] 35 | 36 | cased_models = [ 37 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 38 | "multi_cased_L-12_H-768_A-12" 39 | ] 40 | 41 | is_bad_config = False 42 | if model_name in lower_models and not do_lower_case: 43 | is_bad_config = True 44 | actual_flag = "False" 45 | case_name = "lowercased" 46 | opposite_flag = "True" 47 | 48 | if model_name in cased_models and do_lower_case: 49 | is_bad_config = True 50 | actual_flag = "True" 51 | case_name = "cased" 52 | opposite_flag = "False" 53 | 54 | if is_bad_config: 55 | raise ValueError( 56 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 57 | "However, `%s` seems to be a %s model, so you " 58 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 59 | "how the model was pre-training. If this error is wrong, please " 60 | "just comment out this check." % (actual_flag, init_checkpoint, 61 | model_name, case_name, opposite_flag)) 62 | 63 | 64 | def convert_to_unicode(text): 65 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 66 | if six.PY3: 67 | if isinstance(text, str): 68 | return text 69 | elif isinstance(text, bytes): 70 | return text.decode("utf-8", "ignore") 71 | else: 72 | raise ValueError("Unsupported string type: %s" % (type(text))) 73 | elif six.PY2: 74 | if isinstance(text, str): 75 | return text.decode("utf-8", "ignore") 76 | elif isinstance(text, unicode): 77 | return text 78 | else: 79 | raise ValueError("Unsupported string type: %s" % (type(text))) 80 | else: 81 | raise ValueError("Not running on Python2 or Python 3?") 82 | 83 | 84 | def printable_text(text): 85 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 86 | 87 | # These functions want `str` for both Python2 and Python3, but in one case 88 | # it's a Unicode string and in the other it's a byte string. 89 | if six.PY3: 90 | if isinstance(text, str): 91 | return text 92 | elif isinstance(text, bytes): 93 | return text.decode("utf-8", "ignore") 94 | else: 95 | raise ValueError("Unsupported string type: %s" % (type(text))) 96 | elif six.PY2: 97 | if isinstance(text, str): 98 | return text 99 | elif isinstance(text, unicode): 100 | return text.encode("utf-8") 101 | else: 102 | raise ValueError("Unsupported string type: %s" % (type(text))) 103 | else: 104 | raise ValueError("Not running on Python2 or Python 3?") 105 | 106 | 107 | def load_vocab(vocab_file): 108 | """Loads a vocabulary file into a dictionary.""" 109 | vocab = collections.OrderedDict() 110 | index = 0 111 | with tf.gfile.GFile(vocab_file, "r") as reader: 112 | while True: 113 | token = convert_to_unicode(reader.readline()) 114 | if not token: 115 | break 116 | token = token.strip() 117 | vocab[token] = index 118 | index += 1 119 | return vocab 120 | 121 | 122 | def convert_by_vocab(vocab, items): 123 | """Converts a sequence of [tokens|ids] using the vocab.""" 124 | output = [] 125 | for item in items: 126 | if item in vocab: 127 | output.append(vocab[item]) 128 | else: 129 | output.append(vocab['[UNK]']) 130 | return output 131 | 132 | 133 | def convert_tokens_to_ids(vocab, tokens): 134 | return convert_by_vocab(vocab, tokens) 135 | 136 | 137 | def convert_ids_to_tokens(inv_vocab, ids): 138 | return convert_by_vocab(inv_vocab, ids) 139 | 140 | 141 | def whitespace_tokenize(text): 142 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 143 | text = text.strip() 144 | if not text: 145 | return [] 146 | tokens = text.split() 147 | return tokens 148 | 149 | 150 | class FullTokenizer(object): 151 | """Runs end-to-end tokenziation.""" 152 | 153 | def __init__(self, vocab_file, do_lower_case=True): 154 | self.vocab = load_vocab(vocab_file) 155 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 156 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 157 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 158 | 159 | def tokenize(self, text): 160 | split_tokens = [] 161 | for token in self.basic_tokenizer.tokenize(text): 162 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 163 | split_tokens.append(sub_token) 164 | 165 | return split_tokens 166 | 167 | def convert_tokens_to_ids(self, tokens): 168 | return convert_by_vocab(self.vocab, tokens) 169 | 170 | def convert_ids_to_tokens(self, ids): 171 | return convert_by_vocab(self.inv_vocab, ids) 172 | 173 | 174 | class BasicTokenizer(object): 175 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 176 | 177 | def __init__(self, do_lower_case=True): 178 | """Constructs a BasicTokenizer. 179 | 180 | Args: 181 | do_lower_case: Whether to lower case the input. 182 | """ 183 | self.do_lower_case = do_lower_case 184 | 185 | def tokenize(self, text): 186 | """Tokenizes a piece of text.""" 187 | text = convert_to_unicode(text) 188 | text = self._clean_text(text) 189 | 190 | # This was added on November 1st, 2018 for the multilingual and Chinese 191 | # models. This is also applied to the English models now, but it doesn't 192 | # matter since the English models were not trained on any Chinese data 193 | # and generally don't have any Chinese data in them (there are Chinese 194 | # characters in the vocabulary because Wikipedia does have some Chinese 195 | # words in the English Wikipedia.). 196 | text = self._tokenize_chinese_chars(text) 197 | 198 | orig_tokens = whitespace_tokenize(text) 199 | split_tokens = [] 200 | for token in orig_tokens: 201 | if self.do_lower_case: 202 | token = token.lower() 203 | token = self._run_strip_accents(token) 204 | split_tokens.extend(self._run_split_on_punc(token)) 205 | 206 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 207 | return output_tokens 208 | 209 | def _run_strip_accents(self, text): 210 | """Strips accents from a piece of text.""" 211 | text = unicodedata.normalize("NFD", text) 212 | output = [] 213 | for char in text: 214 | cat = unicodedata.category(char) 215 | if cat == "Mn": 216 | continue 217 | output.append(char) 218 | return "".join(output) 219 | 220 | def _run_split_on_punc(self, text): 221 | """Splits punctuation on a piece of text.""" 222 | chars = list(text) 223 | i = 0 224 | start_new_word = True 225 | output = [] 226 | while i < len(chars): 227 | char = chars[i] 228 | if _is_punctuation(char): 229 | output.append([char]) 230 | start_new_word = True 231 | else: 232 | if start_new_word: 233 | output.append([]) 234 | start_new_word = False 235 | output[-1].append(char) 236 | i += 1 237 | 238 | return ["".join(x) for x in output] 239 | 240 | def _tokenize_chinese_chars(self, text): 241 | """Adds whitespace around any CJK character.""" 242 | output = [] 243 | for char in text: 244 | cp = ord(char) 245 | if self._is_chinese_char(cp): 246 | output.append(" ") 247 | output.append(char) 248 | output.append(" ") 249 | else: 250 | output.append(char) 251 | return "".join(output) 252 | 253 | def _is_chinese_char(self, cp): 254 | """Checks whether CP is the codepoint of a CJK character.""" 255 | # This defines a "chinese character" as anything in the CJK Unicode block: 256 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 257 | # 258 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 259 | # despite its name. The modern Korean Hangul alphabet is a different block, 260 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 261 | # space-separated words, so they are not treated specially and handled 262 | # like the all of the other languages. 263 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 264 | (cp >= 0x3400 and cp <= 0x4DBF) or # 265 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 266 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 267 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 268 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 269 | (cp >= 0xF900 and cp <= 0xFAFF) or # 270 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 271 | return True 272 | 273 | return False 274 | 275 | def _clean_text(self, text): 276 | """Performs invalid character removal and whitespace cleanup on text.""" 277 | output = [] 278 | for char in text: 279 | cp = ord(char) 280 | if cp == 0 or cp == 0xfffd or _is_control(char): 281 | continue 282 | if _is_whitespace(char): 283 | output.append(" ") 284 | else: 285 | output.append(char) 286 | return "".join(output) 287 | 288 | 289 | class WordpieceTokenizer(object): 290 | """Runs WordPiece tokenziation.""" 291 | 292 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 293 | self.vocab = vocab 294 | self.unk_token = unk_token 295 | self.max_input_chars_per_word = max_input_chars_per_word 296 | 297 | def tokenize(self, text): 298 | """Tokenizes a piece of text into its word pieces. 299 | 300 | This uses a greedy longest-match-first algorithm to perform tokenization 301 | using the given vocabulary. 302 | 303 | For example: 304 | input = "unaffable" 305 | output = ["un", "##aff", "##able"] 306 | 307 | Args: 308 | text: A single token or whitespace separated tokens. This should have 309 | already been passed through `BasicTokenizer. 310 | 311 | Returns: 312 | A list of wordpiece tokens. 313 | """ 314 | 315 | text = convert_to_unicode(text) 316 | 317 | output_tokens = [] 318 | for token in whitespace_tokenize(text): 319 | chars = list(token) 320 | if len(chars) > self.max_input_chars_per_word: 321 | output_tokens.append(self.unk_token) 322 | continue 323 | 324 | is_bad = False 325 | start = 0 326 | sub_tokens = [] 327 | while start < len(chars): 328 | end = len(chars) 329 | cur_substr = None 330 | while start < end: 331 | substr = "".join(chars[start:end]) 332 | if start > 0: 333 | substr = "##" + substr 334 | if substr in self.vocab: 335 | cur_substr = substr 336 | break 337 | end -= 1 338 | if cur_substr is None: 339 | is_bad = True 340 | break 341 | sub_tokens.append(cur_substr) 342 | start = end 343 | 344 | if is_bad: 345 | output_tokens.append(self.unk_token) 346 | else: 347 | output_tokens.extend(sub_tokens) 348 | return output_tokens 349 | 350 | 351 | def _is_whitespace(char): 352 | """Checks whether `chars` is a whitespace character.""" 353 | # \t, \n, and \r are technically contorl characters but we treat them 354 | # as whitespace since they are generally considered as such. 355 | if char == " " or char == "\t" or char == "\n" or char == "\r": 356 | return True 357 | cat = unicodedata.category(char) 358 | if cat == "Zs": 359 | return True 360 | return False 361 | 362 | 363 | def _is_control(char): 364 | """Checks whether `chars` is a control character.""" 365 | # These are technically control characters but we count them as whitespace 366 | # characters. 367 | if char == "\t" or char == "\n" or char == "\r": 368 | return False 369 | cat = unicodedata.category(char) 370 | if cat.startswith("C"): 371 | return True 372 | return False 373 | 374 | 375 | def _is_punctuation(char): 376 | """Checks whether `chars` is a punctuation character.""" 377 | cp = ord(char) 378 | # We treat all non-letter/number ASCII as punctuation. 379 | # Characters such as "^", "$", and "`" are not in the Unicode 380 | # Punctuation class but we treat them as punctuation anyways, for 381 | # consistency. 382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 384 | return True 385 | cat = unicodedata.category(char) 386 | if cat.startswith("P"): 387 | return True 388 | return False 389 | --------------------------------------------------------------------------------