├── .gitignore
├── banner.png
├── qrcode.jpg
├── sponsor.png
├── baseline
├── __init__.py
├── README.md
├── cmrc2018_evaluate.py
├── optimization.py
├── tokenization.py
├── modeling.py
└── run_cmrc2018_drcd_baseline.py
├── README_CN.md
├── README.md
├── data
└── cmrc2018_evaluate.py
├── squad-style-data
└── cmrc2018_evaluate.py
└── LICENCE
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */.DS_Store
3 |
--------------------------------------------------------------------------------
/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/banner.png
--------------------------------------------------------------------------------
/qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/qrcode.jpg
--------------------------------------------------------------------------------
/sponsor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/sponsor.png
--------------------------------------------------------------------------------
/baseline/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](./README_CN.md) | [**English**](./README.md)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | 本目录包含[第二届“讯飞杯”中文机器阅读理解评测(CMRC 2018)](https://hfl-rc.github.io/cmrc2018/)所使用的数据。本数据集已被计算语言学顶级国际会议[EMNLP 2019](http://emnlp-ijcnlp2019.org)录用。
15 |
16 | **Title: A Span-Extraction Dataset for Chinese Machine Reading Comprehension**
17 | Authors: Yiming Cui, Ting Liu, Wanxiang Che, Li Xiao, Zhipeng Chen, Wentao Ma, Shijin Wang, Guoping Hu
18 | Link: [https://www.aclweb.org/anthology/D19-1600/](https://www.aclweb.org/anthology/D19-1600/)
19 | Venue: EMNLP-IJCNLP 2019
20 |
21 | ### 开放式挑战排行榜 (new!)
22 | 想了解在CMRC 2018数据上表现最好的模型吗?请查阅排行榜。
23 | [https://ymcui.github.io/cmrc2018/](https://ymcui.github.io/cmrc2018/)
24 |
25 | ### CMRC 2018 公开数据集
26 | 请通过CodaLab Worksheet下载CMRC 2018公开数据集(训练集,开发集)。
27 | [https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce](https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce)
28 |
29 | ### 提交方法
30 | 如果你想要在**隐藏的测试集、挑战集上测试你的模型**,请通过以下步骤提交你的模型。
31 | [https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/](https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/)
32 |
33 | **需要注意的是,[CLUE](https://github.com/CLUEbenchmark/CLUE)上提供的测试集仅是CMRC 2018的部分子集。正式评测仍需通过上述方法得到完整测试集、挑战集上的结果。**
34 |
35 |
36 | ### 通过🤗datasets快速加载
37 | 你可以通过[HuggingFace `datasets` library](https://github.com/huggingface/datasets)工具包快速加载数据集:
38 |
39 | ```python
40 | !pip install datasets
41 | from datasets import load_dataset
42 | dataset = load_dataset('cmrc2018')
43 | ```
44 | 关于`datasets`工具包的更多选项和使用细节可以通过这里访问了解:https://github.com/huggingface/datasets
45 |
46 | ### 引用
47 | 如果你在你的工作中使用了我们的数据,请引用下列文献:
48 |
49 | ```
50 | @inproceedings{cui-emnlp2019-cmrc2018,
51 | title = "A Span-Extraction Dataset for {C}hinese Machine Reading Comprehension",
52 | author = "Cui, Yiming and
53 | Liu, Ting and
54 | Che, Wanxiang and
55 | Xiao, Li and
56 | Chen, Zhipeng and
57 | Ma, Wentao and
58 | Wang, Shijin and
59 | Hu, Guoping",
60 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
61 | month = nov,
62 | year = "2019",
63 | address = "Hong Kong, China",
64 | publisher = "Association for Computational Linguistics",
65 | url = "https://www.aclweb.org/anthology/D19-1600",
66 | doi = "10.18653/v1/D19-1600",
67 | pages = "5886--5891",
68 | }
69 | ```
70 | ### International Standard Language Resource Number (ISLRN)
71 | ISLRN: 013-662-947-043-2
72 |
73 | http://www.islrn.org/resources/resources_info/7952/
74 |
75 | ### 哈工大讯飞联合实验室官方微信公众号
76 | 欢迎关注哈工大讯飞联合实验室(HFL)微信公众号,了解最新的技术动态。
77 |
78 | 
79 |
80 | ### 联系我们
81 | 请提交Issue。
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](./README_CN.md) | [**English**](./README.md)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | This repository contains the data for [The Second Evaluation Workshop on Chinese Machine Reading Comprehension (CMRC 2018)](https://hfl-rc.github.io/cmrc2018/). We will present our paper on [EMNLP 2019](http://emnlp-ijcnlp2019.org).
15 |
16 | **Title: A Span-Extraction Dataset for Chinese Machine Reading Comprehension**
17 | Authors: Yiming Cui, Ting Liu, Wanxiang Che, Li Xiao, Zhipeng Chen, Wentao Ma, Shijin Wang, Guoping Hu
18 | Link: [https://www.aclweb.org/anthology/D19-1600/](https://www.aclweb.org/anthology/D19-1600/)
19 | Venue: EMNLP-IJCNLP 2019
20 |
21 | ### Open Challenge Leaderboard (New!)
22 | Keep track of the latest state-of-the-art systems on CMRC 2018 dataset.
23 | [https://ymcui.github.io/cmrc2018/](https://ymcui.github.io/cmrc2018/)
24 |
25 | ### CMRC 2018 Public Datasets
26 | Please download CMRC 2018 public datasets via the following CodaLab Worksheet.
27 | [https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce](https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce)
28 |
29 | ### Submission Guidelines
30 | If you would like to **test your model on the hidden test and challenge set**, please follow the instructions on how to submit your model via CodaLab worksheet.
31 | [https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/](https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/)
32 |
33 | **Note that the test set on [CLUE](https://github.com/CLUEbenchmark/CLUE) is NOT the complete test set. If you wish to evaluate your model OFFICIALLY on CMRC 2018, you should follow the guidelines here. **
34 |
35 | ### Quick Load Through 🤗datasets
36 | You can also access this dataset as part of the [HuggingFace `datasets` library](https://github.com/huggingface/datasets) library as follow:
37 |
38 | ```python
39 | !pip install datasets
40 | from datasets import load_dataset
41 | dataset = load_dataset('cmrc2018')
42 | ```
43 | More details on the options and usage for this library can be found on the `nlp` repository at https://github.com/huggingface/nlp
44 |
45 | ### Reference
46 | If you wish to use our data in your research, please cite:
47 |
48 | ```
49 | @inproceedings{cui-emnlp2019-cmrc2018,
50 | title = "A Span-Extraction Dataset for {C}hinese Machine Reading Comprehension",
51 | author = "Cui, Yiming and
52 | Liu, Ting and
53 | Che, Wanxiang and
54 | Xiao, Li and
55 | Chen, Zhipeng and
56 | Ma, Wentao and
57 | Wang, Shijin and
58 | Hu, Guoping",
59 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
60 | month = nov,
61 | year = "2019",
62 | address = "Hong Kong, China",
63 | publisher = "Association for Computational Linguistics",
64 | url = "https://www.aclweb.org/anthology/D19-1600",
65 | doi = "10.18653/v1/D19-1600",
66 | pages = "5886--5891",
67 | }
68 | ```
69 | ### International Standard Language Resource Number (ISLRN)
70 | ISLRN: 013-662-947-043-2
71 |
72 | http://www.islrn.org/resources/resources_info/7952/
73 |
74 | ### Official HFL WeChat Account
75 | Follow Joint Laboratory of HIT and iFLYTEK Research (HFL) on WeChat.
76 |
77 | 
78 |
79 | ### Contact us
80 | Please submit an issue.
81 |
--------------------------------------------------------------------------------
/data/cmrc2018_evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5
5 | Note:
6 | v5 formatted output, add usage description
7 | v4 fixed segmentation issues
8 | '''
9 | from __future__ import print_function
10 | from collections import Counter, OrderedDict
11 | import string
12 | import re
13 | import argparse
14 | import json
15 | import sys
16 | reload(sys)
17 | sys.setdefaultencoding('utf8')
18 | import nltk
19 | import pdb
20 |
21 | # split Chinese with English
22 | def mixed_segmentation(in_str, rm_punc=False):
23 | in_str = str(in_str).decode('utf-8').lower().strip()
24 | segs_out = []
25 | temp_str = ""
26 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
27 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
28 | '「','」','(',')','-','~','『','』']
29 | for char in in_str:
30 | if rm_punc and char in sp_char:
31 | continue
32 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
33 | if temp_str != "":
34 | ss = nltk.word_tokenize(temp_str)
35 | segs_out.extend(ss)
36 | temp_str = ""
37 | segs_out.append(char)
38 | else:
39 | temp_str += char
40 |
41 | #handling last part
42 | if temp_str != "":
43 | ss = nltk.word_tokenize(temp_str)
44 | segs_out.extend(ss)
45 |
46 | return segs_out
47 |
48 |
49 | # remove punctuation
50 | def remove_punctuation(in_str):
51 | in_str = str(in_str).decode('utf-8').lower().strip()
52 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
53 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
54 | '「','」','(',')','-','~','『','』']
55 | out_segs = []
56 | for char in in_str:
57 | if char in sp_char:
58 | continue
59 | else:
60 | out_segs.append(char)
61 | return ''.join(out_segs)
62 |
63 |
64 | # find longest common string
65 | def find_lcs(s1, s2):
66 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
67 | mmax = 0
68 | p = 0
69 | for i in range(len(s1)):
70 | for j in range(len(s2)):
71 | if s1[i] == s2[j]:
72 | m[i+1][j+1] = m[i][j]+1
73 | if m[i+1][j+1] > mmax:
74 | mmax=m[i+1][j+1]
75 | p=i+1
76 | return s1[p-mmax:p], mmax
77 |
78 | #
79 | def evaluate(ground_truth_file, prediction_file):
80 | f1 = 0
81 | em = 0
82 | total_count = 0
83 | skip_count = 0
84 | for instance in ground_truth_file:
85 | context_id = instance['context_id'].strip()
86 | context_text = instance['context_text'].strip()
87 | for qas in instance['qas']:
88 | total_count += 1
89 | query_id = qas['query_id'].strip()
90 | query_text = qas['query_text'].strip()
91 | answers = qas['answers']
92 |
93 | if query_id not in prediction_file:
94 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
95 | skip_count += 1
96 | continue
97 |
98 | prediction = str(prediction_file[query_id])
99 | f1 += calc_f1_score(answers, prediction)
100 | em += calc_em_score(answers, prediction)
101 |
102 | f1_score = 100.0 * f1 / total_count
103 | em_score = 100.0 * em / total_count
104 | return f1_score, em_score, total_count, skip_count
105 |
106 |
107 | def calc_f1_score(answers, prediction):
108 | f1_scores = []
109 | for ans in answers:
110 | ans_segs = mixed_segmentation(ans, rm_punc=True)
111 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
112 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
113 | if lcs_len == 0:
114 | f1_scores.append(0)
115 | continue
116 | precision = 1.0*lcs_len/len(prediction_segs)
117 | recall = 1.0*lcs_len/len(ans_segs)
118 | f1 = (2*precision*recall)/(precision+recall)
119 | f1_scores.append(f1)
120 | return max(f1_scores)
121 |
122 |
123 | def calc_em_score(answers, prediction):
124 | em = 0
125 | for ans in answers:
126 | ans_ = remove_punctuation(ans)
127 | prediction_ = remove_punctuation(prediction)
128 | if ans_ == prediction_:
129 | em = 1
130 | break
131 | return em
132 |
133 | if __name__ == '__main__':
134 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
135 | parser.add_argument('dataset_file', help='Official dataset file')
136 | parser.add_argument('prediction_file', help='Your prediction File')
137 | args = parser.parse_args()
138 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
139 | prediction_file = json.load(open(args.prediction_file, 'rb'))
140 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
141 | AVG = (EM+F1)*0.5
142 | output_result = OrderedDict()
143 | output_result['AVERAGE'] = '%.3f' % AVG
144 | output_result['F1'] = '%.3f' % F1
145 | output_result['EM'] = '%.3f' % EM
146 | output_result['TOTAL'] = TOTAL
147 | output_result['SKIP'] = SKIP
148 | output_result['FILE'] = args.prediction_file
149 | print(json.dumps(output_result))
150 |
--------------------------------------------------------------------------------
/baseline/README.md:
--------------------------------------------------------------------------------
1 | # BERT Baselines for CMRC 2018 and DRCD
2 | This repository contains the BERT baseline systems for CMRC 2018 and DRCD.
3 | These are two Chinese Span-Extraction Machine Reading Comprehension datasets (like SQuAD) publically available.
4 |
5 | ## Content
6 |
7 | | Section | Description |
8 | |-|-|
9 | | [Datasets](#Datasets) | CMRC 2018 and DRCD |
10 | | [Usage](#Usage) | How to train and test |
11 | | [Baseline Results](#Baseline-Results) | BERT baseline results |
12 |
13 | ## Datasets
14 | CMRC 2018 (Simplified Chinese): [https://github.com/ymcui/cmrc2018](https://github.com/ymcui/cmrc2018)
15 |
16 | DRCD (Traditional Chinese): [https://github.com/DRCSolutionService/DRCD](https://github.com/DRCSolutionService/DRCD)
17 |
18 | You can download these datasets through the links above, or alternatively, you can also download directly from `data` directory. Note that, we use SQuAD-like CMRC 2018 datasets, which could be accessed through [link](https://github.com/ymcui/cmrc2018/tree/master/squad-style-data).
19 |
20 | For more Chinese machine reading comprehension datasets, please infer: [https://github.com/ymcui/Chinese-RC-Datasets](https://github.com/ymcui/Chinese-RC-Datasets)
21 |
22 |
23 | ## Dependency Requirements
24 | There is nothing special dependency requirements, except for `TensorFlow==1.12`. Could also work on other versions of TensorFlow (not tested).
25 |
26 | The code is based on official BERT implementation of `run_squad.py`.
27 |
28 | Check: https://github.com/google-research/bert/blob/master/run_squad.py
29 |
30 | ## Usage
31 | ### Step 1: Download BERT weights (skip if you have them)
32 | - [Chinese (base)](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)
33 | - [Multi-lingual (base)](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)
34 |
35 | ### Step 2: Set correct local variables
36 | - `$PATH_TO_BERT`: the path to BERT weights (TensorFlow version)
37 | - `$DATA_DIR`: the path to your dataset
38 | - `$MODEL_DIR`: output directory for your model
39 |
40 | ### Step 3: Training
41 | Then, we use the following script for training. We take CMRC 2018 dataset and multi-lingual BERT as an example.
42 | ```
43 | python run_cmrc2018_drcd_baseline.py \
44 | --vocab_file=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/vocab.txt \
45 | --bert_config_file=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/bert_config.json \
46 | --init_checkpoint=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/bert_model.ckpt \
47 | --do_train=True \
48 | --train_file=${DATA_DIR}/cmrc2018_train.json \
49 | --do_predict=True \
50 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \
51 | --train_batch_size=32 \
52 | --num_train_epochs=2 \
53 | --max_seq_length=512 \
54 | --doc_stride=128 \
55 | --learning_rate=3e-5 \
56 | --save_checkpoints_steps=1000 \
57 | --output_dir=${MODEL_DIR} \
58 | --do_lower_case=False \
59 | --use_tpu=False
60 | ```
61 |
62 | ### Step 4: Evaluation
63 | We use official evaluation script for CMRC 2018 and DRCD.
64 | Note that, as DRCD official does not provide evaluation script, we also use `cmrc2018_evaluate.py` for DRCD.
65 |
66 | `python cmrc2018_evaluate.py cmrc2018_dev.json predictions.json`
67 |
68 |
69 | ## Baseline Results
70 | We provide both `BERT-Chinese` and `BERT-multilingual` baselines.
71 |
72 | **Note that, each baseline is performed on 10 runs, and we take average scores of them for reliable results (do not applied on hidden sets).**
73 |
74 | ### CMRC 2018
75 | | System | DEV-EM | DEV-F1 | TEST-EM | TEST-F1 | CHALLENGE-EM | CHALLENGE-F1 | Note |
76 | | :------ | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
77 | | BERT (Multi-lingual) | 63.6 | 84.0 | 68.4 | 86.5 | 18.5 | 43.0 | - |
78 | | BERT (Chinese) | 63.5 | 83.6 | 67.5 | 85.6 | 18.4 | 42.1 | - |
79 | | P-Reader (single model) | 59.894 | 81.499 | 65.189 | 84.386 | 15.079 | 39.583 | - |
80 | | GM-Reader (ensemble) | 58.931 | 80.069 | 64.045 | 83.046 | 15.675 | 37.315 | - |
81 | | MCA-Reader (ensemble) | 66.698 | 85.538 | 71.175 | 88.090 | 15.476 | 37.104 | - |
82 | | Z-Reader (single model) | 79.776 | 92.696 | 74.178 | 88.145 | 13.889 | 37.422 | - |
83 | | [SRC->DS(±) (Yang et al., 2019)](https://arxiv.org/abs/1904.06652) | 49.2 | 65.4 | - | - | - | - | - |
84 |
85 | > Leaderboard: https://hfl-rc.github.io/cmrc2018/open_challenge/
86 | > Note that, some of the previous submission are using development set for training as well.
87 |
88 |
89 | ### DRCD
90 | | System | DEV-EM | DEV-F1 | TEST-EM | TEST-EM | Note |
91 | | :------ | :-----: | :-----: | :-----: | :-----: | :-----: |
92 | | BERT (Multi-lingual) | 83.0 | 89.7 | 83.2 | 89.8 | - |
93 | | BERT (Chinese) | 82.2 | 89.3 | 81.6 | 88.8 | - |
94 | | [SRC + DS(±) (Yang et al., 2019)](https://arxiv.org/abs/1904.06652) | 55.4 | 67.7 | - | - | - |
95 | | r-net (single model) | - | - | 29.1 | 44.4 | - |
96 |
97 | ## Contact
98 | For help or issues, please submit a GitHub issue.
99 |
--------------------------------------------------------------------------------
/squad-style-data/cmrc2018_evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5 - special
5 | Note:
6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
7 | v5: formatted output, add usage description
8 | v4: fixed segmentation issues
9 | '''
10 | from __future__ import print_function
11 | from collections import Counter, OrderedDict
12 | import string
13 | import re
14 | import argparse
15 | import json
16 | import sys
17 | reload(sys)
18 | sys.setdefaultencoding('utf8')
19 | import nltk
20 | import pdb
21 |
22 | # split Chinese with English
23 | def mixed_segmentation(in_str, rm_punc=False):
24 | in_str = str(in_str).decode('utf-8').lower().strip()
25 | segs_out = []
26 | temp_str = ""
27 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
28 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
29 | '「','」','(',')','-','~','『','』']
30 | for char in in_str:
31 | if rm_punc and char in sp_char:
32 | continue
33 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
34 | if temp_str != "":
35 | ss = nltk.word_tokenize(temp_str)
36 | segs_out.extend(ss)
37 | temp_str = ""
38 | segs_out.append(char)
39 | else:
40 | temp_str += char
41 |
42 | #handling last part
43 | if temp_str != "":
44 | ss = nltk.word_tokenize(temp_str)
45 | segs_out.extend(ss)
46 |
47 | return segs_out
48 |
49 |
50 | # remove punctuation
51 | def remove_punctuation(in_str):
52 | in_str = str(in_str).decode('utf-8').lower().strip()
53 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
54 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
55 | '「','」','(',')','-','~','『','』']
56 | out_segs = []
57 | for char in in_str:
58 | if char in sp_char:
59 | continue
60 | else:
61 | out_segs.append(char)
62 | return ''.join(out_segs)
63 |
64 |
65 | # find longest common string
66 | def find_lcs(s1, s2):
67 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
68 | mmax = 0
69 | p = 0
70 | for i in range(len(s1)):
71 | for j in range(len(s2)):
72 | if s1[i] == s2[j]:
73 | m[i+1][j+1] = m[i][j]+1
74 | if m[i+1][j+1] > mmax:
75 | mmax=m[i+1][j+1]
76 | p=i+1
77 | return s1[p-mmax:p], mmax
78 |
79 | #
80 | def evaluate(ground_truth_file, prediction_file):
81 | f1 = 0
82 | em = 0
83 | total_count = 0
84 | skip_count = 0
85 | for instance in ground_truth_file["data"]:
86 | #context_id = instance['context_id'].strip()
87 | #context_text = instance['context_text'].strip()
88 | for para in instance["paragraphs"]:
89 | for qas in para['qas']:
90 | total_count += 1
91 | query_id = qas['id'].strip()
92 | query_text = qas['question'].strip()
93 | answers = [x["text"] for x in qas['answers']]
94 |
95 | if query_id not in prediction_file:
96 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
97 | skip_count += 1
98 | continue
99 |
100 | prediction = str(prediction_file[query_id]).decode('utf-8')
101 | f1 += calc_f1_score(answers, prediction)
102 | em += calc_em_score(answers, prediction)
103 |
104 | f1_score = 100.0 * f1 / total_count
105 | em_score = 100.0 * em / total_count
106 | return f1_score, em_score, total_count, skip_count
107 |
108 |
109 | def calc_f1_score(answers, prediction):
110 | f1_scores = []
111 | for ans in answers:
112 | ans_segs = mixed_segmentation(ans, rm_punc=True)
113 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
114 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
115 | if lcs_len == 0:
116 | f1_scores.append(0)
117 | continue
118 | precision = 1.0*lcs_len/len(prediction_segs)
119 | recall = 1.0*lcs_len/len(ans_segs)
120 | f1 = (2*precision*recall)/(precision+recall)
121 | f1_scores.append(f1)
122 | return max(f1_scores)
123 |
124 |
125 | def calc_em_score(answers, prediction):
126 | em = 0
127 | for ans in answers:
128 | ans_ = remove_punctuation(ans)
129 | prediction_ = remove_punctuation(prediction)
130 | if ans_ == prediction_:
131 | em = 1
132 | break
133 | return em
134 |
135 | if __name__ == '__main__':
136 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
137 | parser.add_argument('dataset_file', help='Official dataset file')
138 | parser.add_argument('prediction_file', help='Your prediction File')
139 | args = parser.parse_args()
140 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
141 | prediction_file = json.load(open(args.prediction_file, 'rb'))
142 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
143 | AVG = (EM+F1)*0.5
144 | output_result = OrderedDict()
145 | output_result['AVERAGE'] = '%.3f' % AVG
146 | output_result['F1'] = '%.3f' % F1
147 | output_result['EM'] = '%.3f' % EM
148 | output_result['TOTAL'] = TOTAL
149 | output_result['SKIP'] = SKIP
150 | output_result['FILE'] = args.prediction_file
151 | print(json.dumps(output_result))
152 |
153 |
--------------------------------------------------------------------------------
/baseline/cmrc2018_evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | Note:
5 | v6: compatible for both original and SQuAD-style CMRC 2018 datasets. support python3.
6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
7 | v5: formatted output, add usage description
8 | v4: fixed segmentation issues
9 | '''
10 | from __future__ import print_function
11 | from collections import Counter, OrderedDict
12 | import string
13 | import re
14 | import argparse
15 | import json
16 | import sys
17 | import nltk
18 | import pdb
19 |
20 | # split Chinese with English
21 | def mixed_segmentation(in_str, rm_punc=False):
22 | in_str = str(in_str).lower().strip()
23 | segs_out = []
24 | temp_str = ""
25 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
26 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
27 | '「','」','(',')','-','~','『','』']
28 | for char in in_str:
29 | if rm_punc and char in sp_char:
30 | continue
31 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
32 | if temp_str != "":
33 | ss = nltk.word_tokenize(temp_str)
34 | segs_out.extend(ss)
35 | temp_str = ""
36 | segs_out.append(char)
37 | else:
38 | temp_str += char
39 |
40 | #handling last part
41 | if temp_str != "":
42 | ss = nltk.word_tokenize(temp_str)
43 | segs_out.extend(ss)
44 |
45 | return segs_out
46 |
47 | # remove punctuation
48 | def remove_punctuation(in_str):
49 | in_str = str(in_str).lower().strip()
50 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
51 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
52 | '「','」','(',')','-','~','『','』']
53 | out_segs = []
54 | for char in in_str:
55 | if char in sp_char:
56 | continue
57 | else:
58 | out_segs.append(char)
59 | return ''.join(out_segs)
60 |
61 |
62 | # find longest common string
63 | def find_lcs(s1, s2):
64 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
65 | mmax = 0
66 | p = 0
67 | for i in range(len(s1)):
68 | for j in range(len(s2)):
69 | if s1[i] == s2[j]:
70 | m[i+1][j+1] = m[i][j]+1
71 | if m[i+1][j+1] > mmax:
72 | mmax=m[i+1][j+1]
73 | p=i+1
74 | return s1[p-mmax:p], mmax
75 |
76 | #
77 | def evaluate(ground_truth_file, prediction_file):
78 | f1 = 0
79 | em = 0
80 | total_count = 0
81 | skip_count = 0
82 |
83 | data_list = ground_truth_file['data'] if 'data' in ground_truth_file else ground_truth_file
84 | for instance in data_list:
85 | para_list = instance['paragraphs'] if 'paragraphs' in instance else [instance]
86 | for para in para_list:
87 | for qas in para['qas']:
88 | total_count += 1
89 | query_id = qas['id'] if 'id' in qas else qas['query_id']
90 | answers = [x['text'] if isinstance(x, dict) else x for x in qas['answers']]
91 |
92 | if query_id not in prediction_file:
93 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
94 | skip_count += 1
95 | continue
96 |
97 | prediction = str(prediction_file[query_id])
98 | f1 += calc_f1_score(answers, prediction)
99 | em += calc_em_score(answers, prediction)
100 |
101 | f1_score = 100.0 * f1 / total_count
102 | em_score = 100.0 * em / total_count
103 | return f1_score, em_score, total_count, skip_count
104 |
105 |
106 | def calc_f1_score(answers, prediction):
107 | f1_scores = []
108 | for ans in answers:
109 | ans_segs = mixed_segmentation(ans, rm_punc=True)
110 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
111 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
112 | if lcs_len == 0:
113 | f1_scores.append(0)
114 | continue
115 | precision = 1.0*lcs_len/len(prediction_segs)
116 | recall = 1.0*lcs_len/len(ans_segs)
117 | f1 = (2*precision*recall)/(precision+recall)
118 | f1_scores.append(f1)
119 | return max(f1_scores)
120 |
121 |
122 | def calc_em_score(answers, prediction):
123 | em = 0
124 | for ans in answers:
125 | ans_ = remove_punctuation(ans)
126 | prediction_ = remove_punctuation(prediction)
127 | if ans_ == prediction_:
128 | em = 1
129 | break
130 | return em
131 |
132 | if __name__ == '__main__':
133 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
134 | parser.add_argument('dataset_file', help='Official dataset file')
135 | parser.add_argument('prediction_file', help='Your prediction File')
136 | args = parser.parse_args()
137 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
138 | prediction_file = json.load(open(args.prediction_file, 'rb'))
139 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
140 | AVG = (EM+F1)*0.5
141 | output_result = OrderedDict()
142 | output_result['AVERAGE'] = '%.3f' % AVG
143 | output_result['F1'] = '%.3f' % F1
144 | output_result['EM'] = '%.3f' % EM
145 | output_result['TOTAL'] = TOTAL
146 | output_result['SKIP'] = SKIP
147 | output_result['FILE'] = args.prediction_file
148 | print(json.dumps(output_result))
149 |
150 |
--------------------------------------------------------------------------------
/baseline/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/baseline/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat.startswith("C"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | Attribution-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-ShareAlike 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-ShareAlike 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. Share means to provide material to the public by any means or
126 | process that requires permission under the Licensed Rights, such
127 | as reproduction, public display, public performance, distribution,
128 | dissemination, communication, or importation, and to make material
129 | available to the public including in ways that members of the
130 | public may access the material from a place and at a time
131 | individually chosen by them.
132 |
133 | l. Sui Generis Database Rights means rights other than copyright
134 | resulting from Directive 96/9/EC of the European Parliament and of
135 | the Council of 11 March 1996 on the legal protection of databases,
136 | as amended and/or succeeded, as well as other essentially
137 | equivalent rights anywhere in the world.
138 |
139 | m. You means the individual or entity exercising the Licensed Rights
140 | under this Public License. Your has a corresponding meaning.
141 |
142 |
143 | Section 2 -- Scope.
144 |
145 | a. License grant.
146 |
147 | 1. Subject to the terms and conditions of this Public License,
148 | the Licensor hereby grants You a worldwide, royalty-free,
149 | non-sublicensable, non-exclusive, irrevocable license to
150 | exercise the Licensed Rights in the Licensed Material to:
151 |
152 | a. reproduce and Share the Licensed Material, in whole or
153 | in part; and
154 |
155 | b. produce, reproduce, and Share Adapted Material.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. Additional offer from the Licensor -- Adapted Material.
186 | Every recipient of Adapted Material from You
187 | automatically receives an offer from the Licensor to
188 | exercise the Licensed Rights in the Adapted Material
189 | under the conditions of the Adapter's License You apply.
190 |
191 | c. No downstream restrictions. You may not offer or impose
192 | any additional or different terms or conditions on, or
193 | apply any Effective Technological Measures to, the
194 | Licensed Material if doing so restricts exercise of the
195 | Licensed Rights by any recipient of the Licensed
196 | Material.
197 |
198 | 6. No endorsement. Nothing in this Public License constitutes or
199 | may be construed as permission to assert or imply that You
200 | are, or that Your use of the Licensed Material is, connected
201 | with, or sponsored, endorsed, or granted official status by,
202 | the Licensor or others designated to receive attribution as
203 | provided in Section 3(a)(1)(A)(i).
204 |
205 | b. Other rights.
206 |
207 | 1. Moral rights, such as the right of integrity, are not
208 | licensed under this Public License, nor are publicity,
209 | privacy, and/or other similar personality rights; however, to
210 | the extent possible, the Licensor waives and/or agrees not to
211 | assert any such rights held by the Licensor to the limited
212 | extent necessary to allow You to exercise the Licensed
213 | Rights, but not otherwise.
214 |
215 | 2. Patent and trademark rights are not licensed under this
216 | Public License.
217 |
218 | 3. To the extent possible, the Licensor waives any right to
219 | collect royalties from You for the exercise of the Licensed
220 | Rights, whether directly or through a collecting society
221 | under any voluntary or waivable statutory or compulsory
222 | licensing scheme. In all other cases the Licensor expressly
223 | reserves any right to collect such royalties.
224 |
225 |
226 | Section 3 -- License Conditions.
227 |
228 | Your exercise of the Licensed Rights is expressly made subject to the
229 | following conditions.
230 |
231 | a. Attribution.
232 |
233 | 1. If You Share the Licensed Material (including in modified
234 | form), You must:
235 |
236 | a. retain the following if it is supplied by the Licensor
237 | with the Licensed Material:
238 |
239 | i. identification of the creator(s) of the Licensed
240 | Material and any others designated to receive
241 | attribution, in any reasonable manner requested by
242 | the Licensor (including by pseudonym if
243 | designated);
244 |
245 | ii. a copyright notice;
246 |
247 | iii. a notice that refers to this Public License;
248 |
249 | iv. a notice that refers to the disclaimer of
250 | warranties;
251 |
252 | v. a URI or hyperlink to the Licensed Material to the
253 | extent reasonably practicable;
254 |
255 | b. indicate if You modified the Licensed Material and
256 | retain an indication of any previous modifications; and
257 |
258 | c. indicate the Licensed Material is licensed under this
259 | Public License, and include the text of, or the URI or
260 | hyperlink to, this Public License.
261 |
262 | 2. You may satisfy the conditions in Section 3(a)(1) in any
263 | reasonable manner based on the medium, means, and context in
264 | which You Share the Licensed Material. For example, it may be
265 | reasonable to satisfy the conditions by providing a URI or
266 | hyperlink to a resource that includes the required
267 | information.
268 |
269 | 3. If requested by the Licensor, You must remove any of the
270 | information required by Section 3(a)(1)(A) to the extent
271 | reasonably practicable.
272 |
273 | b. ShareAlike.
274 |
275 | In addition to the conditions in Section 3(a), if You Share
276 | Adapted Material You produce, the following conditions also apply.
277 |
278 | 1. The Adapter's License You apply must be a Creative Commons
279 | license with the same License Elements, this version or
280 | later, or a BY-SA Compatible License.
281 |
282 | 2. You must include the text of, or the URI or hyperlink to, the
283 | Adapter's License You apply. You may satisfy this condition
284 | in any reasonable manner based on the medium, means, and
285 | context in which You Share Adapted Material.
286 |
287 | 3. You may not offer or impose any additional or different terms
288 | or conditions on, or apply any Effective Technological
289 | Measures to, Adapted Material that restrict exercise of the
290 | rights granted under the Adapter's License You apply.
291 |
292 |
293 | Section 4 -- Sui Generis Database Rights.
294 |
295 | Where the Licensed Rights include Sui Generis Database Rights that
296 | apply to Your use of the Licensed Material:
297 |
298 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
299 | to extract, reuse, reproduce, and Share all or a substantial
300 | portion of the contents of the database;
301 |
302 | b. if You include all or a substantial portion of the database
303 | contents in a database in which You have Sui Generis Database
304 | Rights, then the database in which You have Sui Generis Database
305 | Rights (but not its individual contents) is Adapted Material,
306 |
307 | including for purposes of Section 3(b); and
308 | c. You must comply with the conditions in Section 3(a) if You Share
309 | all or a substantial portion of the contents of the database.
310 |
311 | For the avoidance of doubt, this Section 4 supplements and does not
312 | replace Your obligations under this Public License where the Licensed
313 | Rights include other Copyright and Similar Rights.
314 |
315 |
316 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
317 |
318 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
319 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
320 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
321 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
322 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
323 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
324 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
325 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
326 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
327 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
328 |
329 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
330 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
331 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
332 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
333 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
334 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
335 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
336 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
337 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
338 |
339 | c. The disclaimer of warranties and limitation of liability provided
340 | above shall be interpreted in a manner that, to the extent
341 | possible, most closely approximates an absolute disclaimer and
342 | waiver of all liability.
343 |
344 |
345 | Section 6 -- Term and Termination.
346 |
347 | a. This Public License applies for the term of the Copyright and
348 | Similar Rights licensed here. However, if You fail to comply with
349 | this Public License, then Your rights under this Public License
350 | terminate automatically.
351 |
352 | b. Where Your right to use the Licensed Material has terminated under
353 | Section 6(a), it reinstates:
354 |
355 | 1. automatically as of the date the violation is cured, provided
356 | it is cured within 30 days of Your discovery of the
357 | violation; or
358 |
359 | 2. upon express reinstatement by the Licensor.
360 |
361 | For the avoidance of doubt, this Section 6(b) does not affect any
362 | right the Licensor may have to seek remedies for Your violations
363 | of this Public License.
364 |
365 | c. For the avoidance of doubt, the Licensor may also offer the
366 | Licensed Material under separate terms or conditions or stop
367 | distributing the Licensed Material at any time; however, doing so
368 | will not terminate this Public License.
369 |
370 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
371 | License.
372 |
373 |
374 | Section 7 -- Other Terms and Conditions.
375 |
376 | a. The Licensor shall not be bound by any additional or different
377 | terms or conditions communicated by You unless expressly agreed.
378 |
379 | b. Any arrangements, understandings, or agreements regarding the
380 | Licensed Material not stated herein are separate from and
381 | independent of the terms and conditions of this Public License.
382 |
383 |
384 | Section 8 -- Interpretation.
385 |
386 | a. For the avoidance of doubt, this Public License does not, and
387 | shall not be interpreted to, reduce, limit, restrict, or impose
388 | conditions on any use of the Licensed Material that could lawfully
389 | be made without permission under this Public License.
390 |
391 | b. To the extent possible, if any provision of this Public License is
392 | deemed unenforceable, it shall be automatically reformed to the
393 | minimum extent necessary to make it enforceable. If the provision
394 | cannot be reformed, it shall be severed from this Public License
395 | without affecting the enforceability of the remaining terms and
396 | conditions.
397 |
398 | c. No term or condition of this Public License will be waived and no
399 | failure to comply consented to unless expressly agreed to by the
400 | Licensor.
401 |
402 | d. Nothing in this Public License constitutes or may be interpreted
403 | as a limitation upon, or waiver of, any privileges and immunities
404 | that apply to the Licensor or You, including from the legal
405 | processes of any jurisdiction or authority.
406 |
407 |
408 | =======================================================================
409 |
410 | Creative Commons is not a party to its public
411 | licenses. Notwithstanding, Creative Commons may elect to apply one of
412 | its public licenses to material it publishes and in those instances
413 | will be considered the “Licensor.” The text of the Creative Commons
414 | public licenses is dedicated to the public domain under the CC0 Public
415 | Domain Dedication. Except for the limited purpose of indicating that
416 | material is shared under a Creative Commons public license or as
417 | otherwise permitted by the Creative Commons policies published at
418 | creativecommons.org/policies, Creative Commons does not authorize the
419 | use of the trademark "Creative Commons" or any other trademark or logo
420 | of Creative Commons without its prior written consent including,
421 | without limitation, in connection with any unauthorized modifications
422 | to any of its public licenses or any other arrangements,
423 | understandings, or agreements concerning use of licensed material. For
424 | the avoidance of doubt, this paragraph does not form part of the
425 | public licenses.
426 |
427 | Creative Commons may be contacted at creativecommons.org.
428 |
--------------------------------------------------------------------------------
/baseline/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import six
27 | import tensorflow as tf
28 |
29 |
30 | class BertConfig(object):
31 | """Configuration for `BertModel`."""
32 |
33 | def __init__(self,
34 | vocab_size,
35 | hidden_size=768,
36 | num_hidden_layers=12,
37 | num_attention_heads=12,
38 | intermediate_size=3072,
39 | hidden_act="gelu",
40 | hidden_dropout_prob=0.1,
41 | attention_probs_dropout_prob=0.1,
42 | max_position_embeddings=512,
43 | type_vocab_size=16,
44 | initializer_range=0.02):
45 | """Constructs BertConfig.
46 |
47 | Args:
48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
49 | hidden_size: Size of the encoder layers and the pooler layer.
50 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
51 | num_attention_heads: Number of attention heads for each attention layer in
52 | the Transformer encoder.
53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
54 | layer in the Transformer encoder.
55 | hidden_act: The non-linear activation function (function or string) in the
56 | encoder and pooler.
57 | hidden_dropout_prob: The dropout probability for all fully connected
58 | layers in the embeddings, encoder, and pooler.
59 | attention_probs_dropout_prob: The dropout ratio for the attention
60 | probabilities.
61 | max_position_embeddings: The maximum sequence length that this model might
62 | ever be used with. Typically set this to something large just in case
63 | (e.g., 512 or 1024 or 2048).
64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
65 | `BertModel`.
66 | initializer_range: The stdev of the truncated_normal_initializer for
67 | initializing all weight matrices.
68 | """
69 | self.vocab_size = vocab_size
70 | self.hidden_size = hidden_size
71 | self.num_hidden_layers = num_hidden_layers
72 | self.num_attention_heads = num_attention_heads
73 | self.hidden_act = hidden_act
74 | self.intermediate_size = intermediate_size
75 | self.hidden_dropout_prob = hidden_dropout_prob
76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
77 | self.max_position_embeddings = max_position_embeddings
78 | self.type_vocab_size = type_vocab_size
79 | self.initializer_range = initializer_range
80 |
81 | @classmethod
82 | def from_dict(cls, json_object):
83 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
84 | config = BertConfig(vocab_size=None)
85 | for (key, value) in six.iteritems(json_object):
86 | config.__dict__[key] = value
87 | return config
88 |
89 | @classmethod
90 | def from_json_file(cls, json_file):
91 | """Constructs a `BertConfig` from a json file of parameters."""
92 | with tf.gfile.GFile(json_file, "r") as reader:
93 | text = reader.read()
94 | return cls.from_dict(json.loads(text))
95 |
96 | def to_dict(self):
97 | """Serializes this instance to a Python dictionary."""
98 | output = copy.deepcopy(self.__dict__)
99 | return output
100 |
101 | def to_json_string(self):
102 | """Serializes this instance to a JSON string."""
103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
104 |
105 |
106 | class BertModel(object):
107 | """BERT model ("Bidirectional Encoder Representations from Transformers").
108 |
109 | Example usage:
110 |
111 | ```python
112 | # Already been converted into WordPiece token ids
113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
116 |
117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
119 |
120 | model = modeling.BertModel(config=config, is_training=True,
121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
122 |
123 | label_embeddings = tf.get_variable(...)
124 | pooled_output = model.get_pooled_output()
125 | logits = tf.matmul(pooled_output, label_embeddings)
126 | ...
127 | ```
128 | """
129 |
130 | def __init__(self,
131 | config,
132 | is_training,
133 | input_ids,
134 | input_mask=None,
135 | token_type_ids=None,
136 | use_one_hot_embeddings=True,
137 | scope=None):
138 | """Constructor for BertModel.
139 |
140 | Args:
141 | config: `BertConfig` instance.
142 | is_training: bool. true for training model, false for eval model. Controls
143 | whether dropout will be applied.
144 | input_ids: int32 Tensor of shape [batch_size, seq_length].
145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
149 | it is much faster if this is True, on the CPU or GPU, it is faster if
150 | this is False.
151 | scope: (optional) variable scope. Defaults to "bert".
152 |
153 | Raises:
154 | ValueError: The config is invalid or one of the input tensor shapes
155 | is invalid.
156 | """
157 | config = copy.deepcopy(config)
158 | if not is_training:
159 | config.hidden_dropout_prob = 0.0
160 | config.attention_probs_dropout_prob = 0.0
161 |
162 | input_shape = get_shape_list(input_ids, expected_rank=2)
163 | batch_size = input_shape[0]
164 | seq_length = input_shape[1]
165 |
166 | if input_mask is None:
167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
168 |
169 | if token_type_ids is None:
170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
171 |
172 | with tf.variable_scope(scope, default_name="bert"):
173 | with tf.variable_scope("embeddings"):
174 | # Perform embedding lookup on the word ids.
175 | (self.embedding_output, self.embedding_table) = embedding_lookup(
176 | input_ids=input_ids,
177 | vocab_size=config.vocab_size,
178 | embedding_size=config.hidden_size,
179 | initializer_range=config.initializer_range,
180 | word_embedding_name="word_embeddings",
181 | use_one_hot_embeddings=use_one_hot_embeddings)
182 |
183 | # Add positional embeddings and token type embeddings, then layer
184 | # normalize and perform dropout.
185 | self.embedding_output = embedding_postprocessor(
186 | input_tensor=self.embedding_output,
187 | use_token_type=True,
188 | token_type_ids=token_type_ids,
189 | token_type_vocab_size=config.type_vocab_size,
190 | token_type_embedding_name="token_type_embeddings",
191 | use_position_embeddings=True,
192 | position_embedding_name="position_embeddings",
193 | initializer_range=config.initializer_range,
194 | max_position_embeddings=config.max_position_embeddings,
195 | dropout_prob=config.hidden_dropout_prob)
196 |
197 | with tf.variable_scope("encoder"):
198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
199 | # mask of shape [batch_size, seq_length, seq_length] which is used
200 | # for the attention scores.
201 | attention_mask = create_attention_mask_from_input_mask(
202 | input_ids, input_mask)
203 |
204 | # Run the stacked transformer.
205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
206 | self.all_encoder_layers = transformer_model(
207 | input_tensor=self.embedding_output,
208 | attention_mask=attention_mask,
209 | hidden_size=config.hidden_size,
210 | num_hidden_layers=config.num_hidden_layers,
211 | num_attention_heads=config.num_attention_heads,
212 | intermediate_size=config.intermediate_size,
213 | intermediate_act_fn=get_activation(config.hidden_act),
214 | hidden_dropout_prob=config.hidden_dropout_prob,
215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
216 | initializer_range=config.initializer_range,
217 | do_return_all_layers=True)
218 |
219 | self.sequence_output = self.all_encoder_layers[-1]
220 | # The "pooler" converts the encoded sequence tensor of shape
221 | # [batch_size, seq_length, hidden_size] to a tensor of shape
222 | # [batch_size, hidden_size]. This is necessary for segment-level
223 | # (or segment-pair-level) classification tasks where we need a fixed
224 | # dimensional representation of the segment.
225 | with tf.variable_scope("pooler"):
226 | # We "pool" the model by simply taking the hidden state corresponding
227 | # to the first token. We assume that this has been pre-trained
228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
229 | self.pooled_output = tf.layers.dense(
230 | first_token_tensor,
231 | config.hidden_size,
232 | activation=tf.tanh,
233 | kernel_initializer=create_initializer(config.initializer_range))
234 |
235 | def get_pooled_output(self):
236 | return self.pooled_output
237 |
238 | def get_sequence_output(self):
239 | """Gets final hidden layer of encoder.
240 |
241 | Returns:
242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
243 | to the final hidden of the transformer encoder.
244 | """
245 | return self.sequence_output
246 |
247 | def get_all_encoder_layers(self):
248 | return self.all_encoder_layers
249 |
250 | def get_embedding_output(self):
251 | """Gets output of the embedding lookup (i.e., input to the transformer).
252 |
253 | Returns:
254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
255 | to the output of the embedding layer, after summing the word
256 | embeddings with the positional embeddings and the token type embeddings,
257 | then performing layer normalization. This is the input to the transformer.
258 | """
259 | return self.embedding_output
260 |
261 | def get_embedding_table(self):
262 | return self.embedding_table
263 |
264 |
265 | def gelu(input_tensor):
266 | """Gaussian Error Linear Unit.
267 |
268 | This is a smoother version of the RELU.
269 | Original paper: https://arxiv.org/abs/1606.08415
270 |
271 | Args:
272 | input_tensor: float Tensor to perform activation.
273 |
274 | Returns:
275 | `input_tensor` with the GELU activation applied.
276 | """
277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
278 | return input_tensor * cdf
279 |
280 |
281 | def get_activation(activation_string):
282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
283 |
284 | Args:
285 | activation_string: String name of the activation function.
286 |
287 | Returns:
288 | A Python function corresponding to the activation function. If
289 | `activation_string` is None, empty, or "linear", this will return None.
290 | If `activation_string` is not a string, it will return `activation_string`.
291 |
292 | Raises:
293 | ValueError: The `activation_string` does not correspond to a known
294 | activation.
295 | """
296 |
297 | # We assume that anything that"s not a string is already an activation
298 | # function, so we just return it.
299 | if not isinstance(activation_string, six.string_types):
300 | return activation_string
301 |
302 | if not activation_string:
303 | return None
304 |
305 | act = activation_string.lower()
306 | if act == "linear":
307 | return None
308 | elif act == "relu":
309 | return tf.nn.relu
310 | elif act == "gelu":
311 | return gelu
312 | elif act == "tanh":
313 | return tf.tanh
314 | else:
315 | raise ValueError("Unsupported activation: %s" % act)
316 |
317 |
318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
319 | """Compute the union of the current variables and checkpoint variables."""
320 | assignment_map = {}
321 | initialized_variable_names = {}
322 |
323 | name_to_variable = collections.OrderedDict()
324 | for var in tvars:
325 | name = var.name
326 | m = re.match("^(.*):\\d+$", name)
327 | if m is not None:
328 | name = m.group(1)
329 | name_to_variable[name] = var
330 |
331 | init_vars = tf.train.list_variables(init_checkpoint)
332 |
333 | assignment_map = collections.OrderedDict()
334 | for x in init_vars:
335 | (name, var) = (x[0], x[1])
336 | if name not in name_to_variable:
337 | continue
338 | assignment_map[name] = name
339 | initialized_variable_names[name] = 1
340 | initialized_variable_names[name + ":0"] = 1
341 |
342 | return (assignment_map, initialized_variable_names)
343 |
344 |
345 | def dropout(input_tensor, dropout_prob):
346 | """Perform dropout.
347 |
348 | Args:
349 | input_tensor: float Tensor.
350 | dropout_prob: Python float. The probability of dropping out a value (NOT of
351 | *keeping* a dimension as in `tf.nn.dropout`).
352 |
353 | Returns:
354 | A version of `input_tensor` with dropout applied.
355 | """
356 | if dropout_prob is None or dropout_prob == 0.0:
357 | return input_tensor
358 |
359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
360 | return output
361 |
362 |
363 | def layer_norm(input_tensor, name=None):
364 | """Run layer normalization on the last dimension of the tensor."""
365 | return tf.contrib.layers.layer_norm(
366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
367 |
368 |
369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
370 | """Runs layer normalization followed by dropout."""
371 | output_tensor = layer_norm(input_tensor, name)
372 | output_tensor = dropout(output_tensor, dropout_prob)
373 | return output_tensor
374 |
375 |
376 | def create_initializer(initializer_range=0.02):
377 | """Creates a `truncated_normal_initializer` with the given range."""
378 | return tf.truncated_normal_initializer(stddev=initializer_range)
379 |
380 |
381 | def embedding_lookup(input_ids,
382 | vocab_size,
383 | embedding_size=128,
384 | initializer_range=0.02,
385 | word_embedding_name="word_embeddings",
386 | use_one_hot_embeddings=False):
387 | """Looks up words embeddings for id tensor.
388 |
389 | Args:
390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
391 | ids.
392 | vocab_size: int. Size of the embedding vocabulary.
393 | embedding_size: int. Width of the word embeddings.
394 | initializer_range: float. Embedding initialization range.
395 | word_embedding_name: string. Name of the embedding table.
396 | use_one_hot_embeddings: bool. If True, use one-hot method for word
397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
398 | for TPUs.
399 |
400 | Returns:
401 | float Tensor of shape [batch_size, seq_length, embedding_size].
402 | """
403 | # This function assumes that the input is of shape [batch_size, seq_length,
404 | # num_inputs].
405 | #
406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
407 | # reshape to [batch_size, seq_length, 1].
408 | if input_ids.shape.ndims == 2:
409 | input_ids = tf.expand_dims(input_ids, axis=[-1])
410 |
411 | embedding_table = tf.get_variable(
412 | name=word_embedding_name,
413 | shape=[vocab_size, embedding_size],
414 | initializer=create_initializer(initializer_range))
415 |
416 | if use_one_hot_embeddings:
417 | flat_input_ids = tf.reshape(input_ids, [-1])
418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
419 | output = tf.matmul(one_hot_input_ids, embedding_table)
420 | else:
421 | output = tf.nn.embedding_lookup(embedding_table, input_ids)
422 |
423 | input_shape = get_shape_list(input_ids)
424 |
425 | output = tf.reshape(output,
426 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
427 | return (output, embedding_table)
428 |
429 |
430 | def embedding_postprocessor(input_tensor,
431 | use_token_type=False,
432 | token_type_ids=None,
433 | token_type_vocab_size=16,
434 | token_type_embedding_name="token_type_embeddings",
435 | use_position_embeddings=True,
436 | position_embedding_name="position_embeddings",
437 | initializer_range=0.02,
438 | max_position_embeddings=512,
439 | dropout_prob=0.1):
440 | """Performs various post-processing on a word embedding tensor.
441 |
442 | Args:
443 | input_tensor: float Tensor of shape [batch_size, seq_length,
444 | embedding_size].
445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
447 | Must be specified if `use_token_type` is True.
448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
449 | token_type_embedding_name: string. The name of the embedding table variable
450 | for token type ids.
451 | use_position_embeddings: bool. Whether to add position embeddings for the
452 | position of each token in the sequence.
453 | position_embedding_name: string. The name of the embedding table variable
454 | for positional embeddings.
455 | initializer_range: float. Range of the weight initialization.
456 | max_position_embeddings: int. Maximum sequence length that might ever be
457 | used with this model. This can be longer than the sequence length of
458 | input_tensor, but cannot be shorter.
459 | dropout_prob: float. Dropout probability applied to the final output tensor.
460 |
461 | Returns:
462 | float tensor with same shape as `input_tensor`.
463 |
464 | Raises:
465 | ValueError: One of the tensor shapes or input values is invalid.
466 | """
467 | input_shape = get_shape_list(input_tensor, expected_rank=3)
468 | batch_size = input_shape[0]
469 | seq_length = input_shape[1]
470 | width = input_shape[2]
471 |
472 | output = input_tensor
473 |
474 | if use_token_type:
475 | if token_type_ids is None:
476 | raise ValueError("`token_type_ids` must be specified if"
477 | "`use_token_type` is True.")
478 | token_type_table = tf.get_variable(
479 | name=token_type_embedding_name,
480 | shape=[token_type_vocab_size, width],
481 | initializer=create_initializer(initializer_range))
482 | # This vocab will be small so we always do one-hot here, since it is always
483 | # faster for a small vocabulary.
484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
487 | token_type_embeddings = tf.reshape(token_type_embeddings,
488 | [batch_size, seq_length, width])
489 | output += token_type_embeddings
490 |
491 | if use_position_embeddings:
492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
493 | with tf.control_dependencies([assert_op]):
494 | full_position_embeddings = tf.get_variable(
495 | name=position_embedding_name,
496 | shape=[max_position_embeddings, width],
497 | initializer=create_initializer(initializer_range))
498 | # Since the position embedding table is a learned variable, we create it
499 | # using a (long) sequence length `max_position_embeddings`. The actual
500 | # sequence length might be shorter than this, for faster training of
501 | # tasks that do not have long sequences.
502 | #
503 | # So `full_position_embeddings` is effectively an embedding table
504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
506 | # perform a slice.
507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
508 | [seq_length, -1])
509 | num_dims = len(output.shape.as_list())
510 |
511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
512 | # we broadcast among the first dimensions, which is typically just
513 | # the batch size.
514 | position_broadcast_shape = []
515 | for _ in range(num_dims - 2):
516 | position_broadcast_shape.append(1)
517 | position_broadcast_shape.extend([seq_length, width])
518 | position_embeddings = tf.reshape(position_embeddings,
519 | position_broadcast_shape)
520 | output += position_embeddings
521 |
522 | output = layer_norm_and_dropout(output, dropout_prob)
523 | return output
524 |
525 |
526 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
527 | """Create 3D attention mask from a 2D tensor mask.
528 |
529 | Args:
530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
532 |
533 | Returns:
534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
535 | """
536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
537 | batch_size = from_shape[0]
538 | from_seq_length = from_shape[1]
539 |
540 | to_shape = get_shape_list(to_mask, expected_rank=2)
541 | to_seq_length = to_shape[1]
542 |
543 | to_mask = tf.cast(
544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
545 |
546 | # We don't assume that `from_tensor` is a mask (although it could be). We
547 | # don't actually care if we attend *from* padding tokens (only *to* padding)
548 | # tokens so we create a tensor of all ones.
549 | #
550 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
551 | broadcast_ones = tf.ones(
552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
553 |
554 | # Here we broadcast along two dimensions to create the mask.
555 | mask = broadcast_ones * to_mask
556 |
557 | return mask
558 |
559 |
560 | def attention_layer(from_tensor,
561 | to_tensor,
562 | attention_mask=None,
563 | num_attention_heads=1,
564 | size_per_head=512,
565 | query_act=None,
566 | key_act=None,
567 | value_act=None,
568 | attention_probs_dropout_prob=0.0,
569 | initializer_range=0.02,
570 | do_return_2d_tensor=False,
571 | batch_size=None,
572 | from_seq_length=None,
573 | to_seq_length=None):
574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
575 |
576 | This is an implementation of multi-headed attention based on "Attention
577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
578 | this is self-attention. Each timestep in `from_tensor` attends to the
579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
580 |
581 | This function first projects `from_tensor` into a "query" tensor and
582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
583 | of tensors of length `num_attention_heads`, where each tensor is of shape
584 | [batch_size, seq_length, size_per_head].
585 |
586 | Then, the query and key tensors are dot-producted and scaled. These are
587 | softmaxed to obtain attention probabilities. The value tensors are then
588 | interpolated by these probabilities, then concatenated back to a single
589 | tensor and returned.
590 |
591 | In practice, the multi-headed attention are done with transposes and
592 | reshapes rather than actual separate tensors.
593 |
594 | Args:
595 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
596 | from_width].
597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
598 | attention_mask: (optional) int32 Tensor of shape [batch_size,
599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
600 | attention scores will effectively be set to -infinity for any positions in
601 | the mask that are 0, and will be unchanged for positions that are 1.
602 | num_attention_heads: int. Number of attention heads.
603 | size_per_head: int. Size of each attention head.
604 | query_act: (optional) Activation function for the query transform.
605 | key_act: (optional) Activation function for the key transform.
606 | value_act: (optional) Activation function for the value transform.
607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
608 | attention probabilities.
609 | initializer_range: float. Range of the weight initializer.
610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
612 | output will be of shape [batch_size, from_seq_length, num_attention_heads
613 | * size_per_head].
614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
615 | of the 3D version of the `from_tensor` and `to_tensor`.
616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
617 | of the 3D version of the `from_tensor`.
618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
619 | of the 3D version of the `to_tensor`.
620 |
621 | Returns:
622 | float Tensor of shape [batch_size, from_seq_length,
623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
624 | true, this will be of shape [batch_size * from_seq_length,
625 | num_attention_heads * size_per_head]).
626 |
627 | Raises:
628 | ValueError: Any of the arguments or tensor shapes are invalid.
629 | """
630 |
631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
632 | seq_length, width):
633 | output_tensor = tf.reshape(
634 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
635 |
636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
637 | return output_tensor
638 |
639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
641 |
642 | if len(from_shape) != len(to_shape):
643 | raise ValueError(
644 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
645 |
646 | if len(from_shape) == 3:
647 | batch_size = from_shape[0]
648 | from_seq_length = from_shape[1]
649 | to_seq_length = to_shape[1]
650 | elif len(from_shape) == 2:
651 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
652 | raise ValueError(
653 | "When passing in rank 2 tensors to attention_layer, the values "
654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
655 | "must all be specified.")
656 |
657 | # Scalar dimensions referenced here:
658 | # B = batch size (number of sequences)
659 | # F = `from_tensor` sequence length
660 | # T = `to_tensor` sequence length
661 | # N = `num_attention_heads`
662 | # H = `size_per_head`
663 |
664 | from_tensor_2d = reshape_to_matrix(from_tensor)
665 | to_tensor_2d = reshape_to_matrix(to_tensor)
666 |
667 | # `query_layer` = [B*F, N*H]
668 | query_layer = tf.layers.dense(
669 | from_tensor_2d,
670 | num_attention_heads * size_per_head,
671 | activation=query_act,
672 | name="query",
673 | kernel_initializer=create_initializer(initializer_range))
674 |
675 | # `key_layer` = [B*T, N*H]
676 | key_layer = tf.layers.dense(
677 | to_tensor_2d,
678 | num_attention_heads * size_per_head,
679 | activation=key_act,
680 | name="key",
681 | kernel_initializer=create_initializer(initializer_range))
682 |
683 | # `value_layer` = [B*T, N*H]
684 | value_layer = tf.layers.dense(
685 | to_tensor_2d,
686 | num_attention_heads * size_per_head,
687 | activation=value_act,
688 | name="value",
689 | kernel_initializer=create_initializer(initializer_range))
690 |
691 | # `query_layer` = [B, N, F, H]
692 | query_layer = transpose_for_scores(query_layer, batch_size,
693 | num_attention_heads, from_seq_length,
694 | size_per_head)
695 |
696 | # `key_layer` = [B, N, T, H]
697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
698 | to_seq_length, size_per_head)
699 |
700 | # Take the dot product between "query" and "key" to get the raw
701 | # attention scores.
702 | # `attention_scores` = [B, N, F, T]
703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
704 | attention_scores = tf.multiply(attention_scores,
705 | 1.0 / math.sqrt(float(size_per_head)))
706 |
707 | if attention_mask is not None:
708 | # `attention_mask` = [B, 1, F, T]
709 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
710 |
711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
712 | # masked positions, this operation will create a tensor which is 0.0 for
713 | # positions we want to attend and -10000.0 for masked positions.
714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
715 |
716 | # Since we are adding it to the raw scores before the softmax, this is
717 | # effectively the same as removing these entirely.
718 | attention_scores += adder
719 |
720 | # Normalize the attention scores to probabilities.
721 | # `attention_probs` = [B, N, F, T]
722 | attention_probs = tf.nn.softmax(attention_scores)
723 |
724 | # This is actually dropping out entire tokens to attend to, which might
725 | # seem a bit unusual, but is taken from the original Transformer paper.
726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
727 |
728 | # `value_layer` = [B, T, N, H]
729 | value_layer = tf.reshape(
730 | value_layer,
731 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
732 |
733 | # `value_layer` = [B, N, T, H]
734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
735 |
736 | # `context_layer` = [B, N, F, H]
737 | context_layer = tf.matmul(attention_probs, value_layer)
738 |
739 | # `context_layer` = [B, F, N, H]
740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
741 |
742 | if do_return_2d_tensor:
743 | # `context_layer` = [B*F, N*H]
744 | context_layer = tf.reshape(
745 | context_layer,
746 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
747 | else:
748 | # `context_layer` = [B, F, N*H]
749 | context_layer = tf.reshape(
750 | context_layer,
751 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
752 |
753 | return context_layer
754 |
755 |
756 | def transformer_model(input_tensor,
757 | attention_mask=None,
758 | hidden_size=768,
759 | num_hidden_layers=12,
760 | num_attention_heads=12,
761 | intermediate_size=3072,
762 | intermediate_act_fn=gelu,
763 | hidden_dropout_prob=0.1,
764 | attention_probs_dropout_prob=0.1,
765 | initializer_range=0.02,
766 | do_return_all_layers=False):
767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
768 |
769 | This is almost an exact implementation of the original Transformer encoder.
770 |
771 | See the original paper:
772 | https://arxiv.org/abs/1706.03762
773 |
774 | Also see:
775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
776 |
777 | Args:
778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
780 | seq_length], with 1 for positions that can be attended to and 0 in
781 | positions that should not be.
782 | hidden_size: int. Hidden size of the Transformer.
783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
784 | num_attention_heads: int. Number of attention heads in the Transformer.
785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
786 | forward) layer.
787 | intermediate_act_fn: function. The non-linear activation function to apply
788 | to the output of the intermediate/feed-forward layer.
789 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
790 | attention_probs_dropout_prob: float. Dropout probability of the attention
791 | probabilities.
792 | initializer_range: float. Range of the initializer (stddev of truncated
793 | normal).
794 | do_return_all_layers: Whether to also return all layers or just the final
795 | layer.
796 |
797 | Returns:
798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
799 | hidden layer of the Transformer.
800 |
801 | Raises:
802 | ValueError: A Tensor shape or parameter is invalid.
803 | """
804 | if hidden_size % num_attention_heads != 0:
805 | raise ValueError(
806 | "The hidden size (%d) is not a multiple of the number of attention "
807 | "heads (%d)" % (hidden_size, num_attention_heads))
808 |
809 | attention_head_size = int(hidden_size / num_attention_heads)
810 | input_shape = get_shape_list(input_tensor, expected_rank=3)
811 | batch_size = input_shape[0]
812 | seq_length = input_shape[1]
813 | input_width = input_shape[2]
814 |
815 | # The Transformer performs sum residuals on all layers so the input needs
816 | # to be the same as the hidden size.
817 | if input_width != hidden_size:
818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
819 | (input_width, hidden_size))
820 |
821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
824 | # help the optimizer.
825 | prev_output = reshape_to_matrix(input_tensor)
826 |
827 | all_layer_outputs = []
828 | for layer_idx in range(num_hidden_layers):
829 | with tf.variable_scope("layer_%d" % layer_idx):
830 | layer_input = prev_output
831 |
832 | with tf.variable_scope("attention"):
833 | attention_heads = []
834 | with tf.variable_scope("self"):
835 | attention_head = attention_layer(
836 | from_tensor=layer_input,
837 | to_tensor=layer_input,
838 | attention_mask=attention_mask,
839 | num_attention_heads=num_attention_heads,
840 | size_per_head=attention_head_size,
841 | attention_probs_dropout_prob=attention_probs_dropout_prob,
842 | initializer_range=initializer_range,
843 | do_return_2d_tensor=True,
844 | batch_size=batch_size,
845 | from_seq_length=seq_length,
846 | to_seq_length=seq_length)
847 | attention_heads.append(attention_head)
848 |
849 | attention_output = None
850 | if len(attention_heads) == 1:
851 | attention_output = attention_heads[0]
852 | else:
853 | # In the case where we have other sequences, we just concatenate
854 | # them to the self-attention head before the projection.
855 | attention_output = tf.concat(attention_heads, axis=-1)
856 |
857 | # Run a linear projection of `hidden_size` then add a residual
858 | # with `layer_input`.
859 | with tf.variable_scope("output"):
860 | attention_output = tf.layers.dense(
861 | attention_output,
862 | hidden_size,
863 | kernel_initializer=create_initializer(initializer_range))
864 | attention_output = dropout(attention_output, hidden_dropout_prob)
865 | attention_output = layer_norm(attention_output + layer_input)
866 |
867 | # The activation is only applied to the "intermediate" hidden layer.
868 | with tf.variable_scope("intermediate"):
869 | intermediate_output = tf.layers.dense(
870 | attention_output,
871 | intermediate_size,
872 | activation=intermediate_act_fn,
873 | kernel_initializer=create_initializer(initializer_range))
874 |
875 | # Down-project back to `hidden_size` then add the residual.
876 | with tf.variable_scope("output"):
877 | layer_output = tf.layers.dense(
878 | intermediate_output,
879 | hidden_size,
880 | kernel_initializer=create_initializer(initializer_range))
881 | layer_output = dropout(layer_output, hidden_dropout_prob)
882 | layer_output = layer_norm(layer_output + attention_output)
883 | prev_output = layer_output
884 | all_layer_outputs.append(layer_output)
885 |
886 | if do_return_all_layers:
887 | final_outputs = []
888 | for layer_output in all_layer_outputs:
889 | final_output = reshape_from_matrix(layer_output, input_shape)
890 | final_outputs.append(final_output)
891 | return final_outputs
892 | else:
893 | final_output = reshape_from_matrix(prev_output, input_shape)
894 | return final_output
895 |
896 |
897 | def get_shape_list(tensor, expected_rank=None, name=None):
898 | """Returns a list of the shape of tensor, preferring static dimensions.
899 |
900 | Args:
901 | tensor: A tf.Tensor object to find the shape of.
902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
903 | specified and the `tensor` has a different rank, and exception will be
904 | thrown.
905 | name: Optional name of the tensor for the error message.
906 |
907 | Returns:
908 | A list of dimensions of the shape of tensor. All static dimensions will
909 | be returned as python integers, and dynamic dimensions will be returned
910 | as tf.Tensor scalars.
911 | """
912 | if name is None:
913 | name = tensor.name
914 |
915 | if expected_rank is not None:
916 | assert_rank(tensor, expected_rank, name)
917 |
918 | shape = tensor.shape.as_list()
919 |
920 | non_static_indexes = []
921 | for (index, dim) in enumerate(shape):
922 | if dim is None:
923 | non_static_indexes.append(index)
924 |
925 | if not non_static_indexes:
926 | return shape
927 |
928 | dyn_shape = tf.shape(tensor)
929 | for index in non_static_indexes:
930 | shape[index] = dyn_shape[index]
931 | return shape
932 |
933 |
934 | def reshape_to_matrix(input_tensor):
935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
936 | ndims = input_tensor.shape.ndims
937 | if ndims < 2:
938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
939 | (input_tensor.shape))
940 | if ndims == 2:
941 | return input_tensor
942 |
943 | width = input_tensor.shape[-1]
944 | output_tensor = tf.reshape(input_tensor, [-1, width])
945 | return output_tensor
946 |
947 |
948 | def reshape_from_matrix(output_tensor, orig_shape_list):
949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
950 | if len(orig_shape_list) == 2:
951 | return output_tensor
952 |
953 | output_shape = get_shape_list(output_tensor)
954 |
955 | orig_dims = orig_shape_list[0:-1]
956 | width = output_shape[-1]
957 |
958 | return tf.reshape(output_tensor, orig_dims + [width])
959 |
960 |
961 | def assert_rank(tensor, expected_rank, name=None):
962 | """Raises an exception if the tensor rank is not of the expected rank.
963 |
964 | Args:
965 | tensor: A tf.Tensor to check the rank of.
966 | expected_rank: Python integer or list of integers, expected rank.
967 | name: Optional name of the tensor for the error message.
968 |
969 | Raises:
970 | ValueError: If the expected shape doesn't match the actual shape.
971 | """
972 | if name is None:
973 | name = tensor.name
974 |
975 | expected_rank_dict = {}
976 | if isinstance(expected_rank, six.integer_types):
977 | expected_rank_dict[expected_rank] = True
978 | else:
979 | for x in expected_rank:
980 | expected_rank_dict[x] = True
981 |
982 | actual_rank = tensor.shape.ndims
983 | if actual_rank not in expected_rank_dict:
984 | scope_name = tf.get_variable_scope().name
985 | raise ValueError(
986 | "For the tensor `%s` in scope `%s`, the actual rank "
987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
989 |
--------------------------------------------------------------------------------
/baseline/run_cmrc2018_drcd_baseline.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run BERT on SQuAD 1.1 and SQuAD 2.0."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import json
23 | import math
24 | import os
25 | import random
26 | import modeling
27 | import optimization
28 | import tokenization
29 | import six
30 | import tensorflow as tf
31 | import numpy
32 | import pdb
33 |
34 | flags = tf.flags
35 | FLAGS = flags.FLAGS
36 |
37 | ## Required parameters
38 | flags.DEFINE_string(
39 | "bert_config_file", None,
40 | "The config json file corresponding to the pre-trained BERT model. "
41 | "This specifies the model architecture.")
42 |
43 | flags.DEFINE_string("vocab_file", None,
44 | "The vocabulary file that the BERT model was trained on.")
45 |
46 | flags.DEFINE_string(
47 | "output_dir", None,
48 | "The output directory where the model checkpoints will be written.")
49 |
50 | ## Other parameters
51 | flags.DEFINE_string("train_file", None,
52 | "SQuAD json for training. E.g., train-v1.1.json")
53 |
54 | flags.DEFINE_string(
55 | "predict_file", None,
56 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
57 |
58 | flags.DEFINE_string(
59 | "eval_file", None,
60 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
61 |
62 | flags.DEFINE_string(
63 | "init_checkpoint", None,
64 | "Initial checkpoint (usually from a pre-trained BERT model).")
65 |
66 | flags.DEFINE_bool(
67 | "do_lower_case", False,
68 | "Whether to lower case the input text. Should be True for uncased "
69 | "models and False for cased models.")
70 |
71 | flags.DEFINE_integer(
72 | "max_seq_length", 384,
73 | "The maximum total input sequence length after WordPiece tokenization. "
74 | "Sequences longer than this will be truncated, and sequences shorter "
75 | "than this will be padded.")
76 |
77 | flags.DEFINE_integer(
78 | "doc_stride", 128,
79 | "When splitting up a long document into chunks, how much stride to "
80 | "take between chunks.")
81 |
82 | flags.DEFINE_integer(
83 | "max_query_length", 64,
84 | "The maximum number of tokens for the question. Questions longer than "
85 | "this will be truncated to this length.")
86 |
87 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
88 |
89 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
90 |
91 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
92 |
93 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
94 |
95 | flags.DEFINE_integer("predict_batch_size", 8,
96 | "Total batch size for predictions.")
97 |
98 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
99 |
100 | flags.DEFINE_float("num_train_epochs", 3.0,
101 | "Total number of training epochs to perform.")
102 |
103 | flags.DEFINE_float(
104 | "warmup_proportion", 0.1,
105 | "Proportion of training to perform linear learning rate warmup for. "
106 | "E.g., 0.1 = 10% of training.")
107 |
108 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
109 | "How often to save the model checkpoint.")
110 |
111 | flags.DEFINE_integer("iterations_per_loop", 1000,
112 | "How many steps to make in each estimator call.")
113 |
114 | flags.DEFINE_integer(
115 | "n_best_size", 20,
116 | "The total number of n-best predictions to generate in the "
117 | "nbest_predictions.json output file.")
118 |
119 | flags.DEFINE_integer(
120 | "max_answer_length", 30,
121 | "The maximum length of an answer that can be generated. This is needed "
122 | "because the start and end predictions are not conditioned on one another.")
123 |
124 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
125 |
126 | tf.flags.DEFINE_string(
127 | "tpu_name", None,
128 | "The Cloud TPU to use for training. This should be either the name "
129 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
130 | "url.")
131 |
132 | tf.flags.DEFINE_string(
133 | "tpu_zone", None,
134 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
135 | "specified, we will attempt to automatically detect the GCE project from "
136 | "metadata.")
137 |
138 | tf.flags.DEFINE_string(
139 | "gcp_project", None,
140 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
141 | "specified, we will attempt to automatically detect the GCE project from "
142 | "metadata.")
143 |
144 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
145 |
146 | flags.DEFINE_integer(
147 | "num_tpu_cores", 8,
148 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
149 |
150 | flags.DEFINE_bool(
151 | "verbose_logging", False,
152 | "If true, all of the warnings related to data processing will be printed. "
153 | "A number of warnings are expected for a normal SQuAD evaluation.")
154 |
155 | flags.DEFINE_integer("rand_seed", 12345, "set random seed")
156 |
157 | # set random seed (i don't know whether it works or not)
158 | numpy.random.seed(int(FLAGS.rand_seed))
159 | tf.set_random_seed(int(FLAGS.rand_seed))
160 |
161 | #
162 | class SquadExample(object):
163 | """A single training/test example for simple sequence classification.
164 |
165 | For examples without an answer, the start and end position are -1.
166 | """
167 |
168 | def __init__(self,
169 | qas_id,
170 | question_text,
171 | doc_tokens,
172 | orig_answer_text=None,
173 | start_position=None,
174 | end_position=None):
175 | self.qas_id = qas_id
176 | self.question_text = question_text
177 | self.doc_tokens = doc_tokens
178 | self.orig_answer_text = orig_answer_text
179 | self.start_position = start_position
180 | self.end_position = end_position
181 |
182 | def __str__(self):
183 | return self.__repr__()
184 |
185 | def __repr__(self):
186 | s = ""
187 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
188 | s += ", question_text: %s" % (
189 | tokenization.printable_text(self.question_text))
190 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
191 | if self.start_position:
192 | s += ", start_position: %d" % (self.start_position)
193 | if self.start_position:
194 | s += ", end_position: %d" % (self.end_position)
195 | return s
196 |
197 |
198 | class InputFeatures(object):
199 | """A single set of features of data."""
200 |
201 | def __init__(self,
202 | unique_id,
203 | example_index,
204 | doc_span_index,
205 | tokens,
206 | token_to_orig_map,
207 | token_is_max_context,
208 | input_ids,
209 | input_mask,
210 | segment_ids,
211 | input_span_mask,
212 | start_position=None,
213 | end_position=None):
214 | self.unique_id = unique_id
215 | self.example_index = example_index
216 | self.doc_span_index = doc_span_index
217 | self.tokens = tokens
218 | self.token_to_orig_map = token_to_orig_map
219 | self.token_is_max_context = token_is_max_context
220 | self.input_ids = input_ids
221 | self.input_mask = input_mask
222 | self.segment_ids = segment_ids
223 | self.input_span_mask = input_span_mask
224 | self.start_position = start_position
225 | self.end_position = end_position
226 |
227 | #
228 | def customize_tokenizer(text, do_lower_case=False):
229 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
230 | temp_x = ""
231 | text = tokenization.convert_to_unicode(text)
232 | for c in text:
233 | if tokenizer._is_chinese_char(ord(c)) or tokenization._is_punctuation(c) or tokenization._is_whitespace(c) or tokenization._is_control(c):
234 | temp_x += " " + c + " "
235 | else:
236 | temp_x += c
237 | if do_lower_case:
238 | temp_x = temp_x.lower()
239 | return temp_x.split()
240 |
241 | #
242 | class ChineseFullTokenizer(object):
243 | """Runs end-to-end tokenziation."""
244 |
245 | def __init__(self, vocab_file, do_lower_case=False):
246 | self.vocab = tokenization.load_vocab(vocab_file)
247 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
248 | self.wordpiece_tokenizer = tokenization.WordpieceTokenizer(vocab=self.vocab)
249 | self.do_lower_case = do_lower_case
250 | def tokenize(self, text):
251 | split_tokens = []
252 | for token in customize_tokenizer(text, do_lower_case=self.do_lower_case):
253 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
254 | split_tokens.append(sub_token)
255 |
256 | return split_tokens
257 |
258 | def convert_tokens_to_ids(self, tokens):
259 | return tokenization.convert_by_vocab(self.vocab, tokens)
260 |
261 | def convert_ids_to_tokens(self, ids):
262 | return tokenization.convert_by_vocab(self.inv_vocab, ids)
263 |
264 | #
265 | def read_squad_examples(input_file, is_training):
266 | """Read a SQuAD json file into a list of SquadExample."""
267 | with tf.gfile.Open(input_file, "r") as reader:
268 | input_data = json.load(reader)["data"]
269 |
270 | #
271 | examples = []
272 | for entry in input_data:
273 | for paragraph in entry["paragraphs"]:
274 | paragraph_text = paragraph["context"]
275 | raw_doc_tokens = customize_tokenizer(paragraph_text, do_lower_case=FLAGS.do_lower_case)
276 | doc_tokens = []
277 | char_to_word_offset = []
278 | prev_is_whitespace = True
279 |
280 | k = 0
281 | temp_word = ""
282 | for c in paragraph_text:
283 | if tokenization._is_whitespace(c):
284 | char_to_word_offset.append(k-1)
285 | continue
286 | else:
287 | temp_word += c
288 | char_to_word_offset.append(k)
289 | if FLAGS.do_lower_case:
290 | temp_word = temp_word.lower()
291 | if temp_word == raw_doc_tokens[k]:
292 | doc_tokens.append(temp_word)
293 | temp_word = ""
294 | k += 1
295 |
296 | assert k==len(raw_doc_tokens)
297 |
298 | for qa in paragraph["qas"]:
299 | qas_id = qa["id"]
300 | question_text = qa["question"]
301 | start_position = None
302 | end_position = None
303 | orig_answer_text = None
304 |
305 | if is_training:
306 | answer = qa["answers"][0]
307 | orig_answer_text = answer["text"]
308 |
309 | if orig_answer_text not in paragraph_text:
310 | tf.logging.warning("Could not find answer")
311 | else:
312 | answer_offset = paragraph_text.index(orig_answer_text)
313 | answer_length = len(orig_answer_text)
314 | start_position = char_to_word_offset[answer_offset]
315 | end_position = char_to_word_offset[answer_offset + answer_length - 1]
316 |
317 | # Only add answers where the text can be exactly recovered from the
318 | # document. If this CAN'T happen it's likely due to weird Unicode
319 | # stuff so we will just skip the example.
320 | #
321 | # Note that this means for training mode, every example is NOT
322 | # guaranteed to be preserved.
323 | actual_text = "".join(
324 | doc_tokens[start_position:(end_position + 1)])
325 | cleaned_answer_text = "".join(
326 | tokenization.whitespace_tokenize(orig_answer_text))
327 | if FLAGS.do_lower_case:
328 | cleaned_answer_text = cleaned_answer_text.lower()
329 | if actual_text.find(cleaned_answer_text) == -1:
330 | pdb.set_trace()
331 | tf.logging.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
332 | continue
333 |
334 | example = SquadExample(
335 | qas_id=qas_id,
336 | question_text=question_text,
337 | doc_tokens=doc_tokens,
338 | orig_answer_text=orig_answer_text,
339 | start_position=start_position,
340 | end_position=end_position)
341 | examples.append(example)
342 | tf.logging.info("**********read_squad_examples complete!**********")
343 |
344 | return examples
345 |
346 |
347 | def convert_examples_to_features(examples, tokenizer, max_seq_length,
348 | doc_stride, max_query_length, is_training,
349 | output_fn):
350 | """Loads a data file into a list of `InputBatch`s."""
351 |
352 | unique_id = 1000000000
353 | tokenizer = ChineseFullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
354 |
355 | for (example_index, example) in enumerate(examples):
356 | query_tokens = tokenizer.tokenize(example.question_text)
357 |
358 | if len(query_tokens) > max_query_length:
359 | query_tokens = query_tokens[0:max_query_length]
360 |
361 | tok_to_orig_index = []
362 | orig_to_tok_index = []
363 | all_doc_tokens = []
364 | for (i, token) in enumerate(example.doc_tokens):
365 | orig_to_tok_index.append(len(all_doc_tokens))
366 | sub_tokens = tokenizer.tokenize(token)
367 | for sub_token in sub_tokens:
368 | tok_to_orig_index.append(i)
369 | all_doc_tokens.append(sub_token)
370 |
371 | tok_start_position = None
372 | tok_end_position = None
373 | if is_training:
374 | tok_start_position = orig_to_tok_index[example.start_position]
375 | if example.end_position < len(example.doc_tokens) - 1:
376 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
377 | else:
378 | tok_end_position = len(all_doc_tokens) - 1
379 | (tok_start_position, tok_end_position) = _improve_answer_span(
380 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
381 | example.orig_answer_text)
382 |
383 | # The -3 accounts for [CLS], [SEP] and [SEP]
384 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
385 |
386 | # We can have documents that are longer than the maximum sequence length.
387 | # To deal with this we do a sliding window approach, where we take chunks
388 | # of the up to our max length with a stride of `doc_stride`.
389 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
390 | "DocSpan", ["start", "length"])
391 | doc_spans = []
392 | start_offset = 0
393 | while start_offset < len(all_doc_tokens):
394 | length = len(all_doc_tokens) - start_offset
395 | if length > max_tokens_for_doc:
396 | length = max_tokens_for_doc
397 | doc_spans.append(_DocSpan(start=start_offset, length=length))
398 | if start_offset + length == len(all_doc_tokens):
399 | break
400 | start_offset += min(length, doc_stride)
401 |
402 | for (doc_span_index, doc_span) in enumerate(doc_spans):
403 | tokens = []
404 | token_to_orig_map = {}
405 | token_is_max_context = {}
406 | segment_ids = []
407 | input_span_mask = []
408 | tokens.append("[CLS]")
409 | segment_ids.append(0)
410 | input_span_mask.append(1)
411 | for token in query_tokens:
412 | tokens.append(token)
413 | segment_ids.append(0)
414 | input_span_mask.append(0)
415 | tokens.append("[SEP]")
416 | segment_ids.append(0)
417 | input_span_mask.append(0)
418 |
419 | for i in range(doc_span.length):
420 | split_token_index = doc_span.start + i
421 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
422 |
423 | is_max_context = _check_is_max_context(doc_spans, doc_span_index,
424 | split_token_index)
425 | token_is_max_context[len(tokens)] = is_max_context
426 | tokens.append(all_doc_tokens[split_token_index])
427 | segment_ids.append(1)
428 | input_span_mask.append(1)
429 | tokens.append("[SEP]")
430 | segment_ids.append(1)
431 | input_span_mask.append(0)
432 |
433 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
434 |
435 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
436 | # tokens are attended to.
437 | input_mask = [1] * len(input_ids)
438 |
439 | # Zero-pad up to the sequence length.
440 | while len(input_ids) < max_seq_length:
441 | input_ids.append(0)
442 | input_mask.append(0)
443 | segment_ids.append(0)
444 | input_span_mask.append(0)
445 |
446 | assert len(input_ids) == max_seq_length
447 | assert len(input_mask) == max_seq_length
448 | assert len(segment_ids) == max_seq_length
449 | assert len(input_span_mask) == max_seq_length
450 |
451 | start_position = None
452 | end_position = None
453 | if is_training:
454 | # For training, if our document chunk does not contain an annotation
455 | # we throw it out, since there is nothing to predict.
456 | doc_start = doc_span.start
457 | doc_end = doc_span.start + doc_span.length - 1
458 | out_of_span = False
459 | if not (tok_start_position >= doc_start and
460 | tok_end_position <= doc_end):
461 | out_of_span = True
462 | if out_of_span:
463 | start_position = 0
464 | end_position = 0
465 | else:
466 | doc_offset = len(query_tokens) + 2
467 | start_position = tok_start_position - doc_start + doc_offset
468 | end_position = tok_end_position - doc_start + doc_offset
469 |
470 | if example_index < 3:
471 | tf.logging.info("*** Example ***")
472 | tf.logging.info("unique_id: %s" % (unique_id))
473 | tf.logging.info("example_index: %s" % (example_index))
474 | tf.logging.info("doc_span_index: %s" % (doc_span_index))
475 | tf.logging.info("tokens: %s" % " ".join(
476 | [tokenization.printable_text(x) for x in tokens]))
477 | tf.logging.info("token_to_orig_map: %s" % " ".join(
478 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
479 | tf.logging.info("token_is_max_context: %s" % " ".join([
480 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
481 | ]))
482 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
483 | tf.logging.info(
484 | "input_mask: %s" % " ".join([str(x) for x in input_mask]))
485 | tf.logging.info(
486 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
487 | tf.logging.info(
488 | "input_span_mask: %s" % " ".join([str(x) for x in input_span_mask]))
489 | if is_training:
490 | answer_text = " ".join(tokens[start_position:(end_position + 1)])
491 | tf.logging.info("start_position: %d" % (start_position))
492 | tf.logging.info("end_position: %d" % (end_position))
493 | tf.logging.info(
494 | "answer: %s" % (tokenization.printable_text(answer_text)))
495 |
496 |
497 | feature = InputFeatures(
498 | unique_id=unique_id,
499 | example_index=example_index,
500 | doc_span_index=doc_span_index,
501 | tokens=tokens,
502 | token_to_orig_map=token_to_orig_map,
503 | token_is_max_context=token_is_max_context,
504 | input_ids=input_ids,
505 | input_mask=input_mask,
506 | segment_ids=segment_ids,
507 | input_span_mask=input_span_mask,
508 | start_position=start_position,
509 | end_position=end_position)
510 |
511 | # Run callback
512 | output_fn(feature)
513 |
514 | unique_id += 1
515 |
516 |
517 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
518 | orig_answer_text):
519 | """Returns tokenized answer spans that better match the annotated answer."""
520 |
521 | # The SQuAD annotations are character based. We first project them to
522 | # whitespace-tokenized words. But then after WordPiece tokenization, we can
523 | # often find a "better match". For example:
524 | #
525 | # Question: What year was John Smith born?
526 | # Context: The leader was John Smith (1895-1943).
527 | # Answer: 1895
528 | #
529 | # The original whitespace-tokenized answer will be "(1895-1943).". However
530 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
531 | # the exact answer, 1895.
532 | #
533 | # However, this is not always possible. Consider the following:
534 | #
535 | # Question: What country is the top exporter of electornics?
536 | # Context: The Japanese electronics industry is the lagest in the world.
537 | # Answer: Japan
538 | #
539 | # In this case, the annotator chose "Japan" as a character sub-span of
540 | # the word "Japanese". Since our WordPiece tokenizer does not split
541 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
542 | # in SQuAD, but does happen.
543 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
544 |
545 | for new_start in range(input_start, input_end + 1):
546 | for new_end in range(input_end, new_start - 1, -1):
547 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
548 | if text_span == tok_answer_text:
549 | return (new_start, new_end)
550 |
551 | return (input_start, input_end)
552 |
553 |
554 | def _check_is_max_context(doc_spans, cur_span_index, position):
555 | """Check if this is the 'max context' doc span for the token."""
556 |
557 | # Because of the sliding window approach taken to scoring documents, a single
558 | # token can appear in multiple documents. E.g.
559 | # Doc: the man went to the store and bought a gallon of milk
560 | # Span A: the man went to the
561 | # Span B: to the store and bought
562 | # Span C: and bought a gallon of
563 | # ...
564 | #
565 | # Now the word 'bought' will have two scores from spans B and C. We only
566 | # want to consider the score with "maximum context", which we define as
567 | # the *minimum* of its left and right context (the *sum* of left and
568 | # right context will always be the same, of course).
569 | #
570 | # In the example the maximum context for 'bought' would be span C since
571 | # it has 1 left context and 3 right context, while span B has 4 left context
572 | # and 0 right context.
573 | best_score = None
574 | best_span_index = None
575 | for (span_index, doc_span) in enumerate(doc_spans):
576 | end = doc_span.start + doc_span.length - 1
577 | if position < doc_span.start:
578 | continue
579 | if position > end:
580 | continue
581 | num_left_context = position - doc_span.start
582 | num_right_context = end - position
583 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
584 | if best_score is None or score > best_score:
585 | best_score = score
586 | best_span_index = span_index
587 |
588 | return cur_span_index == best_span_index
589 |
590 |
591 | #
592 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, input_span_mask,
593 | use_one_hot_embeddings):
594 | """Creates a classification model."""
595 | model = modeling.BertModel(
596 | config=bert_config,
597 | is_training=is_training,
598 | input_ids=input_ids,
599 | input_mask=input_mask,
600 | token_type_ids=segment_ids,
601 | use_one_hot_embeddings=use_one_hot_embeddings)
602 |
603 |
604 | final_hidden = model.get_sequence_output()
605 |
606 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
607 | batch_size = final_hidden_shape[0]
608 | seq_length = final_hidden_shape[1]
609 | hidden_size = final_hidden_shape[2]
610 |
611 | output_weights = tf.get_variable(
612 | "cls/squad/output_weights", [2, hidden_size],
613 | initializer=tf.truncated_normal_initializer(stddev=0.02))
614 |
615 | output_bias = tf.get_variable(
616 | "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
617 |
618 | final_hidden_matrix = tf.reshape(final_hidden,
619 | [batch_size * seq_length, hidden_size])
620 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
621 | logits = tf.nn.bias_add(logits, output_bias)
622 |
623 | logits = tf.reshape(logits, [batch_size, seq_length, 2])
624 | logits = tf.transpose(logits, [2, 0, 1])
625 |
626 | unstacked_logits = tf.unstack(logits, axis=0)
627 |
628 | (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
629 |
630 | # apply output mask
631 | adder = (1.0 - tf.cast(input_span_mask, tf.float32)) * -10000.0
632 | start_logits += adder
633 | end_logits += adder
634 |
635 | return (start_logits, end_logits)
636 |
637 |
638 | #
639 | def model_fn_builder(bert_config, init_checkpoint, learning_rate,
640 | num_train_steps, num_warmup_steps, use_tpu,
641 | use_one_hot_embeddings):
642 | """Returns `model_fn` closure for TPUEstimator."""
643 |
644 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
645 | """The `model_fn` for TPUEstimator."""
646 |
647 | tf.logging.info("*** Features ***")
648 | for name in sorted(features.keys()):
649 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
650 |
651 | unique_ids = features["unique_ids"]
652 | input_ids = features["input_ids"]
653 | input_mask = features["input_mask"]
654 | segment_ids = features["segment_ids"]
655 | input_span_mask = features["input_span_mask"]
656 |
657 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
658 |
659 | (start_logits, end_logits) = create_model(
660 | bert_config=bert_config,
661 | is_training=is_training,
662 | input_ids=input_ids,
663 | input_mask=input_mask,
664 | segment_ids=segment_ids,
665 | input_span_mask=input_span_mask,
666 | use_one_hot_embeddings=use_one_hot_embeddings)
667 |
668 | tvars = tf.trainable_variables()
669 |
670 | initialized_variable_names = {}
671 | scaffold_fn = None
672 | if init_checkpoint:
673 | (assignment_map, initialized_variable_names
674 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
675 | if use_tpu:
676 |
677 | def tpu_scaffold():
678 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
679 | return tf.train.Scaffold()
680 |
681 | scaffold_fn = tpu_scaffold
682 | else:
683 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
684 |
685 | tf.logging.info("**** Trainable Variables ****")
686 | for var in tvars:
687 | init_string = ""
688 | if var.name in initialized_variable_names:
689 | init_string = ", *INIT_FROM_CKPT*"
690 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
691 |
692 | output_spec = None
693 | if mode == tf.estimator.ModeKeys.TRAIN:
694 | seq_length = modeling.get_shape_list(input_ids)[1]
695 |
696 | def compute_loss(logits, positions):
697 | on_hot_pos = tf.one_hot(positions, depth=seq_length, dtype=tf.float32)
698 | log_probs = tf.nn.log_softmax(logits, axis=-1)
699 | loss = -tf.reduce_mean(tf.reduce_sum(on_hot_pos * log_probs, axis=-1))
700 | return loss
701 |
702 | start_positions = features["start_positions"]
703 | end_positions = features["end_positions"]
704 |
705 | start_loss = compute_loss(start_logits, start_positions)
706 | end_loss = compute_loss(end_logits, end_positions)
707 | total_loss = (start_loss + end_loss) / 2
708 |
709 | train_op = optimization.create_optimizer(
710 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
711 |
712 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
713 | mode=mode,
714 | loss=total_loss,
715 | train_op=train_op,
716 | scaffold_fn=scaffold_fn)
717 | elif mode == tf.estimator.ModeKeys.PREDICT:
718 | start_logits = tf.nn.log_softmax(start_logits, axis=-1)
719 | end_logits = tf.nn.log_softmax(end_logits, axis=-1)
720 |
721 | predictions = {
722 | "unique_ids": unique_ids,
723 | "start_logits": start_logits,
724 | "end_logits": end_logits,
725 | }
726 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
727 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
728 | else:
729 | raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
730 |
731 | return output_spec
732 |
733 | return model_fn
734 |
735 |
736 | def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
737 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
738 |
739 | name_to_features = {
740 | "unique_ids": tf.FixedLenFeature([], tf.int64),
741 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
742 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
743 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
744 | "input_span_mask": tf.FixedLenFeature([seq_length], tf.int64),
745 | }
746 |
747 | if is_training:
748 | name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
749 | name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
750 |
751 | def _decode_record(record, name_to_features):
752 | """Decodes a record to a TensorFlow example."""
753 | example = tf.parse_single_example(record, name_to_features)
754 |
755 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
756 | # So cast all int64 to int32.
757 | for name in list(example.keys()):
758 | t = example[name]
759 | if t.dtype == tf.int64:
760 | t = tf.to_int32(t)
761 | example[name] = t
762 |
763 | return example
764 |
765 | def input_fn(params):
766 | """The actual input function."""
767 | batch_size = params["batch_size"]
768 |
769 | # For training, we want a lot of parallel reading and shuffling.
770 | # For eval, we want no shuffling and parallel reading doesn't matter.
771 | d = tf.data.TFRecordDataset(input_file)
772 | if is_training:
773 | d = d.repeat()
774 | d = d.shuffle(buffer_size=100)
775 |
776 | d = d.apply(
777 | tf.contrib.data.map_and_batch(
778 | lambda record: _decode_record(record, name_to_features),
779 | batch_size=batch_size,
780 | drop_remainder=drop_remainder))
781 |
782 | return d
783 |
784 | return input_fn
785 |
786 |
787 | RawResult = collections.namedtuple("RawResult",
788 | ["unique_id", "start_logits", "end_logits"])
789 |
790 | def write_predictions(all_examples, all_features, all_results, n_best_size,
791 | max_answer_length, do_lower_case, output_prediction_file,
792 | output_nbest_file):
793 | """Write final predictions to the json file and log-odds of null if needed."""
794 | tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
795 | tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
796 |
797 | example_index_to_features = collections.defaultdict(list)
798 | for feature in all_features:
799 | example_index_to_features[feature.example_index].append(feature)
800 |
801 | unique_id_to_result = {}
802 | for result in all_results:
803 | unique_id_to_result[result.unique_id] = result
804 |
805 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
806 | "PrelimPrediction",
807 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
808 |
809 | all_predictions = collections.OrderedDict()
810 | all_nbest_json = collections.OrderedDict()
811 |
812 | for (example_index, example) in enumerate(all_examples):
813 | features = example_index_to_features[example_index]
814 | prelim_predictions = []
815 |
816 | for (feature_index, feature) in enumerate(features): # multi-trunk
817 | result = unique_id_to_result[feature.unique_id]
818 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
819 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
820 | for start_index in start_indexes:
821 | for end_index in end_indexes:
822 | # We could hypothetically create invalid predictions, e.g., predict
823 | # that the start of the span is in the question. We throw out all
824 | # invalid predictions.
825 | if start_index >= len(feature.tokens):
826 | continue
827 | if end_index >= len(feature.tokens):
828 | continue
829 | if start_index not in feature.token_to_orig_map:
830 | continue
831 | if end_index not in feature.token_to_orig_map:
832 | continue
833 | if not feature.token_is_max_context.get(start_index, False):
834 | continue
835 | if end_index < start_index:
836 | continue
837 | length = end_index - start_index + 1
838 | if length > max_answer_length:
839 | continue
840 | prelim_predictions.append(
841 | _PrelimPrediction(
842 | feature_index=feature_index,
843 | start_index=start_index,
844 | end_index=end_index,
845 | start_logit=result.start_logits[start_index],
846 | end_logit=result.end_logits[end_index]))
847 |
848 | prelim_predictions = sorted(
849 | prelim_predictions,
850 | key=lambda x: (x.start_logit + x.end_logit),
851 | reverse=True)
852 |
853 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
854 | "NbestPrediction", ["text", "start_logit", "end_logit", "start_index", "end_index"])
855 |
856 | seen_predictions = {}
857 | nbest = []
858 | for pred in prelim_predictions:
859 | if len(nbest) >= n_best_size:
860 | break
861 | feature = features[pred.feature_index]
862 | if pred.start_index > 0: # this is a non-null prediction
863 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
864 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
865 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
866 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
867 | tok_text = " ".join(tok_tokens)
868 |
869 | # De-tokenize WordPieces that have been split off.
870 | tok_text = tok_text.replace(" ##", "")
871 | tok_text = tok_text.replace("##", "")
872 |
873 | # Clean whitespace
874 | tok_text = tok_text.strip()
875 | tok_text = " ".join(tok_text.split())
876 | orig_text = " ".join(orig_tokens)
877 |
878 | final_text = get_final_text(tok_text, orig_text, do_lower_case)
879 | final_text = final_text.replace(' ','')
880 | if final_text in seen_predictions:
881 | continue
882 |
883 | seen_predictions[final_text] = True
884 | else:
885 | final_text = ""
886 | seen_predictions[final_text] = True
887 |
888 | nbest.append(
889 | _NbestPrediction(
890 | text=final_text,
891 | start_logit=pred.start_logit,
892 | end_logit=pred.end_logit,
893 | start_index=pred.start_index,
894 | end_index=pred.end_index))
895 |
896 | # In very rare edge cases we could have no valid predictions. So we
897 | # just create a nonce prediction in this case to avoid failure.
898 | if not nbest:
899 | nbest.append(
900 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0, start_index=0, end_index=0))
901 |
902 | assert len(nbest) >= 1
903 |
904 | total_scores = []
905 | best_non_null_entry = None
906 | for entry in nbest:
907 | total_scores.append(entry.start_logit + entry.end_logit)
908 | if not best_non_null_entry:
909 | if entry.text:
910 | best_non_null_entry = entry
911 |
912 | probs = _compute_softmax(total_scores)
913 |
914 | nbest_json = []
915 | for (i, entry) in enumerate(nbest):
916 | output = collections.OrderedDict()
917 | output["text"] = entry.text
918 | output["probability"] = probs[i]
919 | output["start_logit"] = entry.start_logit
920 | output["end_logit"] = entry.end_logit
921 | output["start_index"] = entry.start_index
922 | output["end_index"] = entry.end_index
923 | nbest_json.append(output)
924 |
925 | assert len(nbest_json) >= 1
926 |
927 | all_predictions[example.qas_id] = best_non_null_entry.text
928 | all_nbest_json[example.qas_id] = nbest_json
929 |
930 | with tf.gfile.GFile(output_prediction_file, "w") as writer:
931 | writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n")
932 |
933 | with tf.gfile.GFile(output_nbest_file, "w") as writer:
934 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
935 |
936 |
937 | def get_final_text(pred_text, orig_text, do_lower_case):
938 | """Project the tokenized prediction back to the original text."""
939 |
940 | # When we created the data, we kept track of the alignment between original
941 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
942 | # now `orig_text` contains the span of our original text corresponding to the
943 | # span that we predicted.
944 | #
945 | # However, `orig_text` may contain extra characters that we don't want in
946 | # our prediction.
947 | #
948 | # For example, let's say:
949 | # pred_text = steve smith
950 | # orig_text = Steve Smith's
951 | #
952 | # We don't want to return `orig_text` because it contains the extra "'s".
953 | #
954 | # We don't want to return `pred_text` because it's already been normalized
955 | # (the SQuAD eval script also does punctuation stripping/lower casing but
956 | # our tokenizer does additional normalization like stripping accent
957 | # characters).
958 | #
959 | # What we really want to return is "Steve Smith".
960 | #
961 | # Therefore, we have to apply a semi-complicated alignment heruistic between
962 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
963 | # can fail in certain cases in which case we just return `orig_text`.
964 |
965 | def _strip_spaces(text):
966 | ns_chars = []
967 | ns_to_s_map = collections.OrderedDict()
968 | for (i, c) in enumerate(text):
969 | if c == " ":
970 | continue
971 | ns_to_s_map[len(ns_chars)] = i
972 | ns_chars.append(c)
973 | ns_text = "".join(ns_chars)
974 | return (ns_text, ns_to_s_map)
975 |
976 | # We first tokenize `orig_text`, strip whitespace from the result
977 | # and `pred_text`, and check if they are the same length. If they are
978 | # NOT the same length, the heuristic has failed. If they are the same
979 | # length, we assume the characters are one-to-one aligned.
980 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
981 |
982 | tok_text = " ".join(tokenizer.tokenize(orig_text))
983 |
984 | start_position = tok_text.find(pred_text)
985 | if start_position == -1:
986 | if FLAGS.verbose_logging:
987 | tf.logging.info(
988 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
989 | return orig_text
990 | end_position = start_position + len(pred_text) - 1
991 |
992 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
993 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
994 |
995 | if len(orig_ns_text) != len(tok_ns_text):
996 | if FLAGS.verbose_logging:
997 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
998 | orig_ns_text, tok_ns_text)
999 | return orig_text
1000 |
1001 | # We then project the characters in `pred_text` back to `orig_text` using
1002 | # the character-to-character alignment.
1003 | tok_s_to_ns_map = {}
1004 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
1005 | tok_s_to_ns_map[tok_index] = i
1006 |
1007 | orig_start_position = None
1008 | if start_position in tok_s_to_ns_map:
1009 | ns_start_position = tok_s_to_ns_map[start_position]
1010 | if ns_start_position in orig_ns_to_s_map:
1011 | orig_start_position = orig_ns_to_s_map[ns_start_position]
1012 |
1013 | if orig_start_position is None:
1014 | if FLAGS.verbose_logging:
1015 | tf.logging.info("Couldn't map start position")
1016 | return orig_text
1017 |
1018 | orig_end_position = None
1019 | if end_position in tok_s_to_ns_map:
1020 | ns_end_position = tok_s_to_ns_map[end_position]
1021 | if ns_end_position in orig_ns_to_s_map:
1022 | orig_end_position = orig_ns_to_s_map[ns_end_position]
1023 |
1024 | if orig_end_position is None:
1025 | if FLAGS.verbose_logging:
1026 | tf.logging.info("Couldn't map end position")
1027 | return orig_text
1028 |
1029 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
1030 | return output_text
1031 |
1032 |
1033 | def _get_best_indexes(logits, n_best_size):
1034 | """Get the n-best logits from a list."""
1035 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
1036 |
1037 | best_indexes = []
1038 | for i in range(len(index_and_score)):
1039 | if i >= n_best_size:
1040 | break
1041 | best_indexes.append(index_and_score[i][0])
1042 | return best_indexes
1043 |
1044 |
1045 | def _compute_softmax(scores):
1046 | """Compute softmax probability over raw logits."""
1047 | if not scores:
1048 | return []
1049 |
1050 | max_score = None
1051 | for score in scores:
1052 | if max_score is None or score > max_score:
1053 | max_score = score
1054 |
1055 | exp_scores = []
1056 | total_sum = 0.0
1057 | for score in scores:
1058 | x = math.exp(score - max_score)
1059 | exp_scores.append(x)
1060 | total_sum += x
1061 |
1062 | probs = []
1063 | for score in exp_scores:
1064 | probs.append(score / total_sum)
1065 | return probs
1066 |
1067 |
1068 | class FeatureWriter(object):
1069 | """Writes InputFeature to TF example file."""
1070 |
1071 | def __init__(self, filename, is_training):
1072 | self.filename = filename
1073 | self.is_training = is_training
1074 | self.num_features = 0
1075 | self._writer = tf.python_io.TFRecordWriter(filename)
1076 |
1077 | def process_feature(self, feature):
1078 | """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
1079 | self.num_features += 1
1080 |
1081 | def create_int_feature(values):
1082 | feature = tf.train.Feature(
1083 | int64_list=tf.train.Int64List(value=list(values)))
1084 | return feature
1085 |
1086 | features = collections.OrderedDict()
1087 | features["unique_ids"] = create_int_feature([feature.unique_id])
1088 | features["input_ids"] = create_int_feature(feature.input_ids)
1089 | features["input_mask"] = create_int_feature(feature.input_mask)
1090 | features["segment_ids"] = create_int_feature(feature.segment_ids)
1091 | features["input_span_mask"] = create_int_feature(feature.input_span_mask)
1092 |
1093 | if self.is_training:
1094 | features["start_positions"] = create_int_feature([feature.start_position])
1095 | features["end_positions"] = create_int_feature([feature.end_position])
1096 |
1097 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
1098 | self._writer.write(tf_example.SerializeToString())
1099 |
1100 | def close(self):
1101 | self._writer.close()
1102 |
1103 |
1104 | def validate_flags_or_throw(bert_config):
1105 | """Validate the input FLAGS or throw an exception."""
1106 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
1107 | FLAGS.init_checkpoint)
1108 |
1109 | if not FLAGS.do_train and not FLAGS.do_predict:
1110 | raise ValueError("At least one of `do_train` or `do_predict` must be True.")
1111 |
1112 | if FLAGS.do_train:
1113 | if not FLAGS.train_file:
1114 | raise ValueError(
1115 | "If `do_train` is True, then `train_file` must be specified.")
1116 | if FLAGS.do_predict:
1117 | if not FLAGS.predict_file:
1118 | raise ValueError(
1119 | "If `do_predict` is True, then `predict_file` must be specified.")
1120 | if FLAGS.do_eval:
1121 | if not FLAGS.eval_file:
1122 | raise ValueError(
1123 | "If `do_eval` is True, then `eval_file` must be specified.")
1124 |
1125 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
1126 | raise ValueError(
1127 | "Cannot use sequence length %d because the BERT model "
1128 | "was only trained up to sequence length %d" %
1129 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
1130 |
1131 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
1132 | raise ValueError(
1133 | "The max_seq_length (%d) must be greater than max_query_length "
1134 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
1135 |
1136 |
1137 | def main(_):
1138 | tf.logging.set_verbosity(tf.logging.INFO)
1139 |
1140 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
1141 |
1142 | validate_flags_or_throw(bert_config)
1143 |
1144 | tf.gfile.MakeDirs(FLAGS.output_dir)
1145 |
1146 | tokenizer = tokenization.FullTokenizer(
1147 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
1148 |
1149 | tpu_cluster_resolver = None
1150 | if FLAGS.use_tpu and FLAGS.tpu_name:
1151 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
1152 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
1153 |
1154 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
1155 | run_config = tf.contrib.tpu.RunConfig(
1156 | cluster=tpu_cluster_resolver,
1157 | master=FLAGS.master,
1158 | model_dir=FLAGS.output_dir,
1159 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
1160 | keep_checkpoint_max=2,
1161 | tpu_config=tf.contrib.tpu.TPUConfig(
1162 | iterations_per_loop=FLAGS.iterations_per_loop,
1163 | num_shards=FLAGS.num_tpu_cores,
1164 | per_host_input_for_training=is_per_host))
1165 |
1166 | train_examples = None
1167 | num_train_steps = None
1168 | num_warmup_steps = None
1169 | if FLAGS.do_train:
1170 | train_examples = read_squad_examples(input_file=FLAGS.train_file, is_training=True)
1171 |
1172 | # Pre-shuffle the input to avoid having to make a very large shuffle
1173 | # buffer in in the `input_fn`.
1174 | rng = random.Random(int(FLAGS.rand_seed))
1175 | rng.shuffle(train_examples)
1176 |
1177 | # We write to a temporary file to avoid storing very large constant tensors
1178 | # in memory.
1179 | train_writer = FeatureWriter(
1180 | filename=os.path.join(FLAGS.output_dir, "train.tf_record"),
1181 | is_training=True)
1182 | convert_examples_to_features(
1183 | examples=train_examples,
1184 | tokenizer=tokenizer,
1185 | max_seq_length=FLAGS.max_seq_length,
1186 | doc_stride=FLAGS.doc_stride,
1187 | max_query_length=FLAGS.max_query_length,
1188 | is_training=True,
1189 | output_fn=train_writer.process_feature)
1190 | train_writer.close()
1191 | num_features = train_writer.num_features
1192 | train_examples_len = len(train_examples)
1193 | del train_examples
1194 |
1195 | num_train_steps = int(num_features / FLAGS.train_batch_size * FLAGS.num_train_epochs)
1196 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
1197 |
1198 | tf.logging.info("***** Running training *****")
1199 | tf.logging.info(" Num orig examples = %d", train_examples_len)
1200 | tf.logging.info(" Num split examples = %d", num_features)
1201 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
1202 | tf.logging.info(" Num steps = %d", num_train_steps)
1203 |
1204 | model_fn = model_fn_builder(
1205 | bert_config=bert_config,
1206 | init_checkpoint=FLAGS.init_checkpoint,
1207 | learning_rate=FLAGS.learning_rate,
1208 | num_train_steps=num_train_steps,
1209 | num_warmup_steps=num_warmup_steps,
1210 | use_tpu=FLAGS.use_tpu,
1211 | use_one_hot_embeddings=FLAGS.use_tpu)
1212 |
1213 | # If TPU is not available, this will fall back to normal Estimator on CPU
1214 | # or GPU.
1215 | estimator = tf.contrib.tpu.TPUEstimator(
1216 | use_tpu=FLAGS.use_tpu,
1217 | model_fn=model_fn,
1218 | config=run_config,
1219 | train_batch_size=FLAGS.train_batch_size,
1220 | predict_batch_size=FLAGS.predict_batch_size)
1221 |
1222 | # do training
1223 | if FLAGS.do_train:
1224 | train_writer_filename = train_writer.filename
1225 |
1226 | train_input_fn = input_fn_builder(
1227 | input_file=train_writer_filename,
1228 | seq_length=FLAGS.max_seq_length,
1229 | is_training=True,
1230 | drop_remainder=True)
1231 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
1232 |
1233 | # do predictions
1234 | if FLAGS.do_predict:
1235 | eval_examples = read_squad_examples(
1236 | input_file=FLAGS.predict_file, is_training=False)
1237 |
1238 | eval_writer = FeatureWriter(
1239 | filename=os.path.join(FLAGS.output_dir, "predict.tf_record"),
1240 | is_training=False)
1241 | eval_features = []
1242 |
1243 | def append_feature(feature):
1244 | eval_features.append(feature)
1245 | eval_writer.process_feature(feature)
1246 |
1247 | convert_examples_to_features(
1248 | examples=eval_examples,
1249 | tokenizer=tokenizer,
1250 | max_seq_length=FLAGS.max_seq_length,
1251 | doc_stride=FLAGS.doc_stride,
1252 | max_query_length=FLAGS.max_query_length,
1253 | is_training=False,
1254 | output_fn=append_feature)
1255 | eval_writer.close()
1256 |
1257 | tf.logging.info("***** Running predictions *****")
1258 | tf.logging.info(" Num orig examples = %d", len(eval_examples))
1259 | tf.logging.info(" Num split examples = %d", len(eval_features))
1260 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
1261 |
1262 | all_results = []
1263 |
1264 | predict_input_fn = input_fn_builder(
1265 | input_file=eval_writer.filename,
1266 | seq_length=FLAGS.max_seq_length,
1267 | is_training=False,
1268 | drop_remainder=False)
1269 |
1270 | # If running eval on the TPU, you will need to specify the number of
1271 | # steps.
1272 | all_results = []
1273 | for result in estimator.predict(
1274 | predict_input_fn, yield_single_examples=True):
1275 | if len(all_results) % 1000 == 0:
1276 | tf.logging.info("Processing example: %d" % (len(all_results)))
1277 | unique_id = int(result["unique_ids"])
1278 | start_logits = [float(x) for x in result["start_logits"].flat]
1279 | end_logits = [float(x) for x in result["end_logits"].flat]
1280 | all_results.append(
1281 | RawResult(
1282 | unique_id=unique_id,
1283 | start_logits=start_logits,
1284 | end_logits=end_logits))
1285 |
1286 | output_json_name = "dev_predictions.json"
1287 | output_nbest_name = "dev_nbest_predictions.json"
1288 |
1289 | output_prediction_file = os.path.join(FLAGS.output_dir, output_json_name)
1290 | output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_name)
1291 |
1292 | write_predictions(eval_examples, eval_features, all_results,
1293 | FLAGS.n_best_size, FLAGS.max_answer_length,
1294 | FLAGS.do_lower_case, output_prediction_file,
1295 | output_nbest_file)
1296 |
1297 | # do predictions
1298 | if FLAGS.do_eval:
1299 | eval_examples = read_squad_examples(
1300 | input_file=FLAGS.eval_file, is_training=False)
1301 |
1302 | eval_writer = FeatureWriter(
1303 | filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
1304 | is_training=False)
1305 | eval_features = []
1306 |
1307 | def append_feature(feature):
1308 | eval_features.append(feature)
1309 | eval_writer.process_feature(feature)
1310 |
1311 | convert_examples_to_features(
1312 | examples=eval_examples,
1313 | tokenizer=tokenizer,
1314 | max_seq_length=FLAGS.max_seq_length,
1315 | doc_stride=FLAGS.doc_stride,
1316 | max_query_length=FLAGS.max_query_length,
1317 | is_training=False,
1318 | output_fn=append_feature)
1319 | eval_writer.close()
1320 |
1321 | tf.logging.info("***** Running evals *****")
1322 | tf.logging.info(" Num orig examples = %d", len(eval_examples))
1323 | tf.logging.info(" Num split examples = %d", len(eval_features))
1324 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
1325 |
1326 | all_results = []
1327 |
1328 | predict_input_fn = input_fn_builder(
1329 | input_file=eval_writer.filename,
1330 | seq_length=FLAGS.max_seq_length,
1331 | is_training=False,
1332 | drop_remainder=False)
1333 |
1334 | # If running eval on the TPU, you will need to specify the number of
1335 | # steps.
1336 | all_results = []
1337 | for result in estimator.predict(
1338 | predict_input_fn, yield_single_examples=True):
1339 | if len(all_results) % 1000 == 0:
1340 | tf.logging.info("Processing example: %d" % (len(all_results)))
1341 | unique_id = int(result["unique_ids"])
1342 | start_logits = [float(x) for x in result["start_logits"].flat]
1343 | end_logits = [float(x) for x in result["end_logits"].flat]
1344 | all_results.append(
1345 | RawResult(
1346 | unique_id=unique_id,
1347 | start_logits=start_logits,
1348 | end_logits=end_logits))
1349 |
1350 | output_json_name = "test_predictions.json"
1351 | output_nbest_name = "test_nbest_predictions.json"
1352 |
1353 | output_prediction_file = os.path.join(FLAGS.output_dir, output_json_name)
1354 | output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_name)
1355 |
1356 | write_predictions(eval_examples, eval_features, all_results,
1357 | FLAGS.n_best_size, FLAGS.max_answer_length,
1358 | FLAGS.do_lower_case, output_prediction_file,
1359 | output_nbest_file)
1360 |
1361 |
1362 | if __name__ == "__main__":
1363 | flags.mark_flag_as_required("vocab_file")
1364 | flags.mark_flag_as_required("bert_config_file")
1365 | flags.mark_flag_as_required("output_dir")
1366 | tf.app.run()
1367 |
1368 |
1369 |
1370 |
--------------------------------------------------------------------------------