├── img
└── shared_task_QA_code
├── example
├── nlpcc2020_baseline.zip
└── cluewsc2020_predict.json
├── baselines
├── models
│ ├── bert_wsc_csl
│ │ ├── requirements.txt
│ │ ├── prev_trained_models
│ │ │ └── RoBERTa-tiny-clue
│ │ │ │ └── bert_config.json
│ │ ├── __init__.py
│ │ ├── tpu
│ │ │ ├── run_classifier_inews.sh
│ │ │ ├── run_classifier_lcqmc.sh
│ │ │ ├── run_classifier_xnli.sh
│ │ │ ├── run_classifier_thucnews.sh
│ │ │ ├── run_classifier_jdcomment.sh
│ │ │ └── run_classifier_tnews.sh
│ │ ├── CONTRIBUTING.md
│ │ ├── optimization_test.py
│ │ ├── .gitignore
│ │ ├── run_classifier_csl.sh
│ │ ├── run_classifier_wsc.sh
│ │ ├── sample_text.txt
│ │ ├── run_classifier_clue.sh
│ │ ├── tokenization_test.py
│ │ ├── optimization.py
│ │ ├── tf_metrics.py
│ │ ├── wsc_output
│ │ │ └── wsc_predict.json
│ │ ├── modeling_test.py
│ │ ├── conlleval.py
│ │ ├── LICENSE
│ │ ├── multilingual.md
│ │ └── tokenization.py
│ ├── bert_mrc
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── run_cmrc.sh
│ │ ├── LICENSE
│ │ ├── cmrc2018_evaluate.py
│ │ ├── optimization.py
│ │ └── tokenization.py
│ ├── README.md
│ └── bert_ner
│ │ ├── prev_trained_models
│ │ └── RoBERTa-tiny-clue
│ │ │ └── bert_config.json
│ │ ├── len_count.json
│ │ ├── label2id.json
│ │ ├── readme.md
│ │ ├── run_classifier_tiny.sh
│ │ ├── data-ming.py
│ │ ├── data_processor_seq.py
│ │ ├── predict_sequence_label.py
│ │ ├── optimization.py
│ │ └── tokenization.py
└── CLUEdataset
│ └── ner
│ ├── README.md
│ └── cluener_predict.json
├── Schedule4clue.md
├── .gitignore
└── README.md
/img/shared_task_QA_code:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLUEbenchmark/LightLM/HEAD/img/shared_task_QA_code
--------------------------------------------------------------------------------
/example/nlpcc2020_baseline.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CLUEbenchmark/LightLM/HEAD/example/nlpcc2020_baseline.zip
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow.
2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
3 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/README.md:
--------------------------------------------------------------------------------
1 | # mrc
2 | You should run other tasks before for downloading models and data. I didn't write these shell.
3 |
4 | This will be modified later
5 |
6 |
7 |
8 | This is copied from
9 | https://github.com/johndpope/CMRC2018-DRCD-BERT
10 |
--------------------------------------------------------------------------------
/baselines/CLUEdataset/ner/README.md:
--------------------------------------------------------------------------------
1 | CLUENER 细粒度命名实体识别
2 |
3 | 数据分为10个标签类别,分别为:
4 | 地址(address),
5 | 书名(book),
6 | 公司(company),
7 | 游戏(game),
8 | 政府(goverment),
9 | 电影(movie),
10 | 姓名(name),
11 | 组织机构(organization),
12 | 职位(position),
13 | 景点(scene)
14 |
15 | 数据详细介绍、基线模型和效果测评,见 https://github.com/CLUEbenchmark/CLUENER
16 |
17 | 技术讨论或问题,请项目中提issue或PR,或发送电子邮件到 ChineseGLUE@163.com
18 |
19 | 测试集上SOTA效果见榜单:www.CLUEbenchmark.com
--------------------------------------------------------------------------------
/baselines/CLUEdataset/ner/cluener_predict.json:
--------------------------------------------------------------------------------
1 | {"id": 0, "label": {"address": {"丹棱县": [[11, 13]]}, "name": {"胡文和": [[41, 43]]}}}
2 | {"id": 1, "label": {"address": {"阿布贾": [[12, 14]]}, "goverment": {"尼日利亚海军": [[0, 5]], "尼日利亚通讯社": [[16, 22]]}}}
3 | {"id": 2, "label": {"game": {"辐射3-Bethesda": [[5, 16]]}}}
4 | {"id": 3, "label": {"scene": {"巴厘岛": [[9, 11]]}}}
5 | {"id": 4, "label": {?????}}
6 | {"id": 5, "label": {?????}}
7 | {"id": 6, "label": {?????}}
--------------------------------------------------------------------------------
/baselines/models/README.md:
--------------------------------------------------------------------------------
1 | baselines/models目录下提供了本次比赛中四个任务的预测脚本,使用方法如下:
2 |
3 | ```
4 | 可以进入baselines/models目录下
5 | 1. bert_wsc_csl目录下为WSC和CSL任务: bash run_classifier_clue.sh
6 | - 结果生成在wsc_ouput和csl_output目录下
7 | 2. bert_mrc 为 cmrc2018 任务baseline: bash run_ner_cmrc.sh
8 | - 结果生成为cmrc2018_predict.json
9 | 3. bert_ner 为 CLUENER任务 baseline: bash run_ner.sh
10 | - 结果生成在ner_output下
11 |
12 | ```
13 | 只要加载相应的模型,即可进行预测。本demo中会自动使用"RoBERTa-tiny-clue"进行预测
14 |
15 | 提交指导:将以上的结果,打包为nlpcc-xxx.zip 进行提交。注意各个输出文件的名称是不能变的,要保留,不然系统无法识别。
--------------------------------------------------------------------------------
/baselines/models/bert_ner/prev_trained_models/RoBERTa-tiny-clue/bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "directionality": "bidi",
4 | "hidden_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_size": 312,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 1248,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 4,
12 |
13 | "pooler_fc_size": 768,
14 | "pooler_num_attention_heads": 12,
15 | "pooler_num_fc_layers": 3,
16 | "pooler_size_per_head": 128,
17 | "pooler_type": "first_token_transform",
18 | "type_vocab_size": 2,
19 | "vocab_size": 8021
20 | }
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/prev_trained_models/RoBERTa-tiny-clue/bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "directionality": "bidi",
4 | "hidden_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_size": 312,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 1248,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 4,
12 |
13 | "pooler_fc_size": 768,
14 | "pooler_num_attention_heads": 12,
15 | "pooler_num_fc_layers": 3,
16 | "pooler_size_per_head": 128,
17 | "pooler_type": "first_token_transform",
18 | "type_vocab_size": 2,
19 | "vocab_size": 8021
20 | }
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/len_count.json:
--------------------------------------------------------------------------------
1 | {
2 | "50": 574,
3 | "18": 138,
4 | "22": 239,
5 | "40": 530,
6 | "44": 593,
7 | "48": 644,
8 | "20": 156,
9 | "27": 113,
10 | "30": 181,
11 | "47": 578,
12 | "12": 77,
13 | "23": 105,
14 | "35": 303,
15 | "45": 666,
16 | "15": 96,
17 | "43": 636,
18 | "41": 544,
19 | "36": 319,
20 | "49": 673,
21 | "39": 455,
22 | "19": 124,
23 | "42": 600,
24 | "25": 133,
25 | "34": 248,
26 | "9": 43,
27 | "6": 20,
28 | "29": 133,
29 | "13": 68,
30 | "10": 56,
31 | "16": 121,
32 | "38": 435,
33 | "37": 383,
34 | "17": 171,
35 | "11": 92,
36 | "33": 226,
37 | "46": 561,
38 | "31": 177,
39 | "28": 147,
40 | "32": 219,
41 | "14": 84,
42 | "24": 100,
43 | "21": 151,
44 | "26": 101,
45 | "3": 9,
46 | "8": 30,
47 | "5": 12,
48 | "7": 15,
49 | "4": 9,
50 | "2": 3
51 | }
--------------------------------------------------------------------------------
/baselines/models/bert_ner/label2id.json:
--------------------------------------------------------------------------------
1 | {
2 | "O": 0,
3 | "S_address": 1,
4 | "B_address": 2,
5 | "M_address": 3,
6 | "E_address": 4,
7 | "S_book": 5,
8 | "B_book": 6,
9 | "M_book": 7,
10 | "E_book": 8,
11 | "S_company": 9,
12 | "B_company": 10,
13 | "M_company": 11,
14 | "E_company": 12,
15 | "S_game": 13,
16 | "B_game": 14,
17 | "M_game": 15,
18 | "E_game": 16,
19 | "S_government": 17,
20 | "B_government": 18,
21 | "M_government": 19,
22 | "E_government": 20,
23 | "S_movie": 21,
24 | "B_movie": 22,
25 | "M_movie": 23,
26 | "E_movie": 24,
27 | "S_name": 25,
28 | "B_name": 26,
29 | "M_name": 27,
30 | "E_name": 28,
31 | "S_organization": 29,
32 | "B_organization": 30,
33 | "M_organization": 31,
34 | "E_organization": 32,
35 | "S_position": 33,
36 | "B_position": 34,
37 | "M_position": 35,
38 | "E_position": 36,
39 | "S_scene": 37,
40 | "B_scene": 38,
41 | "M_scene": 39,
42 | "E_scene": 40
43 | }
--------------------------------------------------------------------------------
/Schedule4clue.md:
--------------------------------------------------------------------------------
1 | - [x] 3.12列出比赛信息相关的时间节点
2 | - [ ] 3.13 给出比赛说明的详细版本draft-中文版。
3 | - [ ] 群内审核
4 | - [ ] 蓝博士,杨老师,李老师等老师请教意见
5 | - [ ] 明确排行榜的样式以及对应的周赛形式 - 3.15之前
6 | - [ ] 评估数据集:明确我们用哪些数据集,是否区分最终比赛和平时测试的比赛数据集 - 3.15之前确定方案,之后完善细节
7 | - [ ] 3.25-4.6 使用网络上的数据集,作为热身
8 | - [ ] 4.6-5.15 我们提供新的数据集,作为最终结果评测【是否要保留一部分测试集合作为5.15才释放的】
9 | - [ ] 以上的数据集选择我们需要有根据
10 | - [ ] 整理出baseline,以及对应的数据和评估脚本-2020.3.20之前
11 | - [ ] 测试系统,评测,并走通流程-2020.3.25之前。容错到4.6。可以做测试
12 |
13 | --- 以上为比较紧急的任务。
14 |
15 | - [ ] 列出本次比赛中的推荐参考文献,可以有选择性的解读 - 3.20之前完成
16 | 1. Albert【话说如果只修改NSP会不会也会有效果呢】
17 | 2. Electra
18 | 3. Distilling the knowledge in a Neural Network (Hinton)
19 | 4. Tinybert: Distilling Bert for Natrual Language Understanding
20 | 5. Attention-guided answer distillation for machine reading comprehension
21 | 6. efficientNet: Rethinking Model scaling for Convolution NN
22 | - [ ] 评估标准:以一定的size确定,使结果尽可能好【size明确一个数值】
23 | - [ ]
24 |
25 |
26 |
27 | ------ 以上使在3.25之前需要完成的事情
28 |
29 | - [ ] 撰写相关评测文档 并对作者提交结果进行check
30 |
31 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_inews.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="inews"
4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/
5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME
6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
7 |
8 | python $CURRENT_DIR/../run_classifier.py \
9 | --task_name=$TASK_NAME \
10 | --do_train=true \
11 | --do_eval=true \
12 | --data_dir=$DATA_DIR \
13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
16 | --max_seq_length=512 \
17 | --train_batch_size=16 \
18 | --learning_rate=2e-5 \
19 | --num_train_epochs=8.0 \
20 | --output_dir=$OUTPUT_DIR \
21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_lcqmc.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="lcqmc"
4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12
5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/hard_$TASK_NAME
6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
7 |
8 | python $CURRENT_DIR/../run_classifier.py \
9 | --task_name=$TASK_NAME \
10 | --do_train=true \
11 | --do_eval=true \
12 | --data_dir=$DATA_DIR \
13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
16 | --max_seq_length=128 \
17 | --train_batch_size=16 \
18 | --learning_rate=2e-5 \
19 | --num_train_epochs=8.0 \
20 | --output_dir=$OUTPUT_DIR \
21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://172.20.0.2:8470
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_xnli.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="xnli"
4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/
5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME
6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
7 |
8 | python $CURRENT_DIR/../run_classifier.py \
9 | --task_name=$TASK_NAME \
10 | --do_train=true \
11 | --do_eval=true \
12 | --data_dir=$DATA_DIR \
13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
16 | --max_seq_length=512 \
17 | --train_batch_size=16 \
18 | --learning_rate=2e-5 \
19 | --num_train_epochs=8.0 \
20 | --output_dir=$OUTPUT_DIR \
21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_thucnews.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="thucnews"
4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/al/bert-base/chinese_L-12_H-768_A-12/
5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/$TASK_NAME
6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
7 |
8 | python $CURRENT_DIR/../run_classifier.py \
9 | --task_name=$TASK_NAME \
10 | --do_train=true \
11 | --do_eval=true \
12 | --data_dir=$DATA_DIR \
13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
16 | --max_seq_length=512 \
17 | --train_batch_size=16 \
18 | --learning_rate=2e-5 \
19 | --num_train_epochs=8.0 \
20 | --output_dir=$OUTPUT_DIR \
21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://10.1.101.2:8470
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_jdcomment.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="jdcomment"
4 | export PREV_TRAINED_MODEL_DIR=gs://models_zxw/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12
5 | export DATA_DIR=gs://data_zxw/nlp/chineseGLUEdatasets.v0.0.1/hard_${TASK_NAME}
6 | export OUTPUT_DIR=gs://models_zxw/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
7 | echo $DATA_DIR
8 | python3 $CURRENT_DIR/../run_classifier.py \
9 | --task_name=$TASK_NAME \
10 | --do_train=true \
11 | --do_eval=true \
12 | --data_dir=$DATA_DIR \
13 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
14 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
15 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
16 | --max_seq_length=128 \
17 | --train_batch_size=32 \
18 | --learning_rate=2e-5 \
19 | --num_train_epochs=3.0 \
20 | --output_dir=$OUTPUT_DIR \
21 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://172.18.0.2:8470
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/run_cmrc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #########################################################################
3 | # File Name: run.sh
4 | # Author: Junyi Li
5 | # Personal page: dukeenglish.github.io
6 | # Created Time: 21:36:42 2020-03-30
7 | #########################################################################
8 | PATH_TO_BERT=../bert_wsc_csl/prev_trained_models/RoBERTa-tiny-clue
9 |
10 | DATA_DIR=../..//CLUEdataset/cmrc
11 |
12 | MODEL_DIR=.
13 |
14 |
15 | python run_cmrc2018_drcd_baseline.py \
16 | --vocab_file=${PATH_TO_BERT}/vocab.txt \
17 | --bert_config_file=${PATH_TO_BERT}/bert_config.json \
18 | --init_checkpoint=${PATH_TO_BERT}/bert_model.ckpt \
19 | --do_train=True \
20 | --train_file=${DATA_DIR}/train.json \
21 | --do_predict=True \
22 | --predict_file=${DATA_DIR}/test.json \
23 | --train_batch_size=32 \
24 | --num_train_epochs=2 \
25 | --max_seq_length=512 \
26 | --doc_stride=128 \
27 | --learning_rate=3e-5 \
28 | --save_checkpoints_steps=1000 \
29 | --output_dir=${MODEL_DIR} \
30 | --do_lower_case=False \
31 | --use_tpu=False
32 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yiming Cui
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | BERT needs to maintain permanent compatibility with the pre-trained model files,
4 | so we do not plan to make any major changes to this library (other than what was
5 | promised in the README). However, we can accept small patches related to
6 | re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 如何训练、提交测试
4 |
5 | 环境:Python 3 & Tensorflow 1.x,如1.14;
6 |
7 | ## 模型1 一键运行RoBERTa-wwm-large
8 |
9 | nohup bash run_classifier_roberta_wwm_large.sh &
10 |
11 |
12 | ## 模型2
13 |
14 | 第一步, 生成tf_record
15 | 修改 data_processor_seq.py 里面 函数的输入输出路径即可
16 | ```
17 | python data_processor_seq.py
18 | ```
19 |
20 | 第二步, 训练ner模型
21 | 修改 train_sequence_label.py 里面 config字典即可(如模型参数、文件路径等)
22 | ```
23 | python train_sequence_label.py
24 | ```
25 |
26 | 第三步, 加载模型进行测试
27 | 修改 predict_sequence_label.py 里面 model_path(保存模型的路径), 以及预测文件路径即可
28 | ```
29 | python predict_sequence_label.py
30 | ```
31 |
32 | ### 评估
33 | 以F1-Score为评测指标,修改 score.py 里面 pre ,gold文件即可(验证可用),测试阶段不提供哦
34 | ```
35 | python score.py
36 | ```
37 |
38 | | 模型 | 线上效果f1 |
39 | |:-------------:|:-----:|
40 | | bilstm+crf | 70.00 |
41 | | bert-base | 78.82 |
42 | | roberta-wwm-large-ext | **80.42** |
43 | |Human Performance|63.41|
44 |
45 | 各个实体的评测结果:
46 |
47 |
48 | | 实体 | bilstm+crf | bert-base | roberta-wwm-large-ext | Human Performance |
49 | |:-------------:|:-----:|:-----:|:-----:|:-----:|
50 | | Person Name | 74.04 | 88.75 | **89.09** | 74.49 |
51 | | Organization | 75.96 | 79.43 | **82.34** | 65.41 |
52 | | Position | 70.16 | 78.89 | **79.62** | 55.38 |
53 | | Company | 72.27 | 81.42 | **83.02** | 49.32 |
54 | | Address | 45.50 | 60.89 | **62.63** | 43.04 |
55 | | Game | 85.27 | 86.42 | **86.80** | 80.39 |
56 | | Government | 77.25 | 87.03 | **88.17** | 79.27 |
57 | | Scene | 52.42 | 65.10 | **70.49** | 51.85 |
58 | | Book | 67.20 | 73.68 | **74.60** | 71.70 |
59 | | Movie | 78.97 | 85.82 | **87.46** | 63.21 |
60 |
61 | 更具体的评测结果,请参考我们的技术报告:https://arxiv.org/abs/2001.04351
62 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import optimization
20 | import tensorflow as tf
21 |
22 |
23 | class OptimizationTest(tf.test.TestCase):
24 |
25 | def test_adam(self):
26 | with self.test_session() as sess:
27 | w = tf.get_variable(
28 | "w",
29 | shape=[3],
30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
31 | x = tf.constant([0.4, 0.2, -0.5])
32 | loss = tf.reduce_mean(tf.square(x - w))
33 | tvars = tf.trainable_variables()
34 | grads = tf.gradients(loss, tvars)
35 | global_step = tf.train.get_or_create_global_step()
36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
38 | init_op = tf.group(tf.global_variables_initializer(),
39 | tf.local_variables_initializer())
40 | sess.run(init_op)
41 | for _ in range(100):
42 | sess.run(train_op)
43 | w_np = sess.run(w)
44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 |
115 | # Pyre type checker
116 | .pyre/
117 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/run_classifier_tiny.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # @Author: bo.shi
3 | # @Date: 2019-11-04 09:56:36
4 | # @Last Modified by: bo.shi
5 | # @Last Modified time: 2019-12-05 11:23:30
6 |
7 | TASK_NAME="ner"
8 | MODEL_NAME="RoBERTa-tiny-clue"
9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
10 | export CUDA_VISIBLE_DEVICES="0"
11 | export PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_models
12 | export ROBERTA_CLUE_DIR=$PRETRAINED_MODELS_DIR/$MODEL_NAME
13 | export CLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset
14 |
15 | # download and unzip dataset
16 | if [ ! -d $CLUE_DATA_DIR ]; then
17 | mkdir -p $CLUE_DATA_DIR
18 | echo "makedir $CLUE_DATA_DIR"
19 | fi
20 | cd $CLUE_DATA_DIR
21 | if [ ! -d $TASK_NAME ]; then
22 | mkdir $TASK_NAME
23 | echo "makedir $CLUE_DATA_DIR/$TASK_NAME"
24 | fi
25 | cd $TASK_NAME
26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
27 | rm *
28 | wget https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip
29 | unzip cluener_public.zip
30 | rm cluener_public.zip
31 | else
32 | echo "data exists"
33 | fi
34 | echo "Finish download dataset."
35 |
36 | # download model
37 | if [ ! -d $ROBERTA_CLUE_DIR ]; then
38 | mkdir -p $ROBERTA_CLUE_DIR
39 | echo "makedir $ROBERTA_CLUE_DIR"
40 | fi
41 | cd $ROBERTA_CLUE_DIR
42 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then
43 | rm *
44 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-clue.zip
45 | unzip RoBERTa-tiny-clue.zip
46 | rm RoBERTa-tiny-clue.zip
47 | else
48 | echo "model exists"
49 | fi
50 | echo "Finish download model."
51 |
52 | # run task
53 | cd $CURRENT_DIR
54 | echo "Start running..."
55 |
56 | python run_classifier_ner.py \
57 | --task_name=$TASK_NAME \
58 | --do_train=False \
59 | --do_predict=True \
60 | --data_dir=$CLUE_DATA_DIR/$TASK_NAME \
61 | --vocab_file=$ROBERTA_CLUE_DIR/vocab.txt \
62 | --bert_config_file=$ROBERTA_CLUE_DIR/bert_config.json \
63 | --init_checkpoint=$ROBERTA_CLUE_DIR/bert_model.ckpt \
64 | --max_seq_length=128 \
65 | --train_batch_size=32 \
66 | --learning_rate=2e-5 \
67 | --num_train_epochs=4.0 \
68 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # mac
2 | .DS_Store
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/data-ming.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # coding:utf8
3 | """
4 | @author: Cong Yu
5 | @time: 2020-01-09 18:51
6 | """
7 | import re
8 | import json
9 |
10 |
11 | def prepare_label():
12 | text = """
13 | 地址(address): 544
14 | 书名(book): 258
15 | 公司(company): 479
16 | 游戏(game): 281
17 | 政府(government): 262
18 | 电影(movie): 307
19 | 姓名(name): 710
20 | 组织机构(organization): 515
21 | 职位(position): 573
22 | 景点(scene): 288
23 | """
24 |
25 | a = re.findall(r"((.*?))", text.strip())
26 | print(a)
27 | label2id = {"O": 0}
28 | index = 1
29 | for i in a:
30 | label2id["S_" + i] = index
31 | label2id["B_" + i] = index + 1
32 | label2id["M_" + i] = index + 2
33 | label2id["E_" + i] = index + 3
34 | index += 4
35 |
36 | open("label2id.json", "w").write(json.dumps(label2id, ensure_ascii=False, indent=2))
37 |
38 |
39 | def prepare_len_count():
40 | len_count = {}
41 |
42 | for line in open("data/thuctc_train.json"):
43 | if line.strip():
44 | _ = json.loads(line.strip())
45 | len_ = len(_["text"])
46 | if len_count.get(len_):
47 | len_count[len_] += 1
48 | else:
49 | len_count[len_] = 1
50 |
51 | for line in open("data/thuctc_valid.json"):
52 | if line.strip():
53 | _ = json.loads(line.strip())
54 | len_ = len(_["text"])
55 | if len_count.get(len_):
56 | len_count[len_] += 1
57 | else:
58 | len_count[len_] = 1
59 |
60 | print("len_count", json.dumps(len_count, indent=2))
61 | open("len_count.json", "w").write(json.dumps(len_count, indent=2))
62 |
63 |
64 | def label_count(path):
65 | labels = ['address', 'book', 'company', 'game', 'government', 'movie', 'name', 'organization', 'position', 'scene']
66 | label2desc = {
67 | "address": "地址",
68 | "book": "书名",
69 | "company": "公司",
70 | "game": "游戏",
71 | "government": "政府",
72 | "movie": "电影",
73 | "name": "姓名",
74 | "organization": "组织机构",
75 | "position": "职位",
76 | "scene": "景点"
77 | }
78 | label_count_dict = {i: 0 for i in labels}
79 | for line in open(path):
80 | if line.strip():
81 | _ = json.loads(line.strip())
82 | for k, v in _["label"].items():
83 | label_count_dict[k] += len(v)
84 | for k, v in label_count_dict.items():
85 | print("{}({}):{}".format(label2desc[k], k, v))
86 | print("\n")
87 |
88 |
89 | # prepare_label()
90 | # label_count("data/thuctc_train.json")
91 | # label_count("data/thuctc_valid.json")
92 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tpu/run_classifier_tnews.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
2 | CURRENT_TIME=$(date "+%Y%m%d-%H%M%S")
3 | TASK_NAME="tnews"
4 |
5 | GS="gs" # change it to yours
6 | TPU_IP="1.1.1.1" # chagne it to your
7 | # please create folder
8 | export PREV_TRAINED_MODEL_DIR=$GS/prev_trained_models/nlp/bert-base/chinese_L-12_H-768_A-12
9 | export DATA_DIR=$GS/nlp/chineseGLUEdatasets.v0.0.1/${TASK_NAME}
10 | export OUTPUT_DIR=$GS/fine_tuning_models/nlp/bert-base/chinese_L-12_H-768_A-12/tpu/$TASK_NAME/$CURRENT_TIME
11 |
12 |
13 | MODEL_NAME="chinese_L-12_H-768_A-12"
14 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
15 | export CUDA_VISIBLE_DEVICES="0"
16 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model
17 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME
18 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset
19 |
20 | # download and unzip dataset
21 | if [ ! -d $GLUE_DATA_DIR ]; then
22 | mkdir -p $GLUE_DATA_DIR
23 | echo "makedir $GLUE_DATA_DIR"
24 | fi
25 | cd $GLUE_DATA_DIR
26 | if [ ! -d $TASK_NAME ]; then
27 | mkdir $TASK_NAME
28 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME"
29 | fi
30 | cd $TASK_NAME
31 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
32 | rm *
33 | wget https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip
34 | unzip tnews_public.zip
35 | rm tnews_public.zip
36 | else
37 | echo "data exists"
38 | fi
39 | echo "Finish download dataset."
40 |
41 | # download model
42 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then
43 | mkdir -p $BERT_PRETRAINED_MODELS_DIR
44 | echo "makedir $BERT_PRETRAINED_MODELS_DIR"
45 | fi
46 | cd $BERT_PRETRAINED_MODELS_DIR
47 | if [ ! -d $MODEL_NAME ]; then
48 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
49 | unzip chinese_L-12_H-768_A-12.zip
50 | rm chinese_L-12_H-768_A-12.zip
51 | else
52 | cd $MODEL_NAME
53 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then
54 | cd ..
55 | rm -rf $MODEL_NAME
56 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
57 | unzip chinese_L-12_H-768_A-12.zip
58 | rm chinese_L-12_H-768_A-12.zip
59 | else
60 | echo "model exists"
61 | fi
62 | fi
63 | echo "Finish download model."
64 |
65 | # upload model and data
66 | gsutil -m cp $BERT_PRETRAINED_MODELS_DIR/* $PREV_TRAINED_MODEL_DIR
67 |
68 | gsutil -m cp $GLUE_DATA_DIR/$TASK_NAME/* $DATA_DIR
69 |
70 | cd $CURRENT_DIR
71 | echo "Start running..."
72 | python $CURRENT_DIR/../run_classifier.py \
73 | --task_name=$TASK_NAME \
74 | --do_train=true \
75 | --do_eval=true \
76 | --data_dir=$DATA_DIR \
77 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
78 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
79 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
80 | --max_seq_length=128 \
81 | --train_batch_size=32 \
82 | --learning_rate=2e-5 \
83 | --num_train_epochs=3.0 \
84 | --output_dir=$OUTPUT_DIR \
85 | --num_tpu_cores=8 --use_tpu=True --tpu_name=grpc://$TPU_IP:8470
86 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/run_classifier_csl.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # @Author: Li Yudong
3 | # @Date: 2019-11-28
4 | # @Last Modified by: bo.shi
5 | # @Last Modified time: 2019-12-05 11:00:57
6 |
7 | TASK_NAME="csl"
8 | MODEL_NAME="chinese_L-12_H-768_A-12"
9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
10 | export CUDA_VISIBLE_DEVICES="0"
11 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model
12 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME
13 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset
14 |
15 | # download and unzip dataset
16 | if [ ! -d $GLUE_DATA_DIR ]; then
17 | mkdir -p $GLUE_DATA_DIR
18 | echo "makedir $GLUE_DATA_DIR"
19 | fi
20 | cd $GLUE_DATA_DIR
21 | if [ ! -d $TASK_NAME ]; then
22 | mkdir $TASK_NAME
23 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME"
24 | fi
25 | cd $TASK_NAME
26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
27 | rm *
28 | wget https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip
29 | unzip csl_public.zip
30 | rm csl_public.zip
31 | else
32 | echo "data exists"
33 | fi
34 | echo "Finish download dataset."
35 |
36 | # download model
37 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then
38 | mkdir -p $BERT_PRETRAINED_MODELS_DIR
39 | echo "makedir $BERT_PRETRAINED_MODELS_DIR"
40 | fi
41 | cd $BERT_PRETRAINED_MODELS_DIR
42 | if [ ! -d $MODEL_NAME ]; then
43 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
44 | unzip chinese_L-12_H-768_A-12.zip
45 | rm chinese_L-12_H-768_A-12.zip
46 | else
47 | cd $MODEL_NAME
48 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then
49 | cd ..
50 | rm -rf $MODEL_NAME
51 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
52 | unzip chinese_L-12_H-768_A-12.zip
53 | rm chinese_L-12_H-768_A-12.zip
54 | else
55 | echo "model exists"
56 | fi
57 | fi
58 | echo "Finish download model."
59 |
60 | # run task
61 | cd $CURRENT_DIR
62 | echo "Start running..."
63 | if [ $# == 0 ]; then
64 | python run_classifier.py \
65 | --task_name=$TASK_NAME \
66 | --do_train=true \
67 | --do_eval=true \
68 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \
69 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
70 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
71 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
72 | --max_seq_length=128 \
73 | --train_batch_size=32 \
74 | --learning_rate=2e-5 \
75 | --num_train_epochs=3.0 \
76 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/
77 | elif [ $1 == "predict" ]; then
78 | echo "Start predict..."
79 | python run_classifier.py \
80 | --task_name=$TASK_NAME \
81 | --do_train=false \
82 | --do_eval=false \
83 | --do_predict=true \
84 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \
85 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
86 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
87 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
88 | --max_seq_length=128 \
89 | --train_batch_size=32 \
90 | --learning_rate=2e-5 \
91 | --num_train_epochs=3.0 \
92 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/
93 | fi
94 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/run_classifier_wsc.sh:
--------------------------------------------------------------------------------
1 | # @Author: bo.shi
2 | # @Date: 2019-12-01 22:28:41
3 | # @Last Modified by: bo.shi
4 | # @Last Modified time: 2019-12-05 11:01:11
5 | #!/usr/bin/env bash
6 |
7 | TASK_NAME="cluewsc2020"
8 | MODEL_NAME="chinese_L-12_H-768_A-12"
9 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
10 | export CUDA_VISIBLE_DEVICES="0"
11 | export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model
12 | export BERT_BASE_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME
13 | export GLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset
14 |
15 | # download and unzip dataset
16 | if [ ! -d $GLUE_DATA_DIR ]; then
17 | mkdir -p $GLUE_DATA_DIR
18 | echo "makedir $GLUE_DATA_DIR"
19 | fi
20 | cd $GLUE_DATA_DIR
21 | if [ ! -d $TASK_NAME ]; then
22 | mkdir $TASK_NAME
23 | echo "makedir $GLUE_DATA_DIR/$TASK_NAME"
24 | fi
25 | cd $TASK_NAME
26 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
27 | rm *
28 | wget https://storage.googleapis.com/cluebenchmark/tasks/cluewsc2020_public.zip
29 | unzip cluewsc2020_public.zip
30 | rm cluewsc2020_public.zip
31 | else
32 | echo "data exists"
33 | fi
34 | echo "Finish download dataset."
35 |
36 | # download model
37 | if [ ! -d $BERT_PRETRAINED_MODELS_DIR ]; then
38 | mkdir -p $BERT_PRETRAINED_MODELS_DIR
39 | echo "makedir $BERT_PRETRAINED_MODELS_DIR"
40 | fi
41 | cd $BERT_PRETRAINED_MODELS_DIR
42 | if [ ! -d $MODEL_NAME ]; then
43 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
44 | unzip chinese_L-12_H-768_A-12.zip
45 | rm chinese_L-12_H-768_A-12.zip
46 | else
47 | cd $MODEL_NAME
48 | if [ ! -f "bert_config.json" ] || [ ! -f "vocab.txt" ] || [ ! -f "bert_model.ckpt.index" ] || [ ! -f "bert_model.ckpt.meta" ] || [ ! -f "bert_model.ckpt.data-00000-of-00001" ]; then
49 | cd ..
50 | rm -rf $MODEL_NAME
51 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
52 | unzip chinese_L-12_H-768_A-12.zip
53 | rm chinese_L-12_H-768_A-12.zip
54 | else
55 | echo "model exists"
56 | fi
57 | fi
58 | echo "Finish download model."
59 |
60 | # run task
61 | cd $CURRENT_DIR
62 | echo "Start running..."
63 | if [ $# == 0 ]; then
64 | python run_classifier.py \
65 | --task_name=$TASK_NAME \
66 | --do_train=true \
67 | --do_eval=true \
68 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \
69 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
70 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
71 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
72 | --max_seq_length=128 \
73 | --train_batch_size=32 \
74 | --learning_rate=2e-5 \
75 | --num_train_epochs=3.0 \
76 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/
77 | elif [ $1 == "predict" ]; then
78 | echo "Start predict..."
79 | python run_classifier.py \
80 | --task_name=$TASK_NAME \
81 | --do_train=false \
82 | --do_eval=false \
83 | --do_predict=true \
84 | --data_dir=$GLUE_DATA_DIR/$TASK_NAME \
85 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
86 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
87 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
88 | --max_seq_length=128 \
89 | --train_batch_size=32 \
90 | --learning_rate=2e-5 \
91 | --num_train_epochs=3.0 \
92 | --output_dir=$CURRENT_DIR/${TASK_NAME}_output/
93 | fi
94 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/cmrc2018_evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5 - special
5 | Note:
6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
7 | v5: formatted output, add usage description
8 | v4: fixed segmentation issues
9 | '''
10 | from __future__ import print_function
11 | from collections import Counter, OrderedDict
12 | import string
13 | import re
14 | import argparse
15 | import json
16 | import sys
17 | reload(sys)
18 | sys.setdefaultencoding('utf8')
19 | import nltk
20 | import pdb
21 |
22 | # split Chinese with English
23 | def mixed_segmentation(in_str, rm_punc=False):
24 | in_str = str(in_str).decode('utf-8').lower().strip()
25 | segs_out = []
26 | temp_str = ""
27 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
28 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
29 | '「','」','(',')','-','~','『','』']
30 | for char in in_str:
31 | if rm_punc and char in sp_char:
32 | continue
33 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
34 | if temp_str != "":
35 | ss = nltk.word_tokenize(temp_str)
36 | segs_out.extend(ss)
37 | temp_str = ""
38 | segs_out.append(char)
39 | else:
40 | temp_str += char
41 |
42 | #handling last part
43 | if temp_str != "":
44 | ss = nltk.word_tokenize(temp_str)
45 | segs_out.extend(ss)
46 |
47 | return segs_out
48 |
49 |
50 | # remove punctuation
51 | def remove_punctuation(in_str):
52 | in_str = str(in_str).decode('utf-8').lower().strip()
53 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
54 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
55 | '「','」','(',')','-','~','『','』']
56 | out_segs = []
57 | for char in in_str:
58 | if char in sp_char:
59 | continue
60 | else:
61 | out_segs.append(char)
62 | return ''.join(out_segs)
63 |
64 |
65 | # find longest common string
66 | def find_lcs(s1, s2):
67 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
68 | mmax = 0
69 | p = 0
70 | for i in range(len(s1)):
71 | for j in range(len(s2)):
72 | if s1[i] == s2[j]:
73 | m[i+1][j+1] = m[i][j]+1
74 | if m[i+1][j+1] > mmax:
75 | mmax=m[i+1][j+1]
76 | p=i+1
77 | return s1[p-mmax:p], mmax
78 |
79 | #
80 | def evaluate(ground_truth_file, prediction_file):
81 | f1 = 0
82 | em = 0
83 | total_count = 0
84 | skip_count = 0
85 | for instance in ground_truth_file["data"]:
86 | #context_id = instance['context_id'].strip()
87 | #context_text = instance['context_text'].strip()
88 | for para in instance["paragraphs"]:
89 | for qas in para['qas']:
90 | total_count += 1
91 | query_id = qas['id'].strip()
92 | query_text = qas['question'].strip()
93 | answers = [x["text"] for x in qas['answers']]
94 |
95 | if query_id not in prediction_file:
96 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
97 | skip_count += 1
98 | continue
99 |
100 | prediction = str(prediction_file[query_id]).decode('utf-8')
101 | f1 += calc_f1_score(answers, prediction)
102 | em += calc_em_score(answers, prediction)
103 |
104 | f1_score = 100.0 * f1 / total_count
105 | em_score = 100.0 * em / total_count
106 | return f1_score, em_score, total_count, skip_count
107 |
108 |
109 | def calc_f1_score(answers, prediction):
110 | f1_scores = []
111 | for ans in answers:
112 | ans_segs = mixed_segmentation(ans, rm_punc=True)
113 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
114 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
115 | if lcs_len == 0:
116 | f1_scores.append(0)
117 | continue
118 | precision = 1.0*lcs_len/len(prediction_segs)
119 | recall = 1.0*lcs_len/len(ans_segs)
120 | f1 = (2*precision*recall)/(precision+recall)
121 | f1_scores.append(f1)
122 | return max(f1_scores)
123 |
124 |
125 | def calc_em_score(answers, prediction):
126 | em = 0
127 | for ans in answers:
128 | ans_ = remove_punctuation(ans)
129 | prediction_ = remove_punctuation(prediction)
130 | if ans_ == prediction_:
131 | em = 1
132 | break
133 | return em
134 |
135 | if __name__ == '__main__':
136 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
137 | parser.add_argument('dataset_file', help='Official dataset file')
138 | parser.add_argument('prediction_file', help='Your prediction File')
139 | args = parser.parse_args()
140 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
141 | prediction_file = json.load(open(args.prediction_file, 'rb'))
142 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
143 | AVG = (EM+F1)*0.5
144 | output_result = OrderedDict()
145 | output_result['AVERAGE'] = '%.3f' % AVG
146 | output_result['F1'] = '%.3f' % F1
147 | output_result['EM'] = '%.3f' % EM
148 | output_result['TOTAL'] = TOTAL
149 | output_result['SKIP'] = SKIP
150 | output_result['FILE'] = args.prediction_file
151 | print(json.dumps(output_result))
152 |
153 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/run_classifier_clue.sh:
--------------------------------------------------------------------------------
1 | # @Author: bo.shi
2 | # @Date: 2020-03-15 16:11:00
3 | # @Last Modified by: bo.shi
4 | # @Last Modified time: 2020-03-17 13:06:02
5 | #!/usr/bin/env bash
6 | CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
7 | CLUE_DATA_DIR=$CURRENT_DIR/../../CLUEdataset
8 | CLUE_PREV_TRAINED_MODEL_DIR=$CURRENT_DIR/prev_trained_models
9 |
10 | download_data(){
11 | TASK_NAME=$1
12 | if [ ! -d $CLUE_DATA_DIR ]; then
13 | mkdir -p $CLUE_DATA_DIR
14 | echo "makedir $CLUE_DATA_DIR"
15 | fi
16 | cd $CLUE_DATA_DIR
17 | if [ ! -d ${TASK_NAME} ]; then
18 | mkdir $TASK_NAME
19 | echo "make dataset dir $CLUE_DATA_DIR/$TASK_NAME"
20 | fi
21 | cd $TASK_NAME
22 | if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
23 | rm *
24 | if [ $TASK_NAME = "wsc" ];then
25 | wget https://storage.googleapis.com/cluebenchmark/tasks/clue${TASK_NAME}2020_public.zip
26 | unzip clue${TASK_NAME}2020_public.zip
27 | rm clue${TASK_NAME}2020_public.zip
28 | else
29 | wget https://storage.googleapis.com/cluebenchmark/tasks/${TASK_NAME}_public.zip
30 | unzip ${TASK_NAME}_public.zip
31 | rm ${TASK_NAME}_public.zip
32 | fi
33 | else
34 | echo "data exists"
35 | fi
36 | echo "Finish download dataset."
37 | }
38 |
39 | download_model(){
40 | MODEL_NAME=$1
41 | if [ ! -d $CLUE_PREV_TRAINED_MODEL_DIR ]; then
42 | mkdir -p $CLUE_PREV_TRAINED_MODEL_DIR
43 | echo "make prev_trained_model dir $BERT_PRETRAINED_MODELS_DIR"
44 | fi
45 | cd $CLUE_PREV_TRAINED_MODEL_DIR
46 | if [ ! -d $MODEL_NAME ]; then
47 | mkdir -p $MODEL_NAME
48 | else
49 | cd $MODEL_NAME
50 | rm *
51 | if [ "$MODEL_NAME" = "RoBERTa-tiny-clue" ]; then
52 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-clue.zip
53 | unzip RoBERTa-tiny-clue.zip
54 | rm RoBERTa-tiny-clue.zip
55 | elif [ "$MODEL_NAME" = "RoBERTa-tiny-pair" ]; then
56 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny-pair.zip
57 | unzip RoBERTa-tiny-pair.zip
58 | rm RoBERTa-tiny-pair.zip
59 | elif [ "$MODEL_NAME" = "RoBERTa-tiny3L768-clue" ]; then
60 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny3L768-clue.zip
61 | unzip RoBERTa-tiny3L768-clue.zip
62 | rm RoBERTa-tiny3L768-clue.zip
63 | elif [ "$MODEL_NAME" = "RoBERTa-tiny3L312-clue" ]; then
64 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-tiny3L312-clue.zip
65 | unzip RoBERTa-tiny3L312-clue.zip
66 | rm RoBERTa-tiny3L312-clue.zip
67 | elif [ "$MODEL_NAME" = "RoBERTa-large-clue" ]; then
68 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-large-clue.zip
69 | unzip RoBERTa-large-clue.zip
70 | rm RoBERTa-large-clue.zip
71 | elif [ "$MODEL_NAME" = "RoBERTa-large-pair" ]; then
72 | wget -c https://storage.googleapis.com/cluebenchmark/pretrained_models/RoBERTa-large-pair.zip
73 | unzip RoBERTa-large-pair.zip
74 | rm RoBERTa-large-pair.zip
75 | else
76 | echo "unknown model_name, choose from [ RoBERTa-tiny-clue , RoBERTa-tiny-pair , RoBERTa-tiny3L768-clue , RoBERTa-tiny3L312-clue , RoBERTa-large-clue , RoBERTa-large-pair]"
77 | fi
78 | fi
79 | }
80 |
81 | run_task() {
82 | TASK_NAME=$1
83 | MODEL_NAME=$2
84 | download_data $TASK_NAME
85 | #if [ ! -d $CLUE_PREV_TRAINED_MODEL_DIR/$MODEL_NAME ]; then
86 | # download_model $MODEL_NAME
87 | #fi
88 | DATA_DIR=$CLUE_DATA_DIR/${TASK_NAME}
89 | PREV_TRAINED_MODEL_DIR=$CLUE_PREV_TRAINED_MODEL_DIR/$MODEL_NAME
90 | MAX_SEQ_LENGTH=$3
91 | TRAIN_BATCH_SIZE=$4
92 | LEARNING_RATE=$5
93 | NUM_TRAIN_EPOCHS=$6
94 | SAVE_CHECKPOINTS_STEPS=$7
95 | # TPU_IP=$8
96 | OUTPUT_DIR=$CURRENT_DIR/${TASK_NAME}_output/
97 | COMMON_ARGS="
98 | --task_name=$TASK_NAME \
99 | --data_dir=$DATA_DIR \
100 | --vocab_file=$PREV_TRAINED_MODEL_DIR/vocab.txt \
101 | --bert_config_file=$PREV_TRAINED_MODEL_DIR/bert_config.json \
102 | --init_checkpoint=$PREV_TRAINED_MODEL_DIR/bert_model.ckpt \
103 | --max_seq_length=$MAX_SEQ_LENGTH \
104 | --train_batch_size=$TRAIN_BATCH_SIZE \
105 | --learning_rate=$LEARNING_RATE \
106 | --num_train_epochs=$NUM_TRAIN_EPOCHS \
107 | --save_checkpoints_steps=$SAVE_CHECKPOINTS_STEPS \
108 | --output_dir=$OUTPUT_DIR \
109 | --keep_checkpoint_max=0 \
110 | "
111 | cd $CURRENT_DIR
112 | echo "Start running..."
113 | python run_classifier.py \
114 | $COMMON_ARGS \
115 | --do_train=true \
116 | --do_eval=false \
117 | --do_predict=false
118 |
119 | echo "Start predict..."
120 | python run_classifier.py \
121 | $COMMON_ARGS \
122 | --do_train=false \
123 | --do_eval=true \
124 | --do_predict=true
125 | }
126 |
127 |
128 | run_task csl RoBERTa-tiny-clue 128 16 1e-5 0 100
129 | run_task cluewsc2020 RoBERTa-tiny-clue 128 16 1e-5 0 10
130 | #run_task csl RoBERTa-tiny-clue 128 16 1e-5 5 100
131 | #run_task wsc RoBERTa-tiny-clue 128 16 1e-5 10 10
132 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import tempfile
21 | import tokenization
22 | import six
23 | import tensorflow as tf
24 |
25 |
26 | class TokenizationTest(tf.test.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
34 | if six.PY2:
35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
36 | else:
37 | vocab_writer.write("".join(
38 | [x + "\n" for x in vocab_tokens]).encode("utf-8"))
39 |
40 | vocab_file = vocab_writer.name
41 |
42 | tokenizer = tokenization.FullTokenizer(vocab_file)
43 | os.unlink(vocab_file)
44 |
45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
47 |
48 | self.assertAllEqual(
49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
50 |
51 | def test_chinese(self):
52 | tokenizer = tokenization.BasicTokenizer()
53 |
54 | self.assertAllEqual(
55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
56 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
57 |
58 | def test_basic_tokenizer_lower(self):
59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
60 |
61 | self.assertAllEqual(
62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
63 | ["hello", "!", "how", "are", "you", "?"])
64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
65 |
66 | def test_basic_tokenizer_no_lower(self):
67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
68 |
69 | self.assertAllEqual(
70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
71 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
72 |
73 | def test_wordpiece_tokenizer(self):
74 | vocab_tokens = [
75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
76 | "##ing"
77 | ]
78 |
79 | vocab = {}
80 | for (i, token) in enumerate(vocab_tokens):
81 | vocab[token] = i
82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
83 |
84 | self.assertAllEqual(tokenizer.tokenize(""), [])
85 |
86 | self.assertAllEqual(
87 | tokenizer.tokenize("unwanted running"),
88 | ["un", "##want", "##ed", "runn", "##ing"])
89 |
90 | self.assertAllEqual(
91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
92 |
93 | def test_convert_tokens_to_ids(self):
94 | vocab_tokens = [
95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
96 | "##ing"
97 | ]
98 |
99 | vocab = {}
100 | for (i, token) in enumerate(vocab_tokens):
101 | vocab[token] = i
102 |
103 | self.assertAllEqual(
104 | tokenization.convert_tokens_to_ids(
105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
106 |
107 | def test_is_whitespace(self):
108 | self.assertTrue(tokenization._is_whitespace(u" "))
109 | self.assertTrue(tokenization._is_whitespace(u"\t"))
110 | self.assertTrue(tokenization._is_whitespace(u"\r"))
111 | self.assertTrue(tokenization._is_whitespace(u"\n"))
112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
113 |
114 | self.assertFalse(tokenization._is_whitespace(u"A"))
115 | self.assertFalse(tokenization._is_whitespace(u"-"))
116 |
117 | def test_is_control(self):
118 | self.assertTrue(tokenization._is_control(u"\u0005"))
119 |
120 | self.assertFalse(tokenization._is_control(u"A"))
121 | self.assertFalse(tokenization._is_control(u" "))
122 | self.assertFalse(tokenization._is_control(u"\t"))
123 | self.assertFalse(tokenization._is_control(u"\r"))
124 | self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
125 |
126 | def test_is_punctuation(self):
127 | self.assertTrue(tokenization._is_punctuation(u"-"))
128 | self.assertTrue(tokenization._is_punctuation(u"$"))
129 | self.assertTrue(tokenization._is_punctuation(u"`"))
130 | self.assertTrue(tokenization._is_punctuation(u"."))
131 |
132 | self.assertFalse(tokenization._is_punctuation(u"A"))
133 | self.assertFalse(tokenization._is_punctuation(u" "))
134 |
135 |
136 | if __name__ == "__main__":
137 | tf.test.main()
138 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/data_processor_seq.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # coding:utf8
3 | """
4 | @author: Cong Yu
5 | @time: 2019-12-07 17:03
6 | """
7 | import json
8 | import tokenization
9 | import collections
10 | import tensorflow as tf
11 |
12 |
13 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
14 | """Truncates a sequence pair in place to the maximum length."""
15 |
16 | # This is a simple heuristic which will always truncate the longer sequence
17 | # one token at a time. This makes more sense than truncating an equal percent
18 | # of tokens from each, since if one sequence is very short then each token
19 | # that's truncated likely contains more information than a longer sequence.
20 | while True:
21 | total_length = len(tokens_a) + len(tokens_b)
22 | if total_length <= max_length:
23 | break
24 | if len(tokens_a) > len(tokens_b):
25 | tokens_a.pop()
26 | else:
27 | tokens_b.pop()
28 |
29 |
30 | def process_one_example(tokenizer, label2id, text, label, max_seq_len=128):
31 | # textlist = text.split(' ')
32 | # labellist = label.split(' ')
33 | textlist = list(text)
34 | labellist = list(label)
35 | tokens = []
36 | labels = []
37 | for i, word in enumerate(textlist):
38 | token = tokenizer.tokenize(word)
39 | tokens.extend(token)
40 | label_1 = labellist[i]
41 | for m in range(len(token)):
42 | if m == 0:
43 | labels.append(label_1)
44 | else:
45 | print("some unknown token...")
46 | labels.append(labels[0])
47 | # tokens = tokenizer.tokenize(example.text) -2 的原因是因为序列需要加一个句首和句尾标志
48 | if len(tokens) >= max_seq_len - 1:
49 | tokens = tokens[0:(max_seq_len - 2)]
50 | labels = labels[0:(max_seq_len - 2)]
51 | ntokens = []
52 | segment_ids = []
53 | label_ids = []
54 | ntokens.append("[CLS]") # 句子开始设置CLS 标志
55 | segment_ids.append(0)
56 | # [CLS] [SEP] 可以为 他们构建标签,或者 统一到某个标签,反正他们是不变的,基本不参加训练 即:x-l 永远不变
57 | label_ids.append(0) # label2id["[CLS]"]
58 | for i, token in enumerate(tokens):
59 | ntokens.append(token)
60 | segment_ids.append(0)
61 | label_ids.append(label2id[labels[i]])
62 | ntokens.append("[SEP]")
63 | segment_ids.append(0)
64 | # append("O") or append("[SEP]") not sure!
65 | label_ids.append(0) # label2id["[SEP]"]
66 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
67 | input_mask = [1] * len(input_ids)
68 | while len(input_ids) < max_seq_len:
69 | input_ids.append(0)
70 | input_mask.append(0)
71 | segment_ids.append(0)
72 | label_ids.append(0)
73 | ntokens.append("**NULL**")
74 | assert len(input_ids) == max_seq_len
75 | assert len(input_mask) == max_seq_len
76 | assert len(segment_ids) == max_seq_len
77 | assert len(label_ids) == max_seq_len
78 |
79 | feature = (input_ids, input_mask, segment_ids, label_ids)
80 | return feature
81 |
82 |
83 | def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path):
84 | """
85 | 生成训练数据, tf.record, 单标签分类模型, 随机打乱数据
86 | """
87 | writer = tf.python_io.TFRecordWriter(out_path)
88 | example_count = 0
89 |
90 | for line in open(path):
91 | if not line.strip():
92 | continue
93 | _ = json.loads(line.strip())
94 | len_ = len(_["text"])
95 | labels = ["O"] * len_
96 | for k, v in _["label"].items():
97 | for kk, vv in v.items():
98 | for vvv in vv:
99 | span = vvv
100 | s = span[0]
101 | e = span[1] + 1
102 | # print(s, e)
103 | if e - s == 1:
104 | labels[s] = "S_" + k
105 | else:
106 | labels[s] = "B_" + k
107 | for i in range(s + 1, e - 1):
108 | labels[i] = "M_" + k
109 | labels[e - 1] = "E_" + k
110 | # print()
111 | # feature = process_one_example(tokenizer, label2id, row[column_name_x1], row[column_name_y],
112 | # max_seq_len=max_seq_len)
113 | feature = process_one_example(tokenizer, label2id, list(_["text"]), labels,
114 | max_seq_len=max_seq_len)
115 |
116 | def create_int_feature(values):
117 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
118 | return f
119 |
120 | features = collections.OrderedDict()
121 | # 序列标注任务
122 | features["input_ids"] = create_int_feature(feature[0])
123 | features["input_mask"] = create_int_feature(feature[1])
124 | features["segment_ids"] = create_int_feature(feature[2])
125 | features["label_ids"] = create_int_feature(feature[3])
126 | if example_count < 5:
127 | print("*** Example ***")
128 | print(_["text"])
129 | print(_["label"])
130 | print("input_ids: %s" % " ".join([str(x) for x in feature[0]]))
131 | print("input_mask: %s" % " ".join([str(x) for x in feature[1]]))
132 | print("segment_ids: %s" % " ".join([str(x) for x in feature[2]]))
133 | print("label: %s " % " ".join([str(x) for x in feature[3]]))
134 |
135 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
136 | writer.write(tf_example.SerializeToString())
137 | example_count += 1
138 |
139 | # if example_count == 20:
140 | # break
141 | if example_count % 3000 == 0:
142 | print(example_count)
143 | print("total example:", example_count)
144 | writer.close()
145 |
146 |
147 | if __name__ == "__main__":
148 | vocab_file = "./vocab.txt"
149 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
150 | label2id = json.loads(open("label2id.json").read())
151 |
152 | max_seq_len = 64
153 |
154 | prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json",
155 | out_path="data/train.tf_record")
156 | prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json",
157 | out_path="data/dev.tf_record")
158 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/predict_sequence_label.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # coding:utf8
3 | """
4 | @author: Cong Yu
5 | @time: 2019-12-07 20:51
6 | """
7 | import os
8 | import re
9 | import json
10 | import tensorflow as tf
11 | import tokenization
12 |
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14 |
15 | vocab_file = "./vocab.txt"
16 | tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
17 | label2id = json.loads(open("./label2id.json").read())
18 | id2label = [k for k, v in label2id.items()]
19 |
20 |
21 | def process_one_example_p(tokenizer, text, max_seq_len=128):
22 | textlist = list(text)
23 | tokens = []
24 | # labels = []
25 | for i, word in enumerate(textlist):
26 | token = tokenizer.tokenize(word)
27 | # print(token)
28 | tokens.extend(token)
29 | if len(tokens) >= max_seq_len - 1:
30 | tokens = tokens[0:(max_seq_len - 2)]
31 | # labels = labels[0:(max_seq_len - 2)]
32 | ntokens = []
33 | segment_ids = []
34 | label_ids = []
35 | ntokens.append("[CLS]") # 句子开始设置CLS 标志
36 | segment_ids.append(0)
37 | for i, token in enumerate(tokens):
38 | ntokens.append(token)
39 | segment_ids.append(0)
40 | # label_ids.append(label2id[labels[i]])
41 | ntokens.append("[SEP]")
42 | segment_ids.append(0)
43 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
44 | input_mask = [1] * len(input_ids)
45 | while len(input_ids) < max_seq_len:
46 | input_ids.append(0)
47 | input_mask.append(0)
48 | segment_ids.append(0)
49 | label_ids.append(0)
50 | ntokens.append("**NULL**")
51 | assert len(input_ids) == max_seq_len
52 | assert len(input_mask) == max_seq_len
53 | assert len(segment_ids) == max_seq_len
54 |
55 | feature = (input_ids, input_mask, segment_ids)
56 | return feature
57 |
58 |
59 | def load_model(model_folder):
60 | # We retrieve our checkpoint fullpath
61 | try:
62 | checkpoint = tf.train.get_checkpoint_state(model_folder)
63 | input_checkpoint = checkpoint.model_checkpoint_path
64 | print("[INFO] input_checkpoint:", input_checkpoint)
65 | except Exception as e:
66 | input_checkpoint = model_folder
67 | print("[INFO] Model folder", model_folder, repr(e))
68 |
69 | # We clear devices to allow TensorFlow to control on which device it will load operations
70 | clear_devices = True
71 | tf.reset_default_graph()
72 | # We import the meta graph and retrieve a Saver
73 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
74 |
75 | # We start a session and restore the graph weights
76 | sess_ = tf.Session()
77 | saver.restore(sess_, input_checkpoint)
78 |
79 | # opts = sess_.graph.get_operations()
80 | # for v in opts:
81 | # print(v.name)
82 | return sess_
83 |
84 |
85 | model_path = "./ner_bert_base/"
86 | sess = load_model(model_path)
87 | input_ids = sess.graph.get_tensor_by_name("input_ids:0")
88 | input_mask = sess.graph.get_tensor_by_name("input_mask:0") # is_training
89 | segment_ids = sess.graph.get_tensor_by_name("segment_ids:0") # fc/dense/Relu cnn_block/Reshape
90 | keep_prob = sess.graph.get_tensor_by_name("keep_prob:0")
91 | p = sess.graph.get_tensor_by_name("loss/ReverseSequence_1:0")
92 |
93 |
94 | def predict(text):
95 | data = [text]
96 | # 逐个分成 最大62长度的 text 进行 batch 预测
97 | features = []
98 | for i in data:
99 | feature = process_one_example_p(tokenizer_, i, max_seq_len=64)
100 | features.append(feature)
101 | feed = {input_ids: [feature[0] for feature in features],
102 | input_mask: [feature[1] for feature in features],
103 | segment_ids: [feature[2] for feature in features],
104 | keep_prob: 1.0
105 | }
106 |
107 | [probs] = sess.run([p], feed)
108 | result = []
109 | for index, prob in enumerate(probs):
110 | for v in prob[1:len(data[index]) + 1]:
111 | result.append(id2label[int(v)])
112 | print(result)
113 | labels = {}
114 | start = None
115 | index = 0
116 | for w, t in zip("".join(data), result):
117 | if re.search("^[BS]", t):
118 | if start is not None:
119 | label = result[index - 1][2:]
120 | if labels.get(label):
121 | te_ = text[start:index]
122 | # print(te_, labels)
123 | labels[label][te_] = [[start, index - 1]]
124 | else:
125 | te_ = text[start:index]
126 | # print(te_, labels)
127 | labels[label] = {te_: [[start, index - 1]]}
128 | start = index
129 | # print(start)
130 | if re.search("^O", t):
131 | if start is not None:
132 | # print(start)
133 | label = result[index - 1][2:]
134 | if labels.get(label):
135 | te_ = text[start:index]
136 | # print(te_, labels)
137 | labels[label][te_] = [[start, index - 1]]
138 | else:
139 | te_ = text[start:index]
140 | # print(te_, labels)
141 | labels[label] = {te_: [[start, index - 1]]}
142 | # else:
143 | # print(start, labels)
144 | start = None
145 | index += 1
146 | if start is not None:
147 | # print(start)
148 | label = result[start][2:]
149 | if labels.get(label):
150 | te_ = text[start:index]
151 | # print(te_, labels)
152 | labels[label][te_] = [[start, index - 1]]
153 | else:
154 | te_ = text[start:index]
155 | # print(te_, labels)
156 | labels[label] = {te_: [[start, index - 1]]}
157 | # print(labels)
158 | return labels
159 |
160 |
161 | def submit(path):
162 | data = []
163 | for line in open(path):
164 | if not line.strip():
165 | continue
166 | _ = json.loads(line.strip())
167 | res = predict(_["text"])
168 | data.append(json.dumps({"label": res}, ensure_ascii=False))
169 | open("ner_predict.json", "w").write("\n".join(data))
170 |
171 |
172 | if __name__ == "__main__":
173 | text_ = "梅塔利斯在乌克兰联赛、杯赛及联盟杯中保持9场不败,状态相当出色;"
174 | res_ = predict(text_)
175 | print(res_)
176 |
177 | submit("data/thuctc_valid.json")
178 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/optimization.py:
--------------------------------------------------------------------------------
1 | """Functions and classes related to optimization (weight updates)."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import re
8 | import tensorflow as tf
9 |
10 |
11 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
12 | """Creates an optimizer training op."""
13 | global_step = tf.train.get_or_create_global_step()
14 |
15 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
16 |
17 | # Implements linear decay of the learning rate.
18 | learning_rate = tf.train.polynomial_decay(
19 | learning_rate,
20 | global_step,
21 | num_train_steps,
22 | end_learning_rate=0.0,
23 | power=1.0,
24 | cycle=False)
25 |
26 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
27 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
28 | if num_warmup_steps:
29 | global_steps_int = tf.cast(global_step, tf.int32)
30 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
31 |
32 | global_steps_float = tf.cast(global_steps_int, tf.float32)
33 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
34 |
35 | warmup_percent_done = global_steps_float / warmup_steps_float
36 | warmup_learning_rate = init_lr * warmup_percent_done
37 |
38 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
39 | learning_rate = (
40 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
41 |
42 | # It is recommended that you use this optimizer for fine tuning, since this
43 | # is how the model was trained (note that the Adam m/v variables are NOT
44 | # loaded from init_checkpoint.)
45 | optimizer = AdamWeightDecayOptimizer(
46 | learning_rate=learning_rate,
47 | weight_decay_rate=0.01,
48 | beta_1=0.9,
49 | beta_2=0.999,
50 | epsilon=1e-6,
51 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
52 |
53 | if use_tpu:
54 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
55 |
56 | tvars = tf.trainable_variables()
57 | grads = tf.gradients(loss, tvars)
58 |
59 | # This is how the model was pre-trained.
60 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
61 |
62 | train_op = optimizer.apply_gradients(
63 | zip(grads, tvars), global_step=global_step)
64 |
65 | # Normally the global step update is done inside of `apply_gradients`.
66 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
67 | # a different optimizer, you should probably take this line out.
68 | new_global_step = global_step + 1
69 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
70 | return train_op, learning_rate
71 |
72 |
73 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
74 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
75 |
76 | def __init__(self,
77 | learning_rate,
78 | weight_decay_rate=0.0,
79 | beta_1=0.9,
80 | beta_2=0.999,
81 | epsilon=1e-6,
82 | exclude_from_weight_decay=None,
83 | name="AdamWeightDecayOptimizer"):
84 | """Constructs a AdamWeightDecayOptimizer."""
85 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
86 |
87 | self.learning_rate = learning_rate
88 | self.weight_decay_rate = weight_decay_rate
89 | self.beta_1 = beta_1
90 | self.beta_2 = beta_2
91 | self.epsilon = epsilon
92 | self.exclude_from_weight_decay = exclude_from_weight_decay
93 |
94 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
95 | """See base class."""
96 | assignments = []
97 | for (grad, param) in grads_and_vars:
98 | if grad is None or param is None:
99 | continue
100 |
101 | param_name = self._get_variable_name(param.name)
102 |
103 | m = tf.get_variable(
104 | name=param_name + "/adam_m",
105 | shape=param.shape.as_list(),
106 | dtype=tf.float32,
107 | trainable=False,
108 | initializer=tf.zeros_initializer())
109 | v = tf.get_variable(
110 | name=param_name + "/adam_v",
111 | shape=param.shape.as_list(),
112 | dtype=tf.float32,
113 | trainable=False,
114 | initializer=tf.zeros_initializer())
115 |
116 | # Standard Adam update.
117 | next_m = (
118 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
119 | next_v = (
120 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
121 | tf.square(grad)))
122 |
123 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
124 |
125 | # Just adding the square of the weights to the loss function is *not*
126 | # the correct way of using L2 regularization/weight decay with Adam,
127 | # since that will interact with the m and v parameters in strange ways.
128 | #
129 | # Instead we want ot decay the weights in a manner that doesn't interact
130 | # with the m/v parameters. This is equivalent to adding the square
131 | # of the weights to the loss with plain (non-momentum) SGD.
132 | if self._do_use_weight_decay(param_name):
133 | update += self.weight_decay_rate * param
134 |
135 | update_with_lr = self.learning_rate * update
136 |
137 | next_param = param - update_with_lr
138 |
139 | assignments.extend(
140 | [param.assign(next_param),
141 | m.assign(next_m),
142 | v.assign(next_v)])
143 | return tf.group(*assignments, name=name)
144 |
145 | def _do_use_weight_decay(self, param_name):
146 | """Whether to use L2 weight decay for `param_name`."""
147 | if not self.weight_decay_rate:
148 | return False
149 | if self.exclude_from_weight_decay:
150 | for r in self.exclude_from_weight_decay:
151 | if re.search(r, param_name) is not None:
152 | return False
153 | return True
154 |
155 | def _get_variable_name(self, param_name):
156 | """Get the variable name from the tensor name."""
157 | m = re.match("^(.*):\\d+$", param_name)
158 | if m is not None:
159 | param_name = m.group(1)
160 | return param_name
161 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LightLM 高性能小模型测评
2 |
3 | NLPCC 2020 测评任务 Shared Tasks in NLPCC 2020.
4 |
5 | 高性能小模型测评 Task 1 - Light Pre-Training Chinese Language Model for NLP Task
6 |
7 | 现已可以提交到排行榜,提交样例见:提交样例
8 |
9 | 注册地址:
10 |
11 | https://www.cluebenchmarks.com/NLPCC.html
12 |
13 | NLPCC 2020 网站:
14 |
15 | http://tcci.ccf.org.cn/conference/2020/cfpt.php
16 |
17 | 任务guideline:
18 |
19 | http://tcci.ccf.org.cn/conference/2020/dldoc/taskgline01.pdf
20 |
21 | 比赛交流群:
22 |
23 | 
24 |
25 | ## 任务介绍 Task overview
26 |
27 | **Task 1 - 高性能小模型测评 Light Pre-Training Chinese Language Model for NLP Task**
28 |
29 | 这个任务的目标是训练一个和正常大小的语言模型效果相似的轻量级的语言模型。每个提交上来的模型都会在多个不同的下游NLP任务上评估性能。我们将会综合考虑模型参数数量,模型准确率以及模型推理时间,这些将一起作为模型的评估标准。
30 |
31 | The goal of this task is to train a light language model which is still as powerful as the other normal models. Each model will be tested on many different downstream NLP tasks. We would take the number of parameters, accuracy and inference time as the metrics to measure the performance of a model. To meet the challenge of the lack of Chinese corpus, we will provide a big Chinese corpus for this task and will release them for all the researchers later.
32 |
33 | 为了满足很多参赛者对中文语料缺乏的情况,我们提供了目前为止最大中文语料库作为这个任务的补充资源。这些语料将会在之后公布给大家。
34 |
35 | ## **如何报名 How to Participate**
36 |
37 | 任务注册方式:
38 |
39 | (1.1) 访问[www.CLUEbenchmark.com](http://www.CLUEbenchmark.com), 右上角点击【注册】并登录。
40 |
41 | (1.2) 进入【NLPCC测评】tab页,选中【注册栏】后进行比赛注册,并点击提交。
42 |
43 | Registration online with the following steps:
44 |
45 | (1.1) Visit www.CLUEbenchmark.com, and click the button 【注册】 at the top right corner of the page. After that, please log in.
46 |
47 | (1.2) After selecting the【NLPCC测评】in the top navigation bar, please register our task in 【比赛注册】.
48 |
49 | # 本次测评的基线模型、代码和一键运行脚本
50 | baselines/models目录下提供了本次比赛中四个任务的预测脚本,使用方法如下:
51 | ```
52 | 可以进入baselines/models目录下
53 | 1. bert_wsc_csl目录下为WSC和CSL任务: bash run_classifier_clue.sh
54 | - 结果生成在wsc_ouput和csl_output目录下
55 | 2. bert_mrc 为 cmrc2018 任务baseline: bash run_ner_cmrc.sh
56 | - 结果生成为cmrc2018_predict.json
57 | 3. bert_ner 为 CLUENER任务 baseline: bash run_ner.sh
58 | - 结果生成在ner_output下
59 |
60 | ```
61 | 只要加载相应的模型,即可进行预测。本demo中会自动使用"RoBERTa-tiny-clue"进行预测
62 |
63 | 提交指导:将以上的结果,打包为nlpcc-xxx.zip 进行提交。注意各个输出文件的名称是不能变的,要保留,不然系统无法识别。提交示例请见example。
64 |
65 | CLUENER2020、WSC、CSL: 见CLUEPretrainedModels
66 |
67 | CMRC 2018: 见CLUE
68 |
69 | # 任务描述 Dataset Description
70 |
71 | ##### 1. CLUENER2020 细粒度命名实体识别 [详情](https://github.com/CLUEbenchmark/CLUENER2020) [NER数据下载]([数据下载](http://www.cluebenchmark.com/introduce.html))
72 |
73 | ```
74 | 训练集:10748
75 | 验证集集:1343
76 |
77 | 按照不同标签类别统计,训练集数据分布如下(注:一条数据中出现的所有实体都进行标注,如果一条数据出现两个地址(address)实体,那么统计地址(address)类别数据的时候,算两条数据):
78 | 【训练集】标签数据分布如下:
79 | 地址(address):2829
80 | 书名(book):1131
81 | 公司(company):2897
82 | 游戏(game):2325
83 | 政府(government):1797
84 | 电影(movie):1109
85 | 姓名(name):3661
86 | 组织机构(organization):3075
87 | 职位(position):3052
88 | 景点(scene):1462
89 |
90 | 【验证集】标签数据分布如下:
91 | 地址(address):364
92 | 书名(book):152
93 | 公司(company):366
94 | 游戏(game):287
95 | 政府(government):244
96 | 电影(movie):150
97 | 姓名(name):451
98 | 组织机构(organization):344
99 | 职位(position):425
100 | 景点(scene):199
101 | ```
102 |
103 | ##### 2. CLUEWSC2020: WSC Winograd模式挑战中文版,新版2020-03-25发布 CLUEWSC2020数据集下载
104 |
105 | Winograd Scheme Challenge(WSC)是一类代词消歧的任务。新版与原CLUE项目WSC内容不同
106 |
107 | 即判断句子中的代词指代的是哪个名词。题目以真假判别的方式出现,如:
108 |
109 | 句子:这时候放在床上枕头旁边的手机响了,我感到奇怪,因为欠费已被停机两个月,现在它突然响了。需要判断“它”指代的是“床”、“枕头”,还是“手机”?
110 |
111 | 数据来源:数据有CLUE benchmark提供,从中国现当代作家文学作品中抽取,再经语言专家人工挑选、标注。
112 |
113 | 数据形式:
114 |
115 | {"target":
116 | {"span2_index": 37,
117 | "span1_index": 5,
118 | "span1_text": "床",
119 | "span2_text": "它"},
120 | "idx": 261,
121 | "label": "false",
122 | "text": "这时候放在床上枕头旁边的手机响了,我感到奇怪,因为欠费已被停机两个月,现在它突然响了。"}
123 | "true"表示代词确实是指代span1_text中的名词的,"false"代表不是。
124 |
125 | 数据集大小:
126 | - 训练集:1244
127 | - 开发集:304
128 |
129 | ##### 3. CSL 论文关键词识别 Keyword Recognition [详情](https://github.com/CLUEbenchmark/CLUE) [CSL数据集下载](https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip)
130 |
131 | [中文科技文献数据集(CSL)](https://github.com/P01son6415/chinese-scientific-literature-dataset)取自中文论文摘要及其关键词,论文选自部分中文社会科学和自然科学核心期刊。 使用tf-idf生成伪造关键词与论文真实关键词混合,构造摘要-关键词对,任务目标是根据摘要判断关键词是否全部为真实关键词。
132 |
133 | ```
134 | 数据量:训练集(20,000),验证集(3,000),测试集(3,000)
135 | 例子:
136 | {"id": 1, "abst": "为解决传统均匀FFT波束形成算法引起的3维声呐成像分辨率降低的问题,该文提出分区域FFT波束形成算法.远场条件下,以保证成像分辨率为约束条件,以划分数量最少为目标,采用遗传算法作为优化手段将成像区域划分为多个区域.在每个区域内选取一个波束方向,获得每一个接收阵元收到该方向回波时的解调输出,以此为原始数据在该区域内进行传统均匀FFT波束形成.对FFT计算过程进行优化,降低新算法的计算量,使其满足3维成像声呐实时性的要求.仿真与实验结果表明,采用分区域FFT波束形成算法的成像分辨率较传统均匀FFT波束形成算法有显著提高,且满足实时性要求.", "keyword": ["水声学", "FFT", "波束形成", "3维成像声呐"], "label": "1"}
137 | 每一条数据有四个属性,从前往后分别是 数据ID,论文摘要,关键词,真假标签。
138 |
139 | ```
140 |
141 |
142 | ##### 4.CMRC2018 简体中文阅读理解任务 Reading Comprehension for Simplified Chinese [详情](https://github.com/CLUEbenchmark/CLUE) [CMRC2018数据集下载](https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip)
143 |
144 | https://hfl-rc.github.io/cmrc2018/
145 |
146 | ```
147 | 数据量:训练集(短文数2,403,问题数10,142),试验集(短文数256,问题数1,002),开发集(短文数848,问题数3,219)
148 | 例子:
149 | {
150 | "version": "1.0",
151 | "data": [
152 | {
153 | "title": "傻钱策略",
154 | "context_id": "TRIAL_0",
155 | "context_text": "工商协进会报告,12月消费者信心上升到78.1,明显高于11月的72。另据《华尔街日报》报道,2013年是1995年以来美国股市表现最好的一年。这一年里,投资美国股市的明智做法是追着“傻钱”跑。所谓的“傻钱”策略,其实就是买入并持有美国股票这样的普通组合。这个策略要比对冲基金和其它专业投资者使用的更为复杂的投资方法效果好得多。",
156 | "qas":[
157 | {
158 | "query_id": "TRIAL_0_QUERY_0",
159 | "query_text": "什么是傻钱策略?",
160 | "answers": [
161 | "所谓的“傻钱”策略,其实就是买入并持有美国股票这样的普通组合",
162 | "其实就是买入并持有美国股票这样的普通组合",
163 | "买入并持有美国股票这样的普通组合"
164 | ]
165 | },
166 | {
167 | "query_id": "TRIAL_0_QUERY_1",
168 | "query_text": "12月的消费者信心指数是多少?",
169 | "answers": [
170 | "78.1",
171 | "78.1",
172 | "78.1"
173 | ]
174 | },
175 | {
176 | "query_id": "TRIAL_0_QUERY_2",
177 | "query_text": "消费者信心指数由什么机构发布?",
178 | "answers": [
179 | "工商协进会",
180 | "工商协进会",
181 | "工商协进会"
182 | ]
183 | }
184 | ]
185 | }
186 | ]
187 | }
188 | ```
189 |
190 | ## 双周赛奖励
191 |
192 | 为了鼓励大家的参与,我们决定设立双周赛奖励,时间定在 4.19、5.3、5.17这三个节点。每只队伍只能领取不超过两次双周赛奖励。
193 | 奖励:
194 |
195 | 1. 1000元人民币整
196 | 2. CLUE周赛证书(可以在线访问)
197 | 3. 我们提供提供200G预训练语料给周赛获奖队伍
198 |
199 | 获奖要求:
200 |
201 | 1. 在4.18, 5.3, 5.16 三个时间节点早上8点前处于第一名的队伍
202 | 2. 需要在第二天的时候做一次关于小模型训练的分享,可以有选择性的分享一些比赛相关的技巧
203 |
204 | NOTE:
205 |
206 | 1. 以上两点都是获奖的必要条件
207 | 2. 如果第一名重复,则顺延
208 |
209 | ## Reference
210 |
211 | Information in this page we refered to:
212 |
213 | [1] http://tcci.ccf.org.cn/conference/2019/taskdata.php
214 |
215 | ## NOTE
216 |
217 | Any question, please contact us via CLUEbenchmark@163.com or just open an issue.
218 |
219 | ## 结果:
220 |
221 | 首先感谢大家本次的参与,我们本次比赛的评定规则如下:
222 | 1. 由于测速涉及到硬件条件的限制,所以我们将依据参赛同学自己测定的速度作为模型是否符合规范的依据。
223 | 2. 最终分数评定由我们实际测算的速度和参数量按照任务指南中的公式进行计算,分数可能会存在误差,以不影响名次为可接受误差范围内。
224 | 3. 具体名次如下:
225 | 第一名:Huawei Cloud & Noah's Ark lab,得分:0.777126778
226 | 第二名:Tencent Oteam,得分:0.768969257
227 | 第三名:Xiaomi AI Lab,得分:0.758543871
228 | 结果公示三天,如果有问题可以和我们联系。
229 | 本次比赛是CLUE第一次正式举办的对外比赛,不免有很多疏漏之处,也非常感谢得到来自各个学校、公司的各位同学大神的指导和帮助。
230 | 本次比赛的其他信息会在最终的评估报告中有详细的解释。
231 | CLUE保留对结果的最终解释权。关于奖励发放的事宜会在后续和各位获奖者沟通。
232 | 恭喜这些同学,也非常感谢各位的支持。
233 |
234 | ## **关键日期 Important dates**
235 |
236 | - [x] 2020/03/10:announcement of shared tasks and call for participation;
237 | - [x] 2020/03/10:registration open;
238 | - [x] 2020/03/25:release of detailed task guidelines & training data;
239 | - [x] 2020/04/06:release of baseline and start the competition
240 |
241 | - [x] 2020/05/01:registration deadline;
242 | - [x] 2020/05/15:release of test data;
243 | - [x] 2020/05/20:participants’ results submission deadline;
244 | - [x] 2020/05/30:evaluation results release and call for system reports and conference paper;
245 | - [ ] 2020/06/30:conference paper submission deadline (only for shared tasks);
246 | - [ ] 2020/07/30:conference paper accept/reject notification;
247 | - [ ] 2020/08/10:camera-ready paper submission deadline;
248 |
249 |
250 |
251 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tf_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Multiclass
3 | from:
4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py
5 |
6 | """
7 |
8 | __author__ = "Guillaume Genthial"
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix
13 |
14 |
15 | def precision(labels, predictions, num_classes, pos_indices=None,
16 | weights=None, average='micro'):
17 | """Multi-class precision metric for Tensorflow
18 | Parameters
19 | ----------
20 | labels : Tensor of tf.int32 or tf.int64
21 | The true labels
22 | predictions : Tensor of tf.int32 or tf.int64
23 | The predictions, same shape as labels
24 | num_classes : int
25 | The number of classes
26 | pos_indices : list of int, optional
27 | The indices of the positive classes, default is all
28 | weights : Tensor of tf.int32, optional
29 | Mask, must be of compatible shape with labels
30 | average : str, optional
31 | 'micro': counts the total number of true positives, false
32 | positives, and false negatives for the classes in
33 | `pos_indices` and infer the metric from it.
34 | 'macro': will compute the metric separately for each class in
35 | `pos_indices` and average. Will not account for class
36 | imbalance.
37 | 'weighted': will compute the metric separately for each class in
38 | `pos_indices` and perform a weighted average by the total
39 | number of true labels for each class.
40 | Returns
41 | -------
42 | tuple of (scalar float Tensor, update_op)
43 | """
44 | cm, op = _streaming_confusion_matrix(
45 | labels, predictions, num_classes, weights)
46 | pr, _, _ = metrics_from_confusion_matrix(
47 | cm, pos_indices, average=average)
48 | op, _, _ = metrics_from_confusion_matrix(
49 | op, pos_indices, average=average)
50 | return (pr, op)
51 |
52 |
53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None,
54 | average='micro'):
55 | """Multi-class recall metric for Tensorflow
56 | Parameters
57 | ----------
58 | labels : Tensor of tf.int32 or tf.int64
59 | The true labels
60 | predictions : Tensor of tf.int32 or tf.int64
61 | The predictions, same shape as labels
62 | num_classes : int
63 | The number of classes
64 | pos_indices : list of int, optional
65 | The indices of the positive classes, default is all
66 | weights : Tensor of tf.int32, optional
67 | Mask, must be of compatible shape with labels
68 | average : str, optional
69 | 'micro': counts the total number of true positives, false
70 | positives, and false negatives for the classes in
71 | `pos_indices` and infer the metric from it.
72 | 'macro': will compute the metric separately for each class in
73 | `pos_indices` and average. Will not account for class
74 | imbalance.
75 | 'weighted': will compute the metric separately for each class in
76 | `pos_indices` and perform a weighted average by the total
77 | number of true labels for each class.
78 | Returns
79 | -------
80 | tuple of (scalar float Tensor, update_op)
81 | """
82 | cm, op = _streaming_confusion_matrix(
83 | labels, predictions, num_classes, weights)
84 | _, re, _ = metrics_from_confusion_matrix(
85 | cm, pos_indices, average=average)
86 | _, op, _ = metrics_from_confusion_matrix(
87 | op, pos_indices, average=average)
88 | return (re, op)
89 |
90 |
91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None,
92 | average='micro'):
93 | return fbeta(labels, predictions, num_classes, pos_indices, weights,
94 | average)
95 |
96 |
97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None,
98 | average='micro', beta=1):
99 | """Multi-class fbeta metric for Tensorflow
100 | Parameters
101 | ----------
102 | labels : Tensor of tf.int32 or tf.int64
103 | The true labels
104 | predictions : Tensor of tf.int32 or tf.int64
105 | The predictions, same shape as labels
106 | num_classes : int
107 | The number of classes
108 | pos_indices : list of int, optional
109 | The indices of the positive classes, default is all
110 | weights : Tensor of tf.int32, optional
111 | Mask, must be of compatible shape with labels
112 | average : str, optional
113 | 'micro': counts the total number of true positives, false
114 | positives, and false negatives for the classes in
115 | `pos_indices` and infer the metric from it.
116 | 'macro': will compute the metric separately for each class in
117 | `pos_indices` and average. Will not account for class
118 | imbalance.
119 | 'weighted': will compute the metric separately for each class in
120 | `pos_indices` and perform a weighted average by the total
121 | number of true labels for each class.
122 | beta : int, optional
123 | Weight of precision in harmonic mean
124 | Returns
125 | -------
126 | tuple of (scalar float Tensor, update_op)
127 | """
128 | cm, op = _streaming_confusion_matrix(
129 | labels, predictions, num_classes, weights)
130 | _, _, fbeta = metrics_from_confusion_matrix(
131 | cm, pos_indices, average=average, beta=beta)
132 | _, _, op = metrics_from_confusion_matrix(
133 | op, pos_indices, average=average, beta=beta)
134 | return (fbeta, op)
135 |
136 |
137 | def safe_div(numerator, denominator):
138 | """Safe division, return 0 if denominator is 0"""
139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator)
140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype)
141 | denominator_is_zero = tf.equal(denominator, zeros)
142 | return tf.where(denominator_is_zero, zeros, numerator / denominator)
143 |
144 |
145 | def pr_re_fbeta(cm, pos_indices, beta=1):
146 | """Uses a confusion matrix to compute precision, recall and fbeta"""
147 | num_classes = cm.shape[0]
148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices]
149 | cm_mask = np.ones([num_classes, num_classes])
150 | cm_mask[neg_indices, neg_indices] = 0
151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))
152 |
153 | cm_mask = np.ones([num_classes, num_classes])
154 | cm_mask[:, neg_indices] = 0
155 | tot_pred = tf.reduce_sum(cm * cm_mask)
156 |
157 | cm_mask = np.ones([num_classes, num_classes])
158 | cm_mask[neg_indices, :] = 0
159 | tot_gold = tf.reduce_sum(cm * cm_mask)
160 |
161 | pr = safe_div(diag_sum, tot_pred)
162 | re = safe_div(diag_sum, tot_gold)
163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re)
164 |
165 | return pr, re, fbeta
166 |
167 |
168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro',
169 | beta=1):
170 | """Precision, Recall and F1 from the confusion matrix
171 | Parameters
172 | ----------
173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes)
174 | The streaming confusion matrix.
175 | pos_indices : list of int, optional
176 | The indices of the positive classes
177 | beta : int, optional
178 | Weight of precision in harmonic mean
179 | average : str, optional
180 | 'micro', 'macro' or 'weighted'
181 | """
182 | num_classes = cm.shape[0]
183 | if pos_indices is None:
184 | pos_indices = [i for i in range(num_classes)]
185 |
186 | if average == 'micro':
187 | return pr_re_fbeta(cm, pos_indices, beta)
188 | elif average in {'macro', 'weighted'}:
189 | precisions, recalls, fbetas, n_golds = [], [], [], []
190 | for idx in pos_indices:
191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta)
192 | precisions.append(pr)
193 | recalls.append(re)
194 | fbetas.append(fbeta)
195 | cm_mask = np.zeros([num_classes, num_classes])
196 | cm_mask[idx, :] = 1
197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))
198 |
199 | if average == 'macro':
200 | pr = tf.reduce_mean(precisions)
201 | re = tf.reduce_mean(recalls)
202 | fbeta = tf.reduce_mean(fbetas)
203 | return pr, re, fbeta
204 | if average == 'weighted':
205 | n_gold = tf.reduce_sum(n_golds)
206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds))
207 | pr = safe_div(pr_sum, n_gold)
208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds))
209 | re = safe_div(re_sum, n_gold)
210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds))
211 | fbeta = safe_div(fbeta_sum, n_gold)
212 | return pr, re, fbeta
213 |
214 | else:
215 | raise NotImplementedError()
--------------------------------------------------------------------------------
/example/cluewsc2020_predict.json:
--------------------------------------------------------------------------------
1 | {"id": 0, "label": "false"}
2 | {"id": 1, "label": "false"}
3 | {"id": 2, "label": "false"}
4 | {"id": 3, "label": "false"}
5 | {"id": 4, "label": "false"}
6 | {"id": 5, "label": "false"}
7 | {"id": 6, "label": "false"}
8 | {"id": 7, "label": "false"}
9 | {"id": 8, "label": "false"}
10 | {"id": 9, "label": "false"}
11 | {"id": 10, "label": "false"}
12 | {"id": 11, "label": "false"}
13 | {"id": 12, "label": "false"}
14 | {"id": 13, "label": "false"}
15 | {"id": 14, "label": "false"}
16 | {"id": 15, "label": "false"}
17 | {"id": 16, "label": "false"}
18 | {"id": 17, "label": "false"}
19 | {"id": 18, "label": "false"}
20 | {"id": 19, "label": "false"}
21 | {"id": 20, "label": "false"}
22 | {"id": 21, "label": "false"}
23 | {"id": 22, "label": "false"}
24 | {"id": 23, "label": "false"}
25 | {"id": 24, "label": "false"}
26 | {"id": 25, "label": "false"}
27 | {"id": 26, "label": "false"}
28 | {"id": 27, "label": "false"}
29 | {"id": 28, "label": "false"}
30 | {"id": 29, "label": "false"}
31 | {"id": 30, "label": "false"}
32 | {"id": 31, "label": "false"}
33 | {"id": 32, "label": "false"}
34 | {"id": 33, "label": "false"}
35 | {"id": 34, "label": "false"}
36 | {"id": 35, "label": "false"}
37 | {"id": 36, "label": "false"}
38 | {"id": 37, "label": "false"}
39 | {"id": 38, "label": "false"}
40 | {"id": 39, "label": "false"}
41 | {"id": 40, "label": "false"}
42 | {"id": 41, "label": "false"}
43 | {"id": 42, "label": "false"}
44 | {"id": 43, "label": "false"}
45 | {"id": 44, "label": "false"}
46 | {"id": 45, "label": "false"}
47 | {"id": 46, "label": "false"}
48 | {"id": 47, "label": "false"}
49 | {"id": 48, "label": "false"}
50 | {"id": 49, "label": "false"}
51 | {"id": 50, "label": "false"}
52 | {"id": 51, "label": "false"}
53 | {"id": 52, "label": "false"}
54 | {"id": 53, "label": "false"}
55 | {"id": 54, "label": "false"}
56 | {"id": 55, "label": "false"}
57 | {"id": 56, "label": "false"}
58 | {"id": 57, "label": "false"}
59 | {"id": 58, "label": "false"}
60 | {"id": 59, "label": "false"}
61 | {"id": 60, "label": "false"}
62 | {"id": 61, "label": "false"}
63 | {"id": 62, "label": "false"}
64 | {"id": 63, "label": "false"}
65 | {"id": 64, "label": "false"}
66 | {"id": 65, "label": "false"}
67 | {"id": 66, "label": "false"}
68 | {"id": 67, "label": "false"}
69 | {"id": 68, "label": "false"}
70 | {"id": 69, "label": "false"}
71 | {"id": 70, "label": "false"}
72 | {"id": 71, "label": "false"}
73 | {"id": 72, "label": "false"}
74 | {"id": 73, "label": "false"}
75 | {"id": 74, "label": "false"}
76 | {"id": 75, "label": "false"}
77 | {"id": 76, "label": "false"}
78 | {"id": 77, "label": "false"}
79 | {"id": 78, "label": "false"}
80 | {"id": 79, "label": "false"}
81 | {"id": 80, "label": "false"}
82 | {"id": 81, "label": "false"}
83 | {"id": 82, "label": "false"}
84 | {"id": 83, "label": "false"}
85 | {"id": 84, "label": "false"}
86 | {"id": 85, "label": "false"}
87 | {"id": 86, "label": "false"}
88 | {"id": 87, "label": "false"}
89 | {"id": 88, "label": "false"}
90 | {"id": 89, "label": "false"}
91 | {"id": 90, "label": "false"}
92 | {"id": 91, "label": "false"}
93 | {"id": 92, "label": "false"}
94 | {"id": 93, "label": "false"}
95 | {"id": 94, "label": "false"}
96 | {"id": 95, "label": "false"}
97 | {"id": 96, "label": "false"}
98 | {"id": 97, "label": "false"}
99 | {"id": 98, "label": "false"}
100 | {"id": 99, "label": "false"}
101 | {"id": 100, "label": "false"}
102 | {"id": 101, "label": "false"}
103 | {"id": 102, "label": "false"}
104 | {"id": 103, "label": "false"}
105 | {"id": 104, "label": "false"}
106 | {"id": 105, "label": "false"}
107 | {"id": 106, "label": "false"}
108 | {"id": 107, "label": "false"}
109 | {"id": 108, "label": "false"}
110 | {"id": 109, "label": "false"}
111 | {"id": 110, "label": "false"}
112 | {"id": 111, "label": "false"}
113 | {"id": 112, "label": "false"}
114 | {"id": 113, "label": "false"}
115 | {"id": 114, "label": "false"}
116 | {"id": 115, "label": "false"}
117 | {"id": 116, "label": "false"}
118 | {"id": 117, "label": "false"}
119 | {"id": 118, "label": "false"}
120 | {"id": 119, "label": "false"}
121 | {"id": 120, "label": "false"}
122 | {"id": 121, "label": "false"}
123 | {"id": 122, "label": "false"}
124 | {"id": 123, "label": "false"}
125 | {"id": 124, "label": "false"}
126 | {"id": 125, "label": "false"}
127 | {"id": 126, "label": "false"}
128 | {"id": 127, "label": "false"}
129 | {"id": 128, "label": "false"}
130 | {"id": 129, "label": "false"}
131 | {"id": 130, "label": "false"}
132 | {"id": 131, "label": "false"}
133 | {"id": 132, "label": "false"}
134 | {"id": 133, "label": "false"}
135 | {"id": 134, "label": "false"}
136 | {"id": 135, "label": "false"}
137 | {"id": 136, "label": "false"}
138 | {"id": 137, "label": "false"}
139 | {"id": 138, "label": "false"}
140 | {"id": 139, "label": "false"}
141 | {"id": 140, "label": "false"}
142 | {"id": 141, "label": "false"}
143 | {"id": 142, "label": "false"}
144 | {"id": 143, "label": "false"}
145 | {"id": 144, "label": "false"}
146 | {"id": 145, "label": "false"}
147 | {"id": 146, "label": "false"}
148 | {"id": 147, "label": "false"}
149 | {"id": 148, "label": "false"}
150 | {"id": 149, "label": "false"}
151 | {"id": 150, "label": "false"}
152 | {"id": 151, "label": "false"}
153 | {"id": 152, "label": "false"}
154 | {"id": 153, "label": "false"}
155 | {"id": 154, "label": "false"}
156 | {"id": 155, "label": "false"}
157 | {"id": 156, "label": "false"}
158 | {"id": 157, "label": "false"}
159 | {"id": 158, "label": "false"}
160 | {"id": 159, "label": "false"}
161 | {"id": 160, "label": "false"}
162 | {"id": 161, "label": "false"}
163 | {"id": 162, "label": "false"}
164 | {"id": 163, "label": "false"}
165 | {"id": 164, "label": "false"}
166 | {"id": 165, "label": "false"}
167 | {"id": 166, "label": "false"}
168 | {"id": 167, "label": "false"}
169 | {"id": 168, "label": "false"}
170 | {"id": 169, "label": "false"}
171 | {"id": 170, "label": "false"}
172 | {"id": 171, "label": "false"}
173 | {"id": 172, "label": "false"}
174 | {"id": 173, "label": "false"}
175 | {"id": 174, "label": "false"}
176 | {"id": 175, "label": "false"}
177 | {"id": 176, "label": "false"}
178 | {"id": 177, "label": "false"}
179 | {"id": 178, "label": "false"}
180 | {"id": 179, "label": "false"}
181 | {"id": 180, "label": "false"}
182 | {"id": 181, "label": "false"}
183 | {"id": 182, "label": "false"}
184 | {"id": 183, "label": "false"}
185 | {"id": 184, "label": "false"}
186 | {"id": 185, "label": "false"}
187 | {"id": 186, "label": "false"}
188 | {"id": 187, "label": "true"}
189 | {"id": 188, "label": "true"}
190 | {"id": 189, "label": "false"}
191 | {"id": 190, "label": "false"}
192 | {"id": 191, "label": "false"}
193 | {"id": 192, "label": "false"}
194 | {"id": 193, "label": "false"}
195 | {"id": 194, "label": "false"}
196 | {"id": 195, "label": "false"}
197 | {"id": 196, "label": "false"}
198 | {"id": 197, "label": "false"}
199 | {"id": 198, "label": "false"}
200 | {"id": 199, "label": "false"}
201 | {"id": 200, "label": "false"}
202 | {"id": 201, "label": "false"}
203 | {"id": 202, "label": "false"}
204 | {"id": 203, "label": "false"}
205 | {"id": 204, "label": "false"}
206 | {"id": 205, "label": "false"}
207 | {"id": 206, "label": "false"}
208 | {"id": 207, "label": "false"}
209 | {"id": 208, "label": "false"}
210 | {"id": 209, "label": "false"}
211 | {"id": 210, "label": "false"}
212 | {"id": 211, "label": "false"}
213 | {"id": 212, "label": "false"}
214 | {"id": 213, "label": "false"}
215 | {"id": 214, "label": "false"}
216 | {"id": 215, "label": "false"}
217 | {"id": 216, "label": "false"}
218 | {"id": 217, "label": "false"}
219 | {"id": 218, "label": "false"}
220 | {"id": 219, "label": "false"}
221 | {"id": 220, "label": "false"}
222 | {"id": 221, "label": "false"}
223 | {"id": 222, "label": "false"}
224 | {"id": 223, "label": "false"}
225 | {"id": 224, "label": "false"}
226 | {"id": 225, "label": "false"}
227 | {"id": 226, "label": "false"}
228 | {"id": 227, "label": "false"}
229 | {"id": 228, "label": "false"}
230 | {"id": 229, "label": "false"}
231 | {"id": 230, "label": "false"}
232 | {"id": 231, "label": "false"}
233 | {"id": 232, "label": "false"}
234 | {"id": 233, "label": "false"}
235 | {"id": 234, "label": "false"}
236 | {"id": 235, "label": "false"}
237 | {"id": 236, "label": "false"}
238 | {"id": 237, "label": "false"}
239 | {"id": 238, "label": "false"}
240 | {"id": 239, "label": "false"}
241 | {"id": 240, "label": "false"}
242 | {"id": 241, "label": "false"}
243 | {"id": 242, "label": "false"}
244 | {"id": 243, "label": "false"}
245 | {"id": 244, "label": "false"}
246 | {"id": 245, "label": "false"}
247 | {"id": 246, "label": "false"}
248 | {"id": 247, "label": "false"}
249 | {"id": 248, "label": "false"}
250 | {"id": 249, "label": "false"}
251 | {"id": 250, "label": "false"}
252 | {"id": 251, "label": "false"}
253 | {"id": 252, "label": "false"}
254 | {"id": 253, "label": "false"}
255 | {"id": 254, "label": "false"}
256 | {"id": 255, "label": "false"}
257 | {"id": 256, "label": "false"}
258 | {"id": 257, "label": "false"}
259 | {"id": 258, "label": "false"}
260 | {"id": 259, "label": "false"}
261 | {"id": 260, "label": "false"}
262 | {"id": 261, "label": "false"}
263 | {"id": 262, "label": "false"}
264 | {"id": 263, "label": "false"}
265 | {"id": 264, "label": "false"}
266 | {"id": 265, "label": "false"}
267 | {"id": 266, "label": "false"}
268 | {"id": 267, "label": "false"}
269 | {"id": 268, "label": "false"}
270 | {"id": 269, "label": "false"}
271 | {"id": 270, "label": "false"}
272 | {"id": 271, "label": "false"}
273 | {"id": 272, "label": "false"}
274 | {"id": 273, "label": "false"}
275 | {"id": 274, "label": "false"}
276 | {"id": 275, "label": "false"}
277 | {"id": 276, "label": "false"}
278 | {"id": 277, "label": "false"}
279 | {"id": 278, "label": "false"}
280 | {"id": 279, "label": "false"}
281 | {"id": 280, "label": "false"}
282 | {"id": 281, "label": "false"}
283 | {"id": 282, "label": "false"}
284 | {"id": 283, "label": "false"}
285 | {"id": 284, "label": "false"}
286 | {"id": 285, "label": "false"}
287 | {"id": 286, "label": "false"}
288 | {"id": 287, "label": "false"}
289 | {"id": 288, "label": "false"}
290 | {"id": 289, "label": "false"}
291 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/wsc_output/wsc_predict.json:
--------------------------------------------------------------------------------
1 | {"id": 0, "label": "false"}
2 | {"id": 1, "label": "false"}
3 | {"id": 2, "label": "false"}
4 | {"id": 3, "label": "false"}
5 | {"id": 4, "label": "false"}
6 | {"id": 5, "label": "false"}
7 | {"id": 6, "label": "false"}
8 | {"id": 7, "label": "false"}
9 | {"id": 8, "label": "false"}
10 | {"id": 9, "label": "false"}
11 | {"id": 10, "label": "false"}
12 | {"id": 11, "label": "false"}
13 | {"id": 12, "label": "false"}
14 | {"id": 13, "label": "false"}
15 | {"id": 14, "label": "false"}
16 | {"id": 15, "label": "false"}
17 | {"id": 16, "label": "false"}
18 | {"id": 17, "label": "false"}
19 | {"id": 18, "label": "false"}
20 | {"id": 19, "label": "false"}
21 | {"id": 20, "label": "false"}
22 | {"id": 21, "label": "false"}
23 | {"id": 22, "label": "false"}
24 | {"id": 23, "label": "false"}
25 | {"id": 24, "label": "false"}
26 | {"id": 25, "label": "false"}
27 | {"id": 26, "label": "false"}
28 | {"id": 27, "label": "false"}
29 | {"id": 28, "label": "false"}
30 | {"id": 29, "label": "false"}
31 | {"id": 30, "label": "false"}
32 | {"id": 31, "label": "false"}
33 | {"id": 32, "label": "false"}
34 | {"id": 33, "label": "false"}
35 | {"id": 34, "label": "false"}
36 | {"id": 35, "label": "false"}
37 | {"id": 36, "label": "false"}
38 | {"id": 37, "label": "false"}
39 | {"id": 38, "label": "false"}
40 | {"id": 39, "label": "false"}
41 | {"id": 40, "label": "false"}
42 | {"id": 41, "label": "false"}
43 | {"id": 42, "label": "false"}
44 | {"id": 43, "label": "false"}
45 | {"id": 44, "label": "false"}
46 | {"id": 45, "label": "false"}
47 | {"id": 46, "label": "false"}
48 | {"id": 47, "label": "false"}
49 | {"id": 48, "label": "false"}
50 | {"id": 49, "label": "false"}
51 | {"id": 50, "label": "false"}
52 | {"id": 51, "label": "false"}
53 | {"id": 52, "label": "false"}
54 | {"id": 53, "label": "false"}
55 | {"id": 54, "label": "false"}
56 | {"id": 55, "label": "false"}
57 | {"id": 56, "label": "false"}
58 | {"id": 57, "label": "false"}
59 | {"id": 58, "label": "false"}
60 | {"id": 59, "label": "false"}
61 | {"id": 60, "label": "false"}
62 | {"id": 61, "label": "false"}
63 | {"id": 62, "label": "false"}
64 | {"id": 63, "label": "false"}
65 | {"id": 64, "label": "false"}
66 | {"id": 65, "label": "false"}
67 | {"id": 66, "label": "false"}
68 | {"id": 67, "label": "false"}
69 | {"id": 68, "label": "false"}
70 | {"id": 69, "label": "false"}
71 | {"id": 70, "label": "false"}
72 | {"id": 71, "label": "false"}
73 | {"id": 72, "label": "false"}
74 | {"id": 73, "label": "false"}
75 | {"id": 74, "label": "false"}
76 | {"id": 75, "label": "false"}
77 | {"id": 76, "label": "false"}
78 | {"id": 77, "label": "false"}
79 | {"id": 78, "label": "false"}
80 | {"id": 79, "label": "false"}
81 | {"id": 80, "label": "false"}
82 | {"id": 81, "label": "false"}
83 | {"id": 82, "label": "false"}
84 | {"id": 83, "label": "false"}
85 | {"id": 84, "label": "false"}
86 | {"id": 85, "label": "false"}
87 | {"id": 86, "label": "false"}
88 | {"id": 87, "label": "false"}
89 | {"id": 88, "label": "false"}
90 | {"id": 89, "label": "false"}
91 | {"id": 90, "label": "false"}
92 | {"id": 91, "label": "false"}
93 | {"id": 92, "label": "false"}
94 | {"id": 93, "label": "false"}
95 | {"id": 94, "label": "false"}
96 | {"id": 95, "label": "false"}
97 | {"id": 96, "label": "false"}
98 | {"id": 97, "label": "false"}
99 | {"id": 98, "label": "false"}
100 | {"id": 99, "label": "false"}
101 | {"id": 100, "label": "false"}
102 | {"id": 101, "label": "false"}
103 | {"id": 102, "label": "false"}
104 | {"id": 103, "label": "false"}
105 | {"id": 104, "label": "false"}
106 | {"id": 105, "label": "false"}
107 | {"id": 106, "label": "false"}
108 | {"id": 107, "label": "false"}
109 | {"id": 108, "label": "false"}
110 | {"id": 109, "label": "false"}
111 | {"id": 110, "label": "false"}
112 | {"id": 111, "label": "false"}
113 | {"id": 112, "label": "false"}
114 | {"id": 113, "label": "false"}
115 | {"id": 114, "label": "false"}
116 | {"id": 115, "label": "false"}
117 | {"id": 116, "label": "false"}
118 | {"id": 117, "label": "false"}
119 | {"id": 118, "label": "false"}
120 | {"id": 119, "label": "false"}
121 | {"id": 120, "label": "false"}
122 | {"id": 121, "label": "false"}
123 | {"id": 122, "label": "false"}
124 | {"id": 123, "label": "false"}
125 | {"id": 124, "label": "false"}
126 | {"id": 125, "label": "false"}
127 | {"id": 126, "label": "false"}
128 | {"id": 127, "label": "false"}
129 | {"id": 128, "label": "false"}
130 | {"id": 129, "label": "false"}
131 | {"id": 130, "label": "false"}
132 | {"id": 131, "label": "false"}
133 | {"id": 132, "label": "false"}
134 | {"id": 133, "label": "false"}
135 | {"id": 134, "label": "false"}
136 | {"id": 135, "label": "false"}
137 | {"id": 136, "label": "false"}
138 | {"id": 137, "label": "false"}
139 | {"id": 138, "label": "false"}
140 | {"id": 139, "label": "false"}
141 | {"id": 140, "label": "false"}
142 | {"id": 141, "label": "false"}
143 | {"id": 142, "label": "false"}
144 | {"id": 143, "label": "false"}
145 | {"id": 144, "label": "false"}
146 | {"id": 145, "label": "false"}
147 | {"id": 146, "label": "false"}
148 | {"id": 147, "label": "false"}
149 | {"id": 148, "label": "false"}
150 | {"id": 149, "label": "false"}
151 | {"id": 150, "label": "false"}
152 | {"id": 151, "label": "false"}
153 | {"id": 152, "label": "false"}
154 | {"id": 153, "label": "false"}
155 | {"id": 154, "label": "false"}
156 | {"id": 155, "label": "false"}
157 | {"id": 156, "label": "false"}
158 | {"id": 157, "label": "false"}
159 | {"id": 158, "label": "false"}
160 | {"id": 159, "label": "false"}
161 | {"id": 160, "label": "false"}
162 | {"id": 161, "label": "false"}
163 | {"id": 162, "label": "false"}
164 | {"id": 163, "label": "false"}
165 | {"id": 164, "label": "false"}
166 | {"id": 165, "label": "false"}
167 | {"id": 166, "label": "false"}
168 | {"id": 167, "label": "false"}
169 | {"id": 168, "label": "false"}
170 | {"id": 169, "label": "false"}
171 | {"id": 170, "label": "false"}
172 | {"id": 171, "label": "false"}
173 | {"id": 172, "label": "false"}
174 | {"id": 173, "label": "false"}
175 | {"id": 174, "label": "false"}
176 | {"id": 175, "label": "false"}
177 | {"id": 176, "label": "false"}
178 | {"id": 177, "label": "false"}
179 | {"id": 178, "label": "false"}
180 | {"id": 179, "label": "false"}
181 | {"id": 180, "label": "false"}
182 | {"id": 181, "label": "false"}
183 | {"id": 182, "label": "false"}
184 | {"id": 183, "label": "false"}
185 | {"id": 184, "label": "false"}
186 | {"id": 185, "label": "false"}
187 | {"id": 186, "label": "false"}
188 | {"id": 187, "label": "true"}
189 | {"id": 188, "label": "true"}
190 | {"id": 189, "label": "false"}
191 | {"id": 190, "label": "false"}
192 | {"id": 191, "label": "false"}
193 | {"id": 192, "label": "false"}
194 | {"id": 193, "label": "false"}
195 | {"id": 194, "label": "false"}
196 | {"id": 195, "label": "false"}
197 | {"id": 196, "label": "false"}
198 | {"id": 197, "label": "false"}
199 | {"id": 198, "label": "false"}
200 | {"id": 199, "label": "false"}
201 | {"id": 200, "label": "false"}
202 | {"id": 201, "label": "false"}
203 | {"id": 202, "label": "false"}
204 | {"id": 203, "label": "false"}
205 | {"id": 204, "label": "false"}
206 | {"id": 205, "label": "false"}
207 | {"id": 206, "label": "false"}
208 | {"id": 207, "label": "false"}
209 | {"id": 208, "label": "false"}
210 | {"id": 209, "label": "false"}
211 | {"id": 210, "label": "false"}
212 | {"id": 211, "label": "false"}
213 | {"id": 212, "label": "false"}
214 | {"id": 213, "label": "false"}
215 | {"id": 214, "label": "false"}
216 | {"id": 215, "label": "false"}
217 | {"id": 216, "label": "false"}
218 | {"id": 217, "label": "false"}
219 | {"id": 218, "label": "false"}
220 | {"id": 219, "label": "false"}
221 | {"id": 220, "label": "false"}
222 | {"id": 221, "label": "false"}
223 | {"id": 222, "label": "false"}
224 | {"id": 223, "label": "false"}
225 | {"id": 224, "label": "false"}
226 | {"id": 225, "label": "false"}
227 | {"id": 226, "label": "false"}
228 | {"id": 227, "label": "false"}
229 | {"id": 228, "label": "false"}
230 | {"id": 229, "label": "false"}
231 | {"id": 230, "label": "false"}
232 | {"id": 231, "label": "false"}
233 | {"id": 232, "label": "false"}
234 | {"id": 233, "label": "false"}
235 | {"id": 234, "label": "false"}
236 | {"id": 235, "label": "false"}
237 | {"id": 236, "label": "false"}
238 | {"id": 237, "label": "false"}
239 | {"id": 238, "label": "false"}
240 | {"id": 239, "label": "false"}
241 | {"id": 240, "label": "false"}
242 | {"id": 241, "label": "false"}
243 | {"id": 242, "label": "false"}
244 | {"id": 243, "label": "false"}
245 | {"id": 244, "label": "false"}
246 | {"id": 245, "label": "false"}
247 | {"id": 246, "label": "false"}
248 | {"id": 247, "label": "false"}
249 | {"id": 248, "label": "false"}
250 | {"id": 249, "label": "false"}
251 | {"id": 250, "label": "false"}
252 | {"id": 251, "label": "false"}
253 | {"id": 252, "label": "false"}
254 | {"id": 253, "label": "false"}
255 | {"id": 254, "label": "false"}
256 | {"id": 255, "label": "false"}
257 | {"id": 256, "label": "false"}
258 | {"id": 257, "label": "false"}
259 | {"id": 258, "label": "false"}
260 | {"id": 259, "label": "false"}
261 | {"id": 260, "label": "false"}
262 | {"id": 261, "label": "false"}
263 | {"id": 262, "label": "false"}
264 | {"id": 263, "label": "false"}
265 | {"id": 264, "label": "false"}
266 | {"id": 265, "label": "false"}
267 | {"id": 266, "label": "false"}
268 | {"id": 267, "label": "false"}
269 | {"id": 268, "label": "false"}
270 | {"id": 269, "label": "false"}
271 | {"id": 270, "label": "false"}
272 | {"id": 271, "label": "false"}
273 | {"id": 272, "label": "false"}
274 | {"id": 273, "label": "false"}
275 | {"id": 274, "label": "false"}
276 | {"id": 275, "label": "false"}
277 | {"id": 276, "label": "false"}
278 | {"id": 277, "label": "false"}
279 | {"id": 278, "label": "false"}
280 | {"id": 279, "label": "false"}
281 | {"id": 280, "label": "false"}
282 | {"id": 281, "label": "false"}
283 | {"id": 282, "label": "false"}
284 | {"id": 283, "label": "false"}
285 | {"id": 284, "label": "false"}
286 | {"id": 285, "label": "false"}
287 | {"id": 286, "label": "false"}
288 | {"id": 287, "label": "false"}
289 | {"id": 288, "label": "false"}
290 | {"id": 289, "label": "false"}
291 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import json
21 | import random
22 | import re
23 |
24 | import modeling
25 | import six
26 | import tensorflow as tf
27 |
28 |
29 | class BertModelTest(tf.test.TestCase):
30 |
31 | class BertModelTester(object):
32 |
33 | def __init__(self,
34 | parent,
35 | batch_size=13,
36 | seq_length=7,
37 | is_training=True,
38 | use_input_mask=True,
39 | use_token_type_ids=True,
40 | vocab_size=99,
41 | hidden_size=32,
42 | num_hidden_layers=5,
43 | num_attention_heads=4,
44 | intermediate_size=37,
45 | hidden_act="gelu",
46 | hidden_dropout_prob=0.1,
47 | attention_probs_dropout_prob=0.1,
48 | max_position_embeddings=512,
49 | type_vocab_size=16,
50 | initializer_range=0.02,
51 | scope=None):
52 | self.parent = parent
53 | self.batch_size = batch_size
54 | self.seq_length = seq_length
55 | self.is_training = is_training
56 | self.use_input_mask = use_input_mask
57 | self.use_token_type_ids = use_token_type_ids
58 | self.vocab_size = vocab_size
59 | self.hidden_size = hidden_size
60 | self.num_hidden_layers = num_hidden_layers
61 | self.num_attention_heads = num_attention_heads
62 | self.intermediate_size = intermediate_size
63 | self.hidden_act = hidden_act
64 | self.hidden_dropout_prob = hidden_dropout_prob
65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
66 | self.max_position_embeddings = max_position_embeddings
67 | self.type_vocab_size = type_vocab_size
68 | self.initializer_range = initializer_range
69 | self.scope = scope
70 |
71 | def create_model(self):
72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
73 | self.vocab_size)
74 |
75 | input_mask = None
76 | if self.use_input_mask:
77 | input_mask = BertModelTest.ids_tensor(
78 | [self.batch_size, self.seq_length], vocab_size=2)
79 |
80 | token_type_ids = None
81 | if self.use_token_type_ids:
82 | token_type_ids = BertModelTest.ids_tensor(
83 | [self.batch_size, self.seq_length], self.type_vocab_size)
84 |
85 | config = modeling.BertConfig(
86 | vocab_size=self.vocab_size,
87 | hidden_size=self.hidden_size,
88 | num_hidden_layers=self.num_hidden_layers,
89 | num_attention_heads=self.num_attention_heads,
90 | intermediate_size=self.intermediate_size,
91 | hidden_act=self.hidden_act,
92 | hidden_dropout_prob=self.hidden_dropout_prob,
93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
94 | max_position_embeddings=self.max_position_embeddings,
95 | type_vocab_size=self.type_vocab_size,
96 | initializer_range=self.initializer_range)
97 |
98 | model = modeling.BertModel(
99 | config=config,
100 | is_training=self.is_training,
101 | input_ids=input_ids,
102 | input_mask=input_mask,
103 | token_type_ids=token_type_ids,
104 | scope=self.scope)
105 |
106 | outputs = {
107 | "embedding_output": model.get_embedding_output(),
108 | "sequence_output": model.get_sequence_output(),
109 | "pooled_output": model.get_pooled_output(),
110 | "all_encoder_layers": model.get_all_encoder_layers(),
111 | }
112 | return outputs
113 |
114 | def check_output(self, result):
115 | self.parent.assertAllEqual(
116 | result["embedding_output"].shape,
117 | [self.batch_size, self.seq_length, self.hidden_size])
118 |
119 | self.parent.assertAllEqual(
120 | result["sequence_output"].shape,
121 | [self.batch_size, self.seq_length, self.hidden_size])
122 |
123 | self.parent.assertAllEqual(result["pooled_output"].shape,
124 | [self.batch_size, self.hidden_size])
125 |
126 | def test_default(self):
127 | self.run_tester(BertModelTest.BertModelTester(self))
128 |
129 | def test_config_to_json_string(self):
130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37)
131 | obj = json.loads(config.to_json_string())
132 | self.assertEqual(obj["vocab_size"], 99)
133 | self.assertEqual(obj["hidden_size"], 37)
134 |
135 | def run_tester(self, tester):
136 | with self.test_session() as sess:
137 | ops = tester.create_model()
138 | init_op = tf.group(tf.global_variables_initializer(),
139 | tf.local_variables_initializer())
140 | sess.run(init_op)
141 | output_result = sess.run(ops)
142 | tester.check_output(output_result)
143 |
144 | self.assert_all_tensors_reachable(sess, [init_op, ops])
145 |
146 | @classmethod
147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
148 | """Creates a random int32 tensor of the shape within the vocab size."""
149 | if rng is None:
150 | rng = random.Random()
151 |
152 | total_dims = 1
153 | for dim in shape:
154 | total_dims *= dim
155 |
156 | values = []
157 | for _ in range(total_dims):
158 | values.append(rng.randint(0, vocab_size - 1))
159 |
160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
161 |
162 | def assert_all_tensors_reachable(self, sess, outputs):
163 | """Checks that all the tensors in the graph are reachable from outputs."""
164 | graph = sess.graph
165 |
166 | ignore_strings = [
167 | "^.*/assert_less_equal/.*$",
168 | "^.*/dilation_rate$",
169 | "^.*/Tensordot/concat$",
170 | "^.*/Tensordot/concat/axis$",
171 | "^testing/.*$",
172 | ]
173 |
174 | ignore_regexes = [re.compile(x) for x in ignore_strings]
175 |
176 | unreachable = self.get_unreachable_ops(graph, outputs)
177 | filtered_unreachable = []
178 | for x in unreachable:
179 | do_ignore = False
180 | for r in ignore_regexes:
181 | m = r.match(x.name)
182 | if m is not None:
183 | do_ignore = True
184 | if do_ignore:
185 | continue
186 | filtered_unreachable.append(x)
187 | unreachable = filtered_unreachable
188 |
189 | self.assertEqual(
190 | len(unreachable), 0, "The following ops are unreachable: %s" %
191 | (" ".join([x.name for x in unreachable])))
192 |
193 | @classmethod
194 | def get_unreachable_ops(cls, graph, outputs):
195 | """Finds all of the tensors in graph that are unreachable from outputs."""
196 | outputs = cls.flatten_recursive(outputs)
197 | output_to_op = collections.defaultdict(list)
198 | op_to_all = collections.defaultdict(list)
199 | assign_out_to_in = collections.defaultdict(list)
200 |
201 | for op in graph.get_operations():
202 | for x in op.inputs:
203 | op_to_all[op.name].append(x.name)
204 | for y in op.outputs:
205 | output_to_op[y.name].append(op.name)
206 | op_to_all[op.name].append(y.name)
207 | if str(op.type) == "Assign":
208 | for y in op.outputs:
209 | for x in op.inputs:
210 | assign_out_to_in[y.name].append(x.name)
211 |
212 | assign_groups = collections.defaultdict(list)
213 | for out_name in assign_out_to_in.keys():
214 | name_group = assign_out_to_in[out_name]
215 | for n1 in name_group:
216 | assign_groups[n1].append(out_name)
217 | for n2 in name_group:
218 | if n1 != n2:
219 | assign_groups[n1].append(n2)
220 |
221 | seen_tensors = {}
222 | stack = [x.name for x in outputs]
223 | while stack:
224 | name = stack.pop()
225 | if name in seen_tensors:
226 | continue
227 | seen_tensors[name] = True
228 |
229 | if name in output_to_op:
230 | for op_name in output_to_op[name]:
231 | if op_name in op_to_all:
232 | for input_name in op_to_all[op_name]:
233 | if input_name not in stack:
234 | stack.append(input_name)
235 |
236 | expanded_names = []
237 | if name in assign_groups:
238 | for assign_name in assign_groups[name]:
239 | expanded_names.append(assign_name)
240 |
241 | for expanded_name in expanded_names:
242 | if expanded_name not in stack:
243 | stack.append(expanded_name)
244 |
245 | unreachable_ops = []
246 | for op in graph.get_operations():
247 | is_unreachable = False
248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
249 | for name in all_names:
250 | if name not in seen_tensors:
251 | is_unreachable = True
252 | if is_unreachable:
253 | unreachable_ops.append(op)
254 | return unreachable_ops
255 |
256 | @classmethod
257 | def flatten_recursive(cls, item):
258 | """Flattens (potentially nested) a tuple/dictionary/list to a list."""
259 | output = []
260 | if isinstance(item, list):
261 | output.extend(item)
262 | elif isinstance(item, tuple):
263 | output.extend(list(item))
264 | elif isinstance(item, dict):
265 | for (_, v) in six.iteritems(item):
266 | output.append(v)
267 | else:
268 | return [item]
269 |
270 | flat_output = []
271 | for x in output:
272 | flat_output.extend(cls.flatten_recursive(x))
273 | return flat_output
274 |
275 |
276 | if __name__ == "__main__":
277 | tf.test.main()
278 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | # add function :evaluate(predicted_label, ori_label): which will not read from file
13 |
14 | import sys
15 | import re
16 | import codecs
17 | from collections import defaultdict, namedtuple
18 |
19 | ANY_SPACE = ''
20 |
21 |
22 | class FormatError(Exception):
23 | pass
24 |
25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
26 |
27 |
28 | class EvalCounts(object):
29 | def __init__(self):
30 | self.correct_chunk = 0 # number of correctly identified chunks
31 | self.correct_tags = 0 # number of correct chunk tags
32 | self.found_correct = 0 # number of chunks in corpus
33 | self.found_guessed = 0 # number of identified chunks
34 | self.token_counter = 0 # token counter (ignores sentence breaks)
35 |
36 | # counts by type
37 | self.t_correct_chunk = defaultdict(int)
38 | self.t_found_correct = defaultdict(int)
39 | self.t_found_guessed = defaultdict(int)
40 |
41 |
42 | def parse_args(argv):
43 | import argparse
44 | parser = argparse.ArgumentParser(
45 | description='evaluate tagging results using CoNLL criteria',
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
47 | )
48 | arg = parser.add_argument
49 | arg('-b', '--boundary', metavar='STR', default='-X-',
50 | help='sentence boundary')
51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
52 | help='character delimiting items in input')
53 | arg('-o', '--otag', metavar='CHAR', default='O',
54 | help='alternative outside tag')
55 | arg('file', nargs='?', default=None)
56 | return parser.parse_args(argv)
57 |
58 |
59 | def parse_tag(t):
60 | m = re.match(r'^([^-]*)-(.*)$', t)
61 | return m.groups() if m else (t, '')
62 |
63 |
64 | def evaluate(iterable, options=None):
65 | if options is None:
66 | options = parse_args([]) # use defaults
67 |
68 | counts = EvalCounts()
69 | num_features = None # number of features per line
70 | in_correct = False # currently processed chunks is correct until now
71 | last_correct = 'O' # previous chunk tag in corpus
72 | last_correct_type = '' # type of previously identified chunk tag
73 | last_guessed = 'O' # previously identified chunk tag
74 | last_guessed_type = '' # type of previous chunk tag in corpus
75 |
76 | for line in iterable:
77 | line = line.rstrip('\r\n')
78 |
79 | if options.delimiter == ANY_SPACE:
80 | features = line.split()
81 | else:
82 | features = line.split(options.delimiter)
83 |
84 | if num_features is None:
85 | num_features = len(features)
86 | elif num_features != len(features) and len(features) != 0:
87 | raise FormatError('unexpected number of features: %d (%d)' %
88 | (len(features), num_features))
89 |
90 | if len(features) == 0 or features[0] == options.boundary:
91 | features = [options.boundary, 'O', 'O']
92 | if len(features) < 3:
93 | raise FormatError('unexpected number of features in line %s' % line)
94 |
95 | guessed, guessed_type = parse_tag(features.pop())
96 | correct, correct_type = parse_tag(features.pop())
97 | first_item = features.pop(0)
98 |
99 | if first_item == options.boundary:
100 | guessed = 'O'
101 |
102 | end_correct = end_of_chunk(last_correct, correct,
103 | last_correct_type, correct_type)
104 | end_guessed = end_of_chunk(last_guessed, guessed,
105 | last_guessed_type, guessed_type)
106 | start_correct = start_of_chunk(last_correct, correct,
107 | last_correct_type, correct_type)
108 | start_guessed = start_of_chunk(last_guessed, guessed,
109 | last_guessed_type, guessed_type)
110 |
111 | if in_correct:
112 | if (end_correct and end_guessed and
113 | last_guessed_type == last_correct_type):
114 | in_correct = False
115 | counts.correct_chunk += 1
116 | counts.t_correct_chunk[last_correct_type] += 1
117 | elif (end_correct != end_guessed or guessed_type != correct_type):
118 | in_correct = False
119 |
120 | if start_correct and start_guessed and guessed_type == correct_type:
121 | in_correct = True
122 |
123 | if start_correct:
124 | counts.found_correct += 1
125 | counts.t_found_correct[correct_type] += 1
126 | if start_guessed:
127 | counts.found_guessed += 1
128 | counts.t_found_guessed[guessed_type] += 1
129 | if first_item != options.boundary:
130 | if correct == guessed and guessed_type == correct_type:
131 | counts.correct_tags += 1
132 | counts.token_counter += 1
133 |
134 | last_guessed = guessed
135 | last_correct = correct
136 | last_guessed_type = guessed_type
137 | last_correct_type = correct_type
138 |
139 | if in_correct:
140 | counts.correct_chunk += 1
141 | counts.t_correct_chunk[last_correct_type] += 1
142 |
143 | return counts
144 |
145 |
146 |
147 | def uniq(iterable):
148 | seen = set()
149 | return [i for i in iterable if not (i in seen or seen.add(i))]
150 |
151 |
152 | def calculate_metrics(correct, guessed, total):
153 | tp, fp, fn = correct, guessed-correct, total-correct
154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
156 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
157 | return Metrics(tp, fp, fn, p, r, f)
158 |
159 |
160 | def metrics(counts):
161 | c = counts
162 | overall = calculate_metrics(
163 | c.correct_chunk, c.found_guessed, c.found_correct
164 | )
165 | by_type = {}
166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
167 | by_type[t] = calculate_metrics(
168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
169 | )
170 | return overall, by_type
171 |
172 |
173 | def report(counts, out=None):
174 | if out is None:
175 | out = sys.stdout
176 |
177 | overall, by_type = metrics(counts)
178 |
179 | c = counts
180 | out.write('processed %d tokens with %d phrases; ' %
181 | (c.token_counter, c.found_correct))
182 | out.write('found: %d phrases; correct: %d.\n' %
183 | (c.found_guessed, c.correct_chunk))
184 |
185 | if c.token_counter > 0:
186 | out.write('accuracy: %6.2f%%; ' %
187 | (100.*c.correct_tags/c.token_counter))
188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
191 |
192 | for i, m in sorted(by_type.items()):
193 | out.write('%17s: ' % i)
194 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
195 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
197 |
198 |
199 | def report_notprint(counts, out=None):
200 | if out is None:
201 | out = sys.stdout
202 |
203 | overall, by_type = metrics(counts)
204 |
205 | c = counts
206 | final_report = []
207 | line = []
208 | line.append('processed %d tokens with %d phrases; ' %
209 | (c.token_counter, c.found_correct))
210 | line.append('found: %d phrases; correct: %d.\n' %
211 | (c.found_guessed, c.correct_chunk))
212 | final_report.append("".join(line))
213 |
214 | if c.token_counter > 0:
215 | line = []
216 | line.append('accuracy: %6.2f%%; ' %
217 | (100.*c.correct_tags/c.token_counter))
218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
221 | final_report.append("".join(line))
222 |
223 | for i, m in sorted(by_type.items()):
224 | line = []
225 | line.append('%17s: ' % i)
226 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
227 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
229 | final_report.append("".join(line))
230 | return final_report
231 |
232 |
233 | def end_of_chunk(prev_tag, tag, prev_type, type_):
234 | # check if a chunk ended between the previous and current word
235 | # arguments: previous and current chunk tags, previous and current types
236 | chunk_end = False
237 |
238 | if prev_tag == 'E': chunk_end = True
239 | if prev_tag == 'S': chunk_end = True
240 |
241 | if prev_tag == 'B' and tag == 'B': chunk_end = True
242 | if prev_tag == 'B' and tag == 'S': chunk_end = True
243 | if prev_tag == 'B' and tag == 'O': chunk_end = True
244 | if prev_tag == 'I' and tag == 'B': chunk_end = True
245 | if prev_tag == 'I' and tag == 'S': chunk_end = True
246 | if prev_tag == 'I' and tag == 'O': chunk_end = True
247 |
248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
249 | chunk_end = True
250 |
251 | # these chunks are assumed to have length 1
252 | if prev_tag == ']': chunk_end = True
253 | if prev_tag == '[': chunk_end = True
254 |
255 | return chunk_end
256 |
257 |
258 | def start_of_chunk(prev_tag, tag, prev_type, type_):
259 | # check if a chunk started between the previous and current word
260 | # arguments: previous and current chunk tags, previous and current types
261 | chunk_start = False
262 |
263 | if tag == 'B': chunk_start = True
264 | if tag == 'S': chunk_start = True
265 |
266 | if prev_tag == 'E' and tag == 'E': chunk_start = True
267 | if prev_tag == 'E' and tag == 'I': chunk_start = True
268 | if prev_tag == 'S' and tag == 'E': chunk_start = True
269 | if prev_tag == 'S' and tag == 'I': chunk_start = True
270 | if prev_tag == 'O' and tag == 'E': chunk_start = True
271 | if prev_tag == 'O' and tag == 'I': chunk_start = True
272 |
273 | if tag != 'O' and tag != '.' and prev_type != type_:
274 | chunk_start = True
275 |
276 | # these chunks are assumed to have length 1
277 | if tag == '[': chunk_start = True
278 | if tag == ']': chunk_start = True
279 |
280 | return chunk_start
281 |
282 |
283 | def return_report(input_file):
284 | with codecs.open(input_file, "r", "utf8") as f:
285 | counts = evaluate(f)
286 | return report_notprint(counts)
287 |
288 |
289 | def main(argv):
290 | args = parse_args(argv[1:])
291 |
292 | if args.file is None:
293 | counts = evaluate(sys.stdin, args)
294 | else:
295 | with open(args.file) as f:
296 | counts = evaluate(f, args)
297 | report(counts)
298 |
299 | if __name__ == '__main__':
300 | sys.exit(main(sys.argv))
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/multilingual.md:
--------------------------------------------------------------------------------
1 | ## Models
2 |
3 | There are two multilingual models currently available. We do not plan to release
4 | more single-language models, but we may release `BERT-Large` versions of these
5 | two in the future:
6 |
7 | * **[`BERT-Base, Multilingual Cased (New, recommended)`](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)**:
8 | 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
9 | * **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
10 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
11 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
12 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
13 | parameters
14 |
15 | **The `Multilingual Cased (New)` model also fixes normalization issues in many
16 | languages, so it is recommended in languages with non-Latin alphabets (and is
17 | often better for most languages with Latin alphabets). When using this model,
18 | make sure to pass `--do_lower_case=false` to `run_pretraining.py` and other
19 | scripts.**
20 |
21 | See the [list of languages](#list-of-languages) that the Multilingual model
22 | supports. The Multilingual model does include Chinese (and English), but if your
23 | fine-tuning data is Chinese-only, then the Chinese model will likely produce
24 | better results.
25 |
26 | ## Results
27 |
28 | To evaluate these systems, we use the
29 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a
30 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the
31 | dev and test sets have been translated (by humans) into 15 languages. Note that
32 | the training set was *machine* translated (we used the translations provided by
33 | XNLI, not Google NMT). For clarity, we only report on 6 languages below:
34 |
35 |
36 |
37 | | System | English | Chinese | Spanish | German | Arabic | Urdu |
38 | | --------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- |
39 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 |
40 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 |
41 | | BERT - Translate Train Cased | **81.9** | **76.6** | **77.8** | **75.9** | **70.7** | 61.6 |
42 | | BERT - Translate Train Uncased | 81.4 | 74.2 | 77.3 | 75.2 | 70.5 | 61.7 |
43 | | BERT - Translate Test Uncased | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** |
44 | | BERT - Zero Shot Uncased | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 |
45 |
46 |
47 |
48 | The first two rows are baselines from the XNLI paper and the last three rows are
49 | our results with BERT.
50 |
51 | **Translate Train** means that the MultiNLI training set was machine translated
52 | from English into the foreign language. So training and evaluation were both
53 | done in the foreign language. Unfortunately, training was done on
54 | machine-translated data, so it is impossible to quantify how much of the lower
55 | accuracy (compared to English) is due to the quality of the machine translation
56 | vs. the quality of the pre-trained model.
57 |
58 | **Translate Test** means that the XNLI test set was machine translated from the
59 | foreign language into English. So training and evaluation were both done on
60 | English. However, test evaluation was done on machine-translated English, so the
61 | accuracy depends on the quality of the machine translation system.
62 |
63 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English
64 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case,
65 | machine translation was not involved at all in either the pre-training or
66 | fine-tuning.
67 |
68 | Note that the English result is worse than the 84.2 MultiNLI baseline because
69 | this training used Multilingual BERT rather than English-only BERT. This implies
70 | that for high-resource languages, the Multilingual model is somewhat worse than
71 | a single-language model. However, it is not feasible for us to train and
72 | maintain dozens of single-language models. Therefore, if your goal is to maximize
73 | performance with a language other than English or Chinese, you might find it
74 | beneficial to run pre-training for additional steps starting from our
75 | Multilingual model on data from your language of interest.
76 |
77 | Here is a comparison of training Chinese models with the Multilingual
78 | `BERT-Base` and Chinese-only `BERT-Base`:
79 |
80 | System | Chinese
81 | ----------------------- | -------
82 | XNLI Baseline | 67.0
83 | BERT Multilingual Model | 74.2
84 | BERT Chinese-only Model | 77.2
85 |
86 | Similar to English, the single-language model does 3% better than the
87 | Multilingual model.
88 |
89 | ## Fine-tuning Example
90 |
91 | The multilingual model does **not** require any special consideration or API
92 | changes. We did update the implementation of `BasicTokenizer` in
93 | `tokenization.py` to support Chinese character tokenization, so please update if
94 | you forked it. However, we did not change the tokenization API.
95 |
96 | To test the new models, we did modify `run_classifier.py` to add support for the
97 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language
98 | version of MultiNLI where the dev/test sets have been human-translated, and the
99 | training set has been machine-translated.
100 |
101 | To run the fine-tuning code, please download the
102 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the
103 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip)
104 | and then unpack both .zip files into some directory `$XNLI_DIR`.
105 |
106 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py`
107 | (Chinese by default), so please modify `XnliProcessor` if you want to run on
108 | another language.
109 |
110 | This is a large dataset, so this will training will take a few hours on a GPU
111 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for
112 | debugging, just set `num_train_epochs` to a small value like `0.1`.
113 |
114 | ```shell
115 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12
116 | export XNLI_DIR=/path/to/xnli
117 |
118 | python run_classifier.py \
119 | --task_name=XNLI \
120 | --do_train=true \
121 | --do_eval=true \
122 | --data_dir=$XNLI_DIR \
123 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
124 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
125 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
126 | --max_seq_length=128 \
127 | --train_batch_size=32 \
128 | --learning_rate=5e-5 \
129 | --num_train_epochs=2.0 \
130 | --output_dir=/tmp/xnli_output/
131 | ```
132 |
133 | With the Chinese-only model, the results should look something like this:
134 |
135 | ```
136 | ***** Eval results *****
137 | eval_accuracy = 0.774116
138 | eval_loss = 0.83554
139 | global_step = 24543
140 | loss = 0.74603
141 | ```
142 |
143 | ## Details
144 |
145 | ### Data Source and Sampling
146 |
147 | The languages chosen were the
148 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).
149 | The entire Wikipedia dump for each language (excluding user and talk pages) was
150 | taken as the training data for each language
151 |
152 | However, the size of the Wikipedia for a given language varies greatly, and
153 | therefore low-resource languages may be "under-represented" in terms of the
154 | neural network model (under the assumption that languages are "competing" for
155 | limited model capacity to some extent). At the same time, we also don't want
156 | to overfit the model by performing thousands of epochs over a tiny Wikipedia
157 | for a particular language.
158 |
159 | To balance these two factors, we performed exponentially smoothed weighting of
160 | the data during pre-training data creation (and WordPiece vocab creation). In
161 | other words, let's say that the probability of a language is *P(L)*, e.g.,
162 | *P(English) = 0.21* means that after concatenating all of the Wikipedias
163 | together, 21% of our data is English. We exponentiate each probability by some
164 | factor *S* and then re-normalize, and sample from that distribution. In our case
165 | we use *S=0.7*. So, high-resource languages like English will be under-sampled,
166 | and low-resource languages like Icelandic will be over-sampled. E.g., in the
167 | original distribution English would be sampled 1000x more than Icelandic, but
168 | after smoothing it's only sampled 100x more.
169 |
170 | ### Tokenization
171 |
172 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are
173 | weighted the same way as the data, so low-resource languages are upweighted by
174 | some factor. We intentionally do *not* use any marker to denote the input
175 | language (so that zero-shot training can work).
176 |
177 | Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace
178 | characters, we add spaces around every character in the
179 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\))
180 | before applying WordPiece. This means that Chinese is effectively
181 | character-tokenized. Note that the CJK Unicode block only includes
182 | Chinese-origin characters and does *not* include Hangul Korean or
183 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like
184 | all other languages.
185 |
186 | For all other languages, we apply the
187 | [same recipe as English](https://github.com/google-research/bert#tokenization):
188 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace
189 | tokenization. We understand that accent markers have substantial meaning in some
190 | languages, but felt that the benefits of reducing the effective vocabulary make
191 | up for this. Generally the strong contextual models of BERT should make up for
192 | any ambiguity introduced by stripping accent markers.
193 |
194 | ### List of Languages
195 |
196 | The multilingual model supports the following languages. These languages were
197 | chosen because they are the top 100 languages with the largest Wikipedias:
198 |
199 | * Afrikaans
200 | * Albanian
201 | * Arabic
202 | * Aragonese
203 | * Armenian
204 | * Asturian
205 | * Azerbaijani
206 | * Bashkir
207 | * Basque
208 | * Bavarian
209 | * Belarusian
210 | * Bengali
211 | * Bishnupriya Manipuri
212 | * Bosnian
213 | * Breton
214 | * Bulgarian
215 | * Burmese
216 | * Catalan
217 | * Cebuano
218 | * Chechen
219 | * Chinese (Simplified)
220 | * Chinese (Traditional)
221 | * Chuvash
222 | * Croatian
223 | * Czech
224 | * Danish
225 | * Dutch
226 | * English
227 | * Estonian
228 | * Finnish
229 | * French
230 | * Galician
231 | * Georgian
232 | * German
233 | * Greek
234 | * Gujarati
235 | * Haitian
236 | * Hebrew
237 | * Hindi
238 | * Hungarian
239 | * Icelandic
240 | * Ido
241 | * Indonesian
242 | * Irish
243 | * Italian
244 | * Japanese
245 | * Javanese
246 | * Kannada
247 | * Kazakh
248 | * Kirghiz
249 | * Korean
250 | * Latin
251 | * Latvian
252 | * Lithuanian
253 | * Lombard
254 | * Low Saxon
255 | * Luxembourgish
256 | * Macedonian
257 | * Malagasy
258 | * Malay
259 | * Malayalam
260 | * Marathi
261 | * Minangkabau
262 | * Nepali
263 | * Newar
264 | * Norwegian (Bokmal)
265 | * Norwegian (Nynorsk)
266 | * Occitan
267 | * Persian (Farsi)
268 | * Piedmontese
269 | * Polish
270 | * Portuguese
271 | * Punjabi
272 | * Romanian
273 | * Russian
274 | * Scots
275 | * Serbian
276 | * Serbo-Croatian
277 | * Sicilian
278 | * Slovak
279 | * Slovenian
280 | * South Azerbaijani
281 | * Spanish
282 | * Sundanese
283 | * Swahili
284 | * Swedish
285 | * Tagalog
286 | * Tajik
287 | * Tamil
288 | * Tatar
289 | * Telugu
290 | * Turkish
291 | * Ukrainian
292 | * Urdu
293 | * Uzbek
294 | * Vietnamese
295 | * Volapük
296 | * Waray-Waray
297 | * Welsh
298 | * West Frisian
299 | * Western Punjabi
300 | * Yoruba
301 |
302 | The **Multilingual Cased (New)** release contains additionally **Thai** and
303 | **Mongolian**, which were not included in the original release.
304 |
--------------------------------------------------------------------------------
/baselines/models/bert_mrc/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat.startswith("C"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/baselines/models/bert_wsc_csl/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/baselines/models/bert_ner/tokenization.py:
--------------------------------------------------------------------------------
1 | """Tokenization classes."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import collections
8 | import re
9 | import unicodedata
10 | import six
11 | import tensorflow as tf
12 |
13 |
14 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
15 | """Checks whether the casing config is consistent with the checkpoint name."""
16 |
17 | # The casing has to be passed in by the user and there is no explicit check
18 | # as to whether it matches the checkpoint. The casing information probably
19 | # should have been stored in the bert_config.json file, but it's not, so
20 | # we have to heuristically detect it to validate.
21 |
22 | if not init_checkpoint:
23 | return
24 |
25 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
26 | if m is None:
27 | return
28 |
29 | model_name = m.group(1)
30 |
31 | lower_models = [
32 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
33 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
34 | ]
35 |
36 | cased_models = [
37 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
38 | "multi_cased_L-12_H-768_A-12"
39 | ]
40 |
41 | is_bad_config = False
42 | if model_name in lower_models and not do_lower_case:
43 | is_bad_config = True
44 | actual_flag = "False"
45 | case_name = "lowercased"
46 | opposite_flag = "True"
47 |
48 | if model_name in cased_models and do_lower_case:
49 | is_bad_config = True
50 | actual_flag = "True"
51 | case_name = "cased"
52 | opposite_flag = "False"
53 |
54 | if is_bad_config:
55 | raise ValueError(
56 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
57 | "However, `%s` seems to be a %s model, so you "
58 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
59 | "how the model was pre-training. If this error is wrong, please "
60 | "just comment out this check." % (actual_flag, init_checkpoint,
61 | model_name, case_name, opposite_flag))
62 |
63 |
64 | def convert_to_unicode(text):
65 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
66 | if six.PY3:
67 | if isinstance(text, str):
68 | return text
69 | elif isinstance(text, bytes):
70 | return text.decode("utf-8", "ignore")
71 | else:
72 | raise ValueError("Unsupported string type: %s" % (type(text)))
73 | elif six.PY2:
74 | if isinstance(text, str):
75 | return text.decode("utf-8", "ignore")
76 | elif isinstance(text, unicode):
77 | return text
78 | else:
79 | raise ValueError("Unsupported string type: %s" % (type(text)))
80 | else:
81 | raise ValueError("Not running on Python2 or Python 3?")
82 |
83 |
84 | def printable_text(text):
85 | """Returns text encoded in a way suitable for print or `tf.logging`."""
86 |
87 | # These functions want `str` for both Python2 and Python3, but in one case
88 | # it's a Unicode string and in the other it's a byte string.
89 | if six.PY3:
90 | if isinstance(text, str):
91 | return text
92 | elif isinstance(text, bytes):
93 | return text.decode("utf-8", "ignore")
94 | else:
95 | raise ValueError("Unsupported string type: %s" % (type(text)))
96 | elif six.PY2:
97 | if isinstance(text, str):
98 | return text
99 | elif isinstance(text, unicode):
100 | return text.encode("utf-8")
101 | else:
102 | raise ValueError("Unsupported string type: %s" % (type(text)))
103 | else:
104 | raise ValueError("Not running on Python2 or Python 3?")
105 |
106 |
107 | def load_vocab(vocab_file):
108 | """Loads a vocabulary file into a dictionary."""
109 | vocab = collections.OrderedDict()
110 | index = 0
111 | with tf.gfile.GFile(vocab_file, "r") as reader:
112 | while True:
113 | token = convert_to_unicode(reader.readline())
114 | if not token:
115 | break
116 | token = token.strip()
117 | vocab[token] = index
118 | index += 1
119 | return vocab
120 |
121 |
122 | def convert_by_vocab(vocab, items):
123 | """Converts a sequence of [tokens|ids] using the vocab."""
124 | output = []
125 | for item in items:
126 | if item in vocab:
127 | output.append(vocab[item])
128 | else:
129 | output.append(vocab['[UNK]'])
130 | return output
131 |
132 |
133 | def convert_tokens_to_ids(vocab, tokens):
134 | return convert_by_vocab(vocab, tokens)
135 |
136 |
137 | def convert_ids_to_tokens(inv_vocab, ids):
138 | return convert_by_vocab(inv_vocab, ids)
139 |
140 |
141 | def whitespace_tokenize(text):
142 | """Runs basic whitespace cleaning and splitting on a piece of text."""
143 | text = text.strip()
144 | if not text:
145 | return []
146 | tokens = text.split()
147 | return tokens
148 |
149 |
150 | class FullTokenizer(object):
151 | """Runs end-to-end tokenziation."""
152 |
153 | def __init__(self, vocab_file, do_lower_case=True):
154 | self.vocab = load_vocab(vocab_file)
155 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
156 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
157 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
158 |
159 | def tokenize(self, text):
160 | split_tokens = []
161 | for token in self.basic_tokenizer.tokenize(text):
162 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
163 | split_tokens.append(sub_token)
164 |
165 | return split_tokens
166 |
167 | def convert_tokens_to_ids(self, tokens):
168 | return convert_by_vocab(self.vocab, tokens)
169 |
170 | def convert_ids_to_tokens(self, ids):
171 | return convert_by_vocab(self.inv_vocab, ids)
172 |
173 |
174 | class BasicTokenizer(object):
175 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
176 |
177 | def __init__(self, do_lower_case=True):
178 | """Constructs a BasicTokenizer.
179 |
180 | Args:
181 | do_lower_case: Whether to lower case the input.
182 | """
183 | self.do_lower_case = do_lower_case
184 |
185 | def tokenize(self, text):
186 | """Tokenizes a piece of text."""
187 | text = convert_to_unicode(text)
188 | text = self._clean_text(text)
189 |
190 | # This was added on November 1st, 2018 for the multilingual and Chinese
191 | # models. This is also applied to the English models now, but it doesn't
192 | # matter since the English models were not trained on any Chinese data
193 | # and generally don't have any Chinese data in them (there are Chinese
194 | # characters in the vocabulary because Wikipedia does have some Chinese
195 | # words in the English Wikipedia.).
196 | text = self._tokenize_chinese_chars(text)
197 |
198 | orig_tokens = whitespace_tokenize(text)
199 | split_tokens = []
200 | for token in orig_tokens:
201 | if self.do_lower_case:
202 | token = token.lower()
203 | token = self._run_strip_accents(token)
204 | split_tokens.extend(self._run_split_on_punc(token))
205 |
206 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
207 | return output_tokens
208 |
209 | def _run_strip_accents(self, text):
210 | """Strips accents from a piece of text."""
211 | text = unicodedata.normalize("NFD", text)
212 | output = []
213 | for char in text:
214 | cat = unicodedata.category(char)
215 | if cat == "Mn":
216 | continue
217 | output.append(char)
218 | return "".join(output)
219 |
220 | def _run_split_on_punc(self, text):
221 | """Splits punctuation on a piece of text."""
222 | chars = list(text)
223 | i = 0
224 | start_new_word = True
225 | output = []
226 | while i < len(chars):
227 | char = chars[i]
228 | if _is_punctuation(char):
229 | output.append([char])
230 | start_new_word = True
231 | else:
232 | if start_new_word:
233 | output.append([])
234 | start_new_word = False
235 | output[-1].append(char)
236 | i += 1
237 |
238 | return ["".join(x) for x in output]
239 |
240 | def _tokenize_chinese_chars(self, text):
241 | """Adds whitespace around any CJK character."""
242 | output = []
243 | for char in text:
244 | cp = ord(char)
245 | if self._is_chinese_char(cp):
246 | output.append(" ")
247 | output.append(char)
248 | output.append(" ")
249 | else:
250 | output.append(char)
251 | return "".join(output)
252 |
253 | def _is_chinese_char(self, cp):
254 | """Checks whether CP is the codepoint of a CJK character."""
255 | # This defines a "chinese character" as anything in the CJK Unicode block:
256 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
257 | #
258 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
259 | # despite its name. The modern Korean Hangul alphabet is a different block,
260 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
261 | # space-separated words, so they are not treated specially and handled
262 | # like the all of the other languages.
263 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
264 | (cp >= 0x3400 and cp <= 0x4DBF) or #
265 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
266 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
267 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
268 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
269 | (cp >= 0xF900 and cp <= 0xFAFF) or #
270 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
271 | return True
272 |
273 | return False
274 |
275 | def _clean_text(self, text):
276 | """Performs invalid character removal and whitespace cleanup on text."""
277 | output = []
278 | for char in text:
279 | cp = ord(char)
280 | if cp == 0 or cp == 0xfffd or _is_control(char):
281 | continue
282 | if _is_whitespace(char):
283 | output.append(" ")
284 | else:
285 | output.append(char)
286 | return "".join(output)
287 |
288 |
289 | class WordpieceTokenizer(object):
290 | """Runs WordPiece tokenziation."""
291 |
292 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
293 | self.vocab = vocab
294 | self.unk_token = unk_token
295 | self.max_input_chars_per_word = max_input_chars_per_word
296 |
297 | def tokenize(self, text):
298 | """Tokenizes a piece of text into its word pieces.
299 |
300 | This uses a greedy longest-match-first algorithm to perform tokenization
301 | using the given vocabulary.
302 |
303 | For example:
304 | input = "unaffable"
305 | output = ["un", "##aff", "##able"]
306 |
307 | Args:
308 | text: A single token or whitespace separated tokens. This should have
309 | already been passed through `BasicTokenizer.
310 |
311 | Returns:
312 | A list of wordpiece tokens.
313 | """
314 |
315 | text = convert_to_unicode(text)
316 |
317 | output_tokens = []
318 | for token in whitespace_tokenize(text):
319 | chars = list(token)
320 | if len(chars) > self.max_input_chars_per_word:
321 | output_tokens.append(self.unk_token)
322 | continue
323 |
324 | is_bad = False
325 | start = 0
326 | sub_tokens = []
327 | while start < len(chars):
328 | end = len(chars)
329 | cur_substr = None
330 | while start < end:
331 | substr = "".join(chars[start:end])
332 | if start > 0:
333 | substr = "##" + substr
334 | if substr in self.vocab:
335 | cur_substr = substr
336 | break
337 | end -= 1
338 | if cur_substr is None:
339 | is_bad = True
340 | break
341 | sub_tokens.append(cur_substr)
342 | start = end
343 |
344 | if is_bad:
345 | output_tokens.append(self.unk_token)
346 | else:
347 | output_tokens.extend(sub_tokens)
348 | return output_tokens
349 |
350 |
351 | def _is_whitespace(char):
352 | """Checks whether `chars` is a whitespace character."""
353 | # \t, \n, and \r are technically contorl characters but we treat them
354 | # as whitespace since they are generally considered as such.
355 | if char == " " or char == "\t" or char == "\n" or char == "\r":
356 | return True
357 | cat = unicodedata.category(char)
358 | if cat == "Zs":
359 | return True
360 | return False
361 |
362 |
363 | def _is_control(char):
364 | """Checks whether `chars` is a control character."""
365 | # These are technically control characters but we count them as whitespace
366 | # characters.
367 | if char == "\t" or char == "\n" or char == "\r":
368 | return False
369 | cat = unicodedata.category(char)
370 | if cat.startswith("C"):
371 | return True
372 | return False
373 |
374 |
375 | def _is_punctuation(char):
376 | """Checks whether `chars` is a punctuation character."""
377 | cp = ord(char)
378 | # We treat all non-letter/number ASCII as punctuation.
379 | # Characters such as "^", "$", and "`" are not in the Unicode
380 | # Punctuation class but we treat them as punctuation anyways, for
381 | # consistency.
382 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
383 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
384 | return True
385 | cat = unicodedata.category(char)
386 | if cat.startswith("P"):
387 | return True
388 | return False
389 |
--------------------------------------------------------------------------------