├── .gitignore ├── banner.png ├── qrcode.jpg ├── sponsor.png ├── baseline ├── __init__.py ├── README.md ├── cmrc2018_evaluate.py ├── optimization.py ├── tokenization.py ├── modeling.py └── run_cmrc2018_drcd_baseline.py ├── README_CN.md ├── README.md ├── data └── cmrc2018_evaluate.py ├── squad-style-data └── cmrc2018_evaluate.py └── LICENCE /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/banner.png -------------------------------------------------------------------------------- /qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/qrcode.jpg -------------------------------------------------------------------------------- /sponsor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/cmrc2018/HEAD/sponsor.png -------------------------------------------------------------------------------- /baseline/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](./README_CN.md) | [**English**](./README.md) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | 10 | GitHub 11 | 12 |

13 | 14 | 本目录包含[第二届“讯飞杯”中文机器阅读理解评测(CMRC 2018)](https://hfl-rc.github.io/cmrc2018/)所使用的数据。本数据集已被计算语言学顶级国际会议[EMNLP 2019](http://emnlp-ijcnlp2019.org)录用。 15 | 16 | **Title: A Span-Extraction Dataset for Chinese Machine Reading Comprehension** 17 | Authors: Yiming Cui, Ting Liu, Wanxiang Che, Li Xiao, Zhipeng Chen, Wentao Ma, Shijin Wang, Guoping Hu 18 | Link: [https://www.aclweb.org/anthology/D19-1600/](https://www.aclweb.org/anthology/D19-1600/) 19 | Venue: EMNLP-IJCNLP 2019 20 | 21 | ### 开放式挑战排行榜 (new!) 22 | 想了解在CMRC 2018数据上表现最好的模型吗?请查阅排行榜。 23 | [https://ymcui.github.io/cmrc2018/](https://ymcui.github.io/cmrc2018/) 24 | 25 | ### CMRC 2018 公开数据集 26 | 请通过CodaLab Worksheet下载CMRC 2018公开数据集(训练集,开发集)。 27 | [https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce](https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce) 28 | 29 | ### 提交方法 30 | 如果你想要在**隐藏的测试集、挑战集上测试你的模型**,请通过以下步骤提交你的模型。 31 | [https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/](https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/) 32 | 33 | **需要注意的是,[CLUE](https://github.com/CLUEbenchmark/CLUE)上提供的测试集仅是CMRC 2018的部分子集。正式评测仍需通过上述方法得到完整测试集、挑战集上的结果。** 34 | 35 | 36 | ### 通过🤗datasets快速加载 37 | 你可以通过[HuggingFace `datasets` library](https://github.com/huggingface/datasets)工具包快速加载数据集: 38 | 39 | ```python 40 | !pip install datasets 41 | from datasets import load_dataset 42 | dataset = load_dataset('cmrc2018') 43 | ``` 44 | 关于`datasets`工具包的更多选项和使用细节可以通过这里访问了解:https://github.com/huggingface/datasets 45 | 46 | ### 引用 47 | 如果你在你的工作中使用了我们的数据,请引用下列文献: 48 | 49 | ``` 50 | @inproceedings{cui-emnlp2019-cmrc2018, 51 | title = "A Span-Extraction Dataset for {C}hinese Machine Reading Comprehension", 52 | author = "Cui, Yiming and 53 | Liu, Ting and 54 | Che, Wanxiang and 55 | Xiao, Li and 56 | Chen, Zhipeng and 57 | Ma, Wentao and 58 | Wang, Shijin and 59 | Hu, Guoping", 60 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 61 | month = nov, 62 | year = "2019", 63 | address = "Hong Kong, China", 64 | publisher = "Association for Computational Linguistics", 65 | url = "https://www.aclweb.org/anthology/D19-1600", 66 | doi = "10.18653/v1/D19-1600", 67 | pages = "5886--5891", 68 | } 69 | ``` 70 | ### International Standard Language Resource Number (ISLRN) 71 | ISLRN: 013-662-947-043-2 72 | 73 | http://www.islrn.org/resources/resources_info/7952/ 74 | 75 | ### 哈工大讯飞联合实验室官方微信公众号 76 | 欢迎关注哈工大讯飞联合实验室(HFL)微信公众号,了解最新的技术动态。 77 | 78 | ![qrcode.png](./qrcode.jpg) 79 | 80 | ### 联系我们 81 | 请提交Issue。 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](./README_CN.md) | [**English**](./README.md) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | 10 | GitHub 11 | 12 |

13 | 14 | This repository contains the data for [The Second Evaluation Workshop on Chinese Machine Reading Comprehension (CMRC 2018)](https://hfl-rc.github.io/cmrc2018/). We will present our paper on [EMNLP 2019](http://emnlp-ijcnlp2019.org). 15 | 16 | **Title: A Span-Extraction Dataset for Chinese Machine Reading Comprehension** 17 | Authors: Yiming Cui, Ting Liu, Wanxiang Che, Li Xiao, Zhipeng Chen, Wentao Ma, Shijin Wang, Guoping Hu 18 | Link: [https://www.aclweb.org/anthology/D19-1600/](https://www.aclweb.org/anthology/D19-1600/) 19 | Venue: EMNLP-IJCNLP 2019 20 | 21 | ### Open Challenge Leaderboard (New!) 22 | Keep track of the latest state-of-the-art systems on CMRC 2018 dataset. 23 | [https://ymcui.github.io/cmrc2018/](https://ymcui.github.io/cmrc2018/) 24 | 25 | ### CMRC 2018 Public Datasets 26 | Please download CMRC 2018 public datasets via the following CodaLab Worksheet. 27 | [https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce](https://worksheets.codalab.org/worksheets/0x92a80d2fab4b4f79a2b4064f7ddca9ce) 28 | 29 | ### Submission Guidelines 30 | If you would like to **test your model on the hidden test and challenge set**, please follow the instructions on how to submit your model via CodaLab worksheet. 31 | [https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/](https://worksheets.codalab.org/worksheets/0x96f61ee5e9914aee8b54bd11e66ec647/) 32 | 33 | **Note that the test set on [CLUE](https://github.com/CLUEbenchmark/CLUE) is NOT the complete test set. If you wish to evaluate your model OFFICIALLY on CMRC 2018, you should follow the guidelines here. ** 34 | 35 | ### Quick Load Through 🤗datasets 36 | You can also access this dataset as part of the [HuggingFace `datasets` library](https://github.com/huggingface/datasets) library as follow: 37 | 38 | ```python 39 | !pip install datasets 40 | from datasets import load_dataset 41 | dataset = load_dataset('cmrc2018') 42 | ``` 43 | More details on the options and usage for this library can be found on the `nlp` repository at https://github.com/huggingface/nlp 44 | 45 | ### Reference 46 | If you wish to use our data in your research, please cite: 47 | 48 | ``` 49 | @inproceedings{cui-emnlp2019-cmrc2018, 50 | title = "A Span-Extraction Dataset for {C}hinese Machine Reading Comprehension", 51 | author = "Cui, Yiming and 52 | Liu, Ting and 53 | Che, Wanxiang and 54 | Xiao, Li and 55 | Chen, Zhipeng and 56 | Ma, Wentao and 57 | Wang, Shijin and 58 | Hu, Guoping", 59 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 60 | month = nov, 61 | year = "2019", 62 | address = "Hong Kong, China", 63 | publisher = "Association for Computational Linguistics", 64 | url = "https://www.aclweb.org/anthology/D19-1600", 65 | doi = "10.18653/v1/D19-1600", 66 | pages = "5886--5891", 67 | } 68 | ``` 69 | ### International Standard Language Resource Number (ISLRN) 70 | ISLRN: 013-662-947-043-2 71 | 72 | http://www.islrn.org/resources/resources_info/7952/ 73 | 74 | ### Official HFL WeChat Account 75 | Follow Joint Laboratory of HIT and iFLYTEK Research (HFL) on WeChat. 76 | 77 | ![qrcode.png](./qrcode.jpg) 78 | 79 | ### Contact us 80 | Please submit an issue. 81 | -------------------------------------------------------------------------------- /data/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 5 | Note: 6 | v5 formatted output, add usage description 7 | v4 fixed segmentation issues 8 | ''' 9 | from __future__ import print_function 10 | from collections import Counter, OrderedDict 11 | import string 12 | import re 13 | import argparse 14 | import json 15 | import sys 16 | reload(sys) 17 | sys.setdefaultencoding('utf8') 18 | import nltk 19 | import pdb 20 | 21 | # split Chinese with English 22 | def mixed_segmentation(in_str, rm_punc=False): 23 | in_str = str(in_str).decode('utf-8').lower().strip() 24 | segs_out = [] 25 | temp_str = "" 26 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 27 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 28 | '「','」','(',')','-','~','『','』'] 29 | for char in in_str: 30 | if rm_punc and char in sp_char: 31 | continue 32 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char: 33 | if temp_str != "": 34 | ss = nltk.word_tokenize(temp_str) 35 | segs_out.extend(ss) 36 | temp_str = "" 37 | segs_out.append(char) 38 | else: 39 | temp_str += char 40 | 41 | #handling last part 42 | if temp_str != "": 43 | ss = nltk.word_tokenize(temp_str) 44 | segs_out.extend(ss) 45 | 46 | return segs_out 47 | 48 | 49 | # remove punctuation 50 | def remove_punctuation(in_str): 51 | in_str = str(in_str).decode('utf-8').lower().strip() 52 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 53 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 54 | '「','」','(',')','-','~','『','』'] 55 | out_segs = [] 56 | for char in in_str: 57 | if char in sp_char: 58 | continue 59 | else: 60 | out_segs.append(char) 61 | return ''.join(out_segs) 62 | 63 | 64 | # find longest common string 65 | def find_lcs(s1, s2): 66 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 67 | mmax = 0 68 | p = 0 69 | for i in range(len(s1)): 70 | for j in range(len(s2)): 71 | if s1[i] == s2[j]: 72 | m[i+1][j+1] = m[i][j]+1 73 | if m[i+1][j+1] > mmax: 74 | mmax=m[i+1][j+1] 75 | p=i+1 76 | return s1[p-mmax:p], mmax 77 | 78 | # 79 | def evaluate(ground_truth_file, prediction_file): 80 | f1 = 0 81 | em = 0 82 | total_count = 0 83 | skip_count = 0 84 | for instance in ground_truth_file: 85 | context_id = instance['context_id'].strip() 86 | context_text = instance['context_text'].strip() 87 | for qas in instance['qas']: 88 | total_count += 1 89 | query_id = qas['query_id'].strip() 90 | query_text = qas['query_text'].strip() 91 | answers = qas['answers'] 92 | 93 | if query_id not in prediction_file: 94 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 95 | skip_count += 1 96 | continue 97 | 98 | prediction = str(prediction_file[query_id]) 99 | f1 += calc_f1_score(answers, prediction) 100 | em += calc_em_score(answers, prediction) 101 | 102 | f1_score = 100.0 * f1 / total_count 103 | em_score = 100.0 * em / total_count 104 | return f1_score, em_score, total_count, skip_count 105 | 106 | 107 | def calc_f1_score(answers, prediction): 108 | f1_scores = [] 109 | for ans in answers: 110 | ans_segs = mixed_segmentation(ans, rm_punc=True) 111 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 112 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 113 | if lcs_len == 0: 114 | f1_scores.append(0) 115 | continue 116 | precision = 1.0*lcs_len/len(prediction_segs) 117 | recall = 1.0*lcs_len/len(ans_segs) 118 | f1 = (2*precision*recall)/(precision+recall) 119 | f1_scores.append(f1) 120 | return max(f1_scores) 121 | 122 | 123 | def calc_em_score(answers, prediction): 124 | em = 0 125 | for ans in answers: 126 | ans_ = remove_punctuation(ans) 127 | prediction_ = remove_punctuation(prediction) 128 | if ans_ == prediction_: 129 | em = 1 130 | break 131 | return em 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 135 | parser.add_argument('dataset_file', help='Official dataset file') 136 | parser.add_argument('prediction_file', help='Your prediction File') 137 | args = parser.parse_args() 138 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 139 | prediction_file = json.load(open(args.prediction_file, 'rb')) 140 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 141 | AVG = (EM+F1)*0.5 142 | output_result = OrderedDict() 143 | output_result['AVERAGE'] = '%.3f' % AVG 144 | output_result['F1'] = '%.3f' % F1 145 | output_result['EM'] = '%.3f' % EM 146 | output_result['TOTAL'] = TOTAL 147 | output_result['SKIP'] = SKIP 148 | output_result['FILE'] = args.prediction_file 149 | print(json.dumps(output_result)) 150 | -------------------------------------------------------------------------------- /baseline/README.md: -------------------------------------------------------------------------------- 1 | # BERT Baselines for CMRC 2018 and DRCD 2 | This repository contains the BERT baseline systems for CMRC 2018 and DRCD. 3 | These are two Chinese Span-Extraction Machine Reading Comprehension datasets (like SQuAD) publically available. 4 | 5 | ## Content 6 | 7 | | Section | Description | 8 | |-|-| 9 | | [Datasets](#Datasets) | CMRC 2018 and DRCD | 10 | | [Usage](#Usage) | How to train and test | 11 | | [Baseline Results](#Baseline-Results) | BERT baseline results | 12 | 13 | ## Datasets 14 | CMRC 2018 (Simplified Chinese): [https://github.com/ymcui/cmrc2018](https://github.com/ymcui/cmrc2018) 15 | 16 | DRCD (Traditional Chinese): [https://github.com/DRCSolutionService/DRCD](https://github.com/DRCSolutionService/DRCD) 17 | 18 | You can download these datasets through the links above, or alternatively, you can also download directly from `data` directory. Note that, we use SQuAD-like CMRC 2018 datasets, which could be accessed through [link](https://github.com/ymcui/cmrc2018/tree/master/squad-style-data). 19 | 20 | For more Chinese machine reading comprehension datasets, please infer: [https://github.com/ymcui/Chinese-RC-Datasets](https://github.com/ymcui/Chinese-RC-Datasets) 21 | 22 | 23 | ## Dependency Requirements 24 | There is nothing special dependency requirements, except for `TensorFlow==1.12`. Could also work on other versions of TensorFlow (not tested). 25 | 26 | The code is based on official BERT implementation of `run_squad.py`. 27 | 28 | Check: https://github.com/google-research/bert/blob/master/run_squad.py 29 | 30 | ## Usage 31 | ### Step 1: Download BERT weights (skip if you have them) 32 | - [Chinese (base)](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 33 | - [Multi-lingual (base)](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip) 34 | 35 | ### Step 2: Set correct local variables 36 | - `$PATH_TO_BERT`: the path to BERT weights (TensorFlow version) 37 | - `$DATA_DIR`: the path to your dataset 38 | - `$MODEL_DIR`: output directory for your model 39 | 40 | ### Step 3: Training 41 | Then, we use the following script for training. We take CMRC 2018 dataset and multi-lingual BERT as an example. 42 | ``` 43 | python run_cmrc2018_drcd_baseline.py \ 44 | --vocab_file=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/vocab.txt \ 45 | --bert_config_file=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/bert_config.json \ 46 | --init_checkpoint=${PATH_TO_BERT}/multi_cased_L-12_H-768_A-12/bert_model.ckpt \ 47 | --do_train=True \ 48 | --train_file=${DATA_DIR}/cmrc2018_train.json \ 49 | --do_predict=True \ 50 | --predict_file=${DATA_DIR}/cmrc2018_dev.json \ 51 | --train_batch_size=32 \ 52 | --num_train_epochs=2 \ 53 | --max_seq_length=512 \ 54 | --doc_stride=128 \ 55 | --learning_rate=3e-5 \ 56 | --save_checkpoints_steps=1000 \ 57 | --output_dir=${MODEL_DIR} \ 58 | --do_lower_case=False \ 59 | --use_tpu=False 60 | ``` 61 | 62 | ### Step 4: Evaluation 63 | We use official evaluation script for CMRC 2018 and DRCD. 64 | Note that, as DRCD official does not provide evaluation script, we also use `cmrc2018_evaluate.py` for DRCD. 65 | 66 | `python cmrc2018_evaluate.py cmrc2018_dev.json predictions.json` 67 | 68 | 69 | ## Baseline Results 70 | We provide both `BERT-Chinese` and `BERT-multilingual` baselines. 71 | 72 | **Note that, each baseline is performed on 10 runs, and we take average scores of them for reliable results (do not applied on hidden sets).** 73 | 74 | ### CMRC 2018 75 | | System | DEV-EM | DEV-F1 | TEST-EM | TEST-F1 | CHALLENGE-EM | CHALLENGE-F1 | Note | 76 | | :------ | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | 77 | | BERT (Multi-lingual) | 63.6 | 84.0 | 68.4 | 86.5 | 18.5 | 43.0 | - | 78 | | BERT (Chinese) | 63.5 | 83.6 | 67.5 | 85.6 | 18.4 | 42.1 | - | 79 | | P-Reader (single model) | 59.894 | 81.499 | 65.189 | 84.386 | 15.079 | 39.583 | - | 80 | | GM-Reader (ensemble) | 58.931 | 80.069 | 64.045 | 83.046 | 15.675 | 37.315 | - | 81 | | MCA-Reader (ensemble) | 66.698 | 85.538 | 71.175 | 88.090 | 15.476 | 37.104 | - | 82 | | Z-Reader (single model) | 79.776 | 92.696 | 74.178 | 88.145 | 13.889 | 37.422 | - | 83 | | [SRC->DS(±) (Yang et al., 2019)](https://arxiv.org/abs/1904.06652) | 49.2 | 65.4 | - | - | - | - | - | 84 | 85 | > Leaderboard: https://hfl-rc.github.io/cmrc2018/open_challenge/
86 | > Note that, some of the previous submission are using development set for training as well. 87 | 88 | 89 | ### DRCD 90 | | System | DEV-EM | DEV-F1 | TEST-EM | TEST-EM | Note | 91 | | :------ | :-----: | :-----: | :-----: | :-----: | :-----: | 92 | | BERT (Multi-lingual) | 83.0 | 89.7 | 83.2 | 89.8 | - | 93 | | BERT (Chinese) | 82.2 | 89.3 | 81.6 | 88.8 | - | 94 | | [SRC + DS(±) (Yang et al., 2019)](https://arxiv.org/abs/1904.06652) | 55.4 | 67.7 | - | - | - | 95 | | r-net (single model) | - | - | 29.1 | 44.4 | - | 96 | 97 | ## Contact 98 | For help or issues, please submit a GitHub issue. 99 | -------------------------------------------------------------------------------- /squad-style-data/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 - special 5 | Note: 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import Counter, OrderedDict 12 | import string 13 | import re 14 | import argparse 15 | import json 16 | import sys 17 | reload(sys) 18 | sys.setdefaultencoding('utf8') 19 | import nltk 20 | import pdb 21 | 22 | # split Chinese with English 23 | def mixed_segmentation(in_str, rm_punc=False): 24 | in_str = str(in_str).decode('utf-8').lower().strip() 25 | segs_out = [] 26 | temp_str = "" 27 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 28 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 29 | '「','」','(',')','-','~','『','』'] 30 | for char in in_str: 31 | if rm_punc and char in sp_char: 32 | continue 33 | if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char: 34 | if temp_str != "": 35 | ss = nltk.word_tokenize(temp_str) 36 | segs_out.extend(ss) 37 | temp_str = "" 38 | segs_out.append(char) 39 | else: 40 | temp_str += char 41 | 42 | #handling last part 43 | if temp_str != "": 44 | ss = nltk.word_tokenize(temp_str) 45 | segs_out.extend(ss) 46 | 47 | return segs_out 48 | 49 | 50 | # remove punctuation 51 | def remove_punctuation(in_str): 52 | in_str = str(in_str).decode('utf-8').lower().strip() 53 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 54 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 55 | '「','」','(',')','-','~','『','』'] 56 | out_segs = [] 57 | for char in in_str: 58 | if char in sp_char: 59 | continue 60 | else: 61 | out_segs.append(char) 62 | return ''.join(out_segs) 63 | 64 | 65 | # find longest common string 66 | def find_lcs(s1, s2): 67 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 68 | mmax = 0 69 | p = 0 70 | for i in range(len(s1)): 71 | for j in range(len(s2)): 72 | if s1[i] == s2[j]: 73 | m[i+1][j+1] = m[i][j]+1 74 | if m[i+1][j+1] > mmax: 75 | mmax=m[i+1][j+1] 76 | p=i+1 77 | return s1[p-mmax:p], mmax 78 | 79 | # 80 | def evaluate(ground_truth_file, prediction_file): 81 | f1 = 0 82 | em = 0 83 | total_count = 0 84 | skip_count = 0 85 | for instance in ground_truth_file["data"]: 86 | #context_id = instance['context_id'].strip() 87 | #context_text = instance['context_text'].strip() 88 | for para in instance["paragraphs"]: 89 | for qas in para['qas']: 90 | total_count += 1 91 | query_id = qas['id'].strip() 92 | query_text = qas['question'].strip() 93 | answers = [x["text"] for x in qas['answers']] 94 | 95 | if query_id not in prediction_file: 96 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 97 | skip_count += 1 98 | continue 99 | 100 | prediction = str(prediction_file[query_id]).decode('utf-8') 101 | f1 += calc_f1_score(answers, prediction) 102 | em += calc_em_score(answers, prediction) 103 | 104 | f1_score = 100.0 * f1 / total_count 105 | em_score = 100.0 * em / total_count 106 | return f1_score, em_score, total_count, skip_count 107 | 108 | 109 | def calc_f1_score(answers, prediction): 110 | f1_scores = [] 111 | for ans in answers: 112 | ans_segs = mixed_segmentation(ans, rm_punc=True) 113 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 114 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 115 | if lcs_len == 0: 116 | f1_scores.append(0) 117 | continue 118 | precision = 1.0*lcs_len/len(prediction_segs) 119 | recall = 1.0*lcs_len/len(ans_segs) 120 | f1 = (2*precision*recall)/(precision+recall) 121 | f1_scores.append(f1) 122 | return max(f1_scores) 123 | 124 | 125 | def calc_em_score(answers, prediction): 126 | em = 0 127 | for ans in answers: 128 | ans_ = remove_punctuation(ans) 129 | prediction_ = remove_punctuation(prediction) 130 | if ans_ == prediction_: 131 | em = 1 132 | break 133 | return em 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 137 | parser.add_argument('dataset_file', help='Official dataset file') 138 | parser.add_argument('prediction_file', help='Your prediction File') 139 | args = parser.parse_args() 140 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 141 | prediction_file = json.load(open(args.prediction_file, 'rb')) 142 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 143 | AVG = (EM+F1)*0.5 144 | output_result = OrderedDict() 145 | output_result['AVERAGE'] = '%.3f' % AVG 146 | output_result['F1'] = '%.3f' % F1 147 | output_result['EM'] = '%.3f' % EM 148 | output_result['TOTAL'] = TOTAL 149 | output_result['SKIP'] = SKIP 150 | output_result['FILE'] = args.prediction_file 151 | print(json.dumps(output_result)) 152 | 153 | -------------------------------------------------------------------------------- /baseline/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | Note: 5 | v6: compatible for both original and SQuAD-style CMRC 2018 datasets. support python3. 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import Counter, OrderedDict 12 | import string 13 | import re 14 | import argparse 15 | import json 16 | import sys 17 | import nltk 18 | import pdb 19 | 20 | # split Chinese with English 21 | def mixed_segmentation(in_str, rm_punc=False): 22 | in_str = str(in_str).lower().strip() 23 | segs_out = [] 24 | temp_str = "" 25 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 26 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 27 | '「','」','(',')','-','~','『','』'] 28 | for char in in_str: 29 | if rm_punc and char in sp_char: 30 | continue 31 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: 32 | if temp_str != "": 33 | ss = nltk.word_tokenize(temp_str) 34 | segs_out.extend(ss) 35 | temp_str = "" 36 | segs_out.append(char) 37 | else: 38 | temp_str += char 39 | 40 | #handling last part 41 | if temp_str != "": 42 | ss = nltk.word_tokenize(temp_str) 43 | segs_out.extend(ss) 44 | 45 | return segs_out 46 | 47 | # remove punctuation 48 | def remove_punctuation(in_str): 49 | in_str = str(in_str).lower().strip() 50 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 51 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 52 | '「','」','(',')','-','~','『','』'] 53 | out_segs = [] 54 | for char in in_str: 55 | if char in sp_char: 56 | continue 57 | else: 58 | out_segs.append(char) 59 | return ''.join(out_segs) 60 | 61 | 62 | # find longest common string 63 | def find_lcs(s1, s2): 64 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 65 | mmax = 0 66 | p = 0 67 | for i in range(len(s1)): 68 | for j in range(len(s2)): 69 | if s1[i] == s2[j]: 70 | m[i+1][j+1] = m[i][j]+1 71 | if m[i+1][j+1] > mmax: 72 | mmax=m[i+1][j+1] 73 | p=i+1 74 | return s1[p-mmax:p], mmax 75 | 76 | # 77 | def evaluate(ground_truth_file, prediction_file): 78 | f1 = 0 79 | em = 0 80 | total_count = 0 81 | skip_count = 0 82 | 83 | data_list = ground_truth_file['data'] if 'data' in ground_truth_file else ground_truth_file 84 | for instance in data_list: 85 | para_list = instance['paragraphs'] if 'paragraphs' in instance else [instance] 86 | for para in para_list: 87 | for qas in para['qas']: 88 | total_count += 1 89 | query_id = qas['id'] if 'id' in qas else qas['query_id'] 90 | answers = [x['text'] if isinstance(x, dict) else x for x in qas['answers']] 91 | 92 | if query_id not in prediction_file: 93 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 94 | skip_count += 1 95 | continue 96 | 97 | prediction = str(prediction_file[query_id]) 98 | f1 += calc_f1_score(answers, prediction) 99 | em += calc_em_score(answers, prediction) 100 | 101 | f1_score = 100.0 * f1 / total_count 102 | em_score = 100.0 * em / total_count 103 | return f1_score, em_score, total_count, skip_count 104 | 105 | 106 | def calc_f1_score(answers, prediction): 107 | f1_scores = [] 108 | for ans in answers: 109 | ans_segs = mixed_segmentation(ans, rm_punc=True) 110 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 111 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 112 | if lcs_len == 0: 113 | f1_scores.append(0) 114 | continue 115 | precision = 1.0*lcs_len/len(prediction_segs) 116 | recall = 1.0*lcs_len/len(ans_segs) 117 | f1 = (2*precision*recall)/(precision+recall) 118 | f1_scores.append(f1) 119 | return max(f1_scores) 120 | 121 | 122 | def calc_em_score(answers, prediction): 123 | em = 0 124 | for ans in answers: 125 | ans_ = remove_punctuation(ans) 126 | prediction_ = remove_punctuation(prediction) 127 | if ans_ == prediction_: 128 | em = 1 129 | break 130 | return em 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 134 | parser.add_argument('dataset_file', help='Official dataset file') 135 | parser.add_argument('prediction_file', help='Your prediction File') 136 | args = parser.parse_args() 137 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 138 | prediction_file = json.load(open(args.prediction_file, 'rb')) 139 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 140 | AVG = (EM+F1)*0.5 141 | output_result = OrderedDict() 142 | output_result['AVERAGE'] = '%.3f' % AVG 143 | output_result['F1'] = '%.3f' % F1 144 | output_result['EM'] = '%.3f' % EM 145 | output_result['TOTAL'] = TOTAL 146 | output_result['SKIP'] = SKIP 147 | output_result['FILE'] = args.prediction_file 148 | print(json.dumps(output_result)) 149 | 150 | -------------------------------------------------------------------------------- /baseline/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /baseline/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Attribution-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-ShareAlike 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-ShareAlike 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. Share means to provide material to the public by any means or 126 | process that requires permission under the Licensed Rights, such 127 | as reproduction, public display, public performance, distribution, 128 | dissemination, communication, or importation, and to make material 129 | available to the public including in ways that members of the 130 | public may access the material from a place and at a time 131 | individually chosen by them. 132 | 133 | l. Sui Generis Database Rights means rights other than copyright 134 | resulting from Directive 96/9/EC of the European Parliament and of 135 | the Council of 11 March 1996 on the legal protection of databases, 136 | as amended and/or succeeded, as well as other essentially 137 | equivalent rights anywhere in the world. 138 | 139 | m. You means the individual or entity exercising the Licensed Rights 140 | under this Public License. Your has a corresponding meaning. 141 | 142 | 143 | Section 2 -- Scope. 144 | 145 | a. License grant. 146 | 147 | 1. Subject to the terms and conditions of this Public License, 148 | the Licensor hereby grants You a worldwide, royalty-free, 149 | non-sublicensable, non-exclusive, irrevocable license to 150 | exercise the Licensed Rights in the Licensed Material to: 151 | 152 | a. reproduce and Share the Licensed Material, in whole or 153 | in part; and 154 | 155 | b. produce, reproduce, and Share Adapted Material. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. Additional offer from the Licensor -- Adapted Material. 186 | Every recipient of Adapted Material from You 187 | automatically receives an offer from the Licensor to 188 | exercise the Licensed Rights in the Adapted Material 189 | under the conditions of the Adapter's License You apply. 190 | 191 | c. No downstream restrictions. You may not offer or impose 192 | any additional or different terms or conditions on, or 193 | apply any Effective Technological Measures to, the 194 | Licensed Material if doing so restricts exercise of the 195 | Licensed Rights by any recipient of the Licensed 196 | Material. 197 | 198 | 6. No endorsement. Nothing in this Public License constitutes or 199 | may be construed as permission to assert or imply that You 200 | are, or that Your use of the Licensed Material is, connected 201 | with, or sponsored, endorsed, or granted official status by, 202 | the Licensor or others designated to receive attribution as 203 | provided in Section 3(a)(1)(A)(i). 204 | 205 | b. Other rights. 206 | 207 | 1. Moral rights, such as the right of integrity, are not 208 | licensed under this Public License, nor are publicity, 209 | privacy, and/or other similar personality rights; however, to 210 | the extent possible, the Licensor waives and/or agrees not to 211 | assert any such rights held by the Licensor to the limited 212 | extent necessary to allow You to exercise the Licensed 213 | Rights, but not otherwise. 214 | 215 | 2. Patent and trademark rights are not licensed under this 216 | Public License. 217 | 218 | 3. To the extent possible, the Licensor waives any right to 219 | collect royalties from You for the exercise of the Licensed 220 | Rights, whether directly or through a collecting society 221 | under any voluntary or waivable statutory or compulsory 222 | licensing scheme. In all other cases the Licensor expressly 223 | reserves any right to collect such royalties. 224 | 225 | 226 | Section 3 -- License Conditions. 227 | 228 | Your exercise of the Licensed Rights is expressly made subject to the 229 | following conditions. 230 | 231 | a. Attribution. 232 | 233 | 1. If You Share the Licensed Material (including in modified 234 | form), You must: 235 | 236 | a. retain the following if it is supplied by the Licensor 237 | with the Licensed Material: 238 | 239 | i. identification of the creator(s) of the Licensed 240 | Material and any others designated to receive 241 | attribution, in any reasonable manner requested by 242 | the Licensor (including by pseudonym if 243 | designated); 244 | 245 | ii. a copyright notice; 246 | 247 | iii. a notice that refers to this Public License; 248 | 249 | iv. a notice that refers to the disclaimer of 250 | warranties; 251 | 252 | v. a URI or hyperlink to the Licensed Material to the 253 | extent reasonably practicable; 254 | 255 | b. indicate if You modified the Licensed Material and 256 | retain an indication of any previous modifications; and 257 | 258 | c. indicate the Licensed Material is licensed under this 259 | Public License, and include the text of, or the URI or 260 | hyperlink to, this Public License. 261 | 262 | 2. You may satisfy the conditions in Section 3(a)(1) in any 263 | reasonable manner based on the medium, means, and context in 264 | which You Share the Licensed Material. For example, it may be 265 | reasonable to satisfy the conditions by providing a URI or 266 | hyperlink to a resource that includes the required 267 | information. 268 | 269 | 3. If requested by the Licensor, You must remove any of the 270 | information required by Section 3(a)(1)(A) to the extent 271 | reasonably practicable. 272 | 273 | b. ShareAlike. 274 | 275 | In addition to the conditions in Section 3(a), if You Share 276 | Adapted Material You produce, the following conditions also apply. 277 | 278 | 1. The Adapter's License You apply must be a Creative Commons 279 | license with the same License Elements, this version or 280 | later, or a BY-SA Compatible License. 281 | 282 | 2. You must include the text of, or the URI or hyperlink to, the 283 | Adapter's License You apply. You may satisfy this condition 284 | in any reasonable manner based on the medium, means, and 285 | context in which You Share Adapted Material. 286 | 287 | 3. You may not offer or impose any additional or different terms 288 | or conditions on, or apply any Effective Technological 289 | Measures to, Adapted Material that restrict exercise of the 290 | rights granted under the Adapter's License You apply. 291 | 292 | 293 | Section 4 -- Sui Generis Database Rights. 294 | 295 | Where the Licensed Rights include Sui Generis Database Rights that 296 | apply to Your use of the Licensed Material: 297 | 298 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 299 | to extract, reuse, reproduce, and Share all or a substantial 300 | portion of the contents of the database; 301 | 302 | b. if You include all or a substantial portion of the database 303 | contents in a database in which You have Sui Generis Database 304 | Rights, then the database in which You have Sui Generis Database 305 | Rights (but not its individual contents) is Adapted Material, 306 | 307 | including for purposes of Section 3(b); and 308 | c. You must comply with the conditions in Section 3(a) if You Share 309 | all or a substantial portion of the contents of the database. 310 | 311 | For the avoidance of doubt, this Section 4 supplements and does not 312 | replace Your obligations under this Public License where the Licensed 313 | Rights include other Copyright and Similar Rights. 314 | 315 | 316 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 317 | 318 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 319 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 320 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 321 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 322 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 323 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 324 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 325 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 326 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 327 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 328 | 329 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 330 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 331 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 332 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 333 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 334 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 335 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 336 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 337 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 338 | 339 | c. The disclaimer of warranties and limitation of liability provided 340 | above shall be interpreted in a manner that, to the extent 341 | possible, most closely approximates an absolute disclaimer and 342 | waiver of all liability. 343 | 344 | 345 | Section 6 -- Term and Termination. 346 | 347 | a. This Public License applies for the term of the Copyright and 348 | Similar Rights licensed here. However, if You fail to comply with 349 | this Public License, then Your rights under this Public License 350 | terminate automatically. 351 | 352 | b. Where Your right to use the Licensed Material has terminated under 353 | Section 6(a), it reinstates: 354 | 355 | 1. automatically as of the date the violation is cured, provided 356 | it is cured within 30 days of Your discovery of the 357 | violation; or 358 | 359 | 2. upon express reinstatement by the Licensor. 360 | 361 | For the avoidance of doubt, this Section 6(b) does not affect any 362 | right the Licensor may have to seek remedies for Your violations 363 | of this Public License. 364 | 365 | c. For the avoidance of doubt, the Licensor may also offer the 366 | Licensed Material under separate terms or conditions or stop 367 | distributing the Licensed Material at any time; however, doing so 368 | will not terminate this Public License. 369 | 370 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 371 | License. 372 | 373 | 374 | Section 7 -- Other Terms and Conditions. 375 | 376 | a. The Licensor shall not be bound by any additional or different 377 | terms or conditions communicated by You unless expressly agreed. 378 | 379 | b. Any arrangements, understandings, or agreements regarding the 380 | Licensed Material not stated herein are separate from and 381 | independent of the terms and conditions of this Public License. 382 | 383 | 384 | Section 8 -- Interpretation. 385 | 386 | a. For the avoidance of doubt, this Public License does not, and 387 | shall not be interpreted to, reduce, limit, restrict, or impose 388 | conditions on any use of the Licensed Material that could lawfully 389 | be made without permission under this Public License. 390 | 391 | b. To the extent possible, if any provision of this Public License is 392 | deemed unenforceable, it shall be automatically reformed to the 393 | minimum extent necessary to make it enforceable. If the provision 394 | cannot be reformed, it shall be severed from this Public License 395 | without affecting the enforceability of the remaining terms and 396 | conditions. 397 | 398 | c. No term or condition of this Public License will be waived and no 399 | failure to comply consented to unless expressly agreed to by the 400 | Licensor. 401 | 402 | d. Nothing in this Public License constitutes or may be interpreted 403 | as a limitation upon, or waiver of, any privileges and immunities 404 | that apply to the Licensor or You, including from the legal 405 | processes of any jurisdiction or authority. 406 | 407 | 408 | ======================================================================= 409 | 410 | Creative Commons is not a party to its public 411 | licenses. Notwithstanding, Creative Commons may elect to apply one of 412 | its public licenses to material it publishes and in those instances 413 | will be considered the “Licensor.” The text of the Creative Commons 414 | public licenses is dedicated to the public domain under the CC0 Public 415 | Domain Dedication. Except for the limited purpose of indicating that 416 | material is shared under a Creative Commons public license or as 417 | otherwise permitted by the Creative Commons policies published at 418 | creativecommons.org/policies, Creative Commons does not authorize the 419 | use of the trademark "Creative Commons" or any other trademark or logo 420 | of Creative Commons without its prior written consent including, 421 | without limitation, in connection with any unauthorized modifications 422 | to any of its public licenses or any other arrangements, 423 | understandings, or agreements concerning use of licensed material. For 424 | the avoidance of doubt, this paragraph does not form part of the 425 | public licenses. 426 | 427 | Creative Commons may be contacted at creativecommons.org. 428 | -------------------------------------------------------------------------------- /baseline/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. true for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is much faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | output = input_tensor 473 | 474 | if use_token_type: 475 | if token_type_ids is None: 476 | raise ValueError("`token_type_ids` must be specified if" 477 | "`use_token_type` is True.") 478 | token_type_table = tf.get_variable( 479 | name=token_type_embedding_name, 480 | shape=[token_type_vocab_size, width], 481 | initializer=create_initializer(initializer_range)) 482 | # This vocab will be small so we always do one-hot here, since it is always 483 | # faster for a small vocabulary. 484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 487 | token_type_embeddings = tf.reshape(token_type_embeddings, 488 | [batch_size, seq_length, width]) 489 | output += token_type_embeddings 490 | 491 | if use_position_embeddings: 492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 493 | with tf.control_dependencies([assert_op]): 494 | full_position_embeddings = tf.get_variable( 495 | name=position_embedding_name, 496 | shape=[max_position_embeddings, width], 497 | initializer=create_initializer(initializer_range)) 498 | # Since the position embedding table is a learned variable, we create it 499 | # using a (long) sequence length `max_position_embeddings`. The actual 500 | # sequence length might be shorter than this, for faster training of 501 | # tasks that do not have long sequences. 502 | # 503 | # So `full_position_embeddings` is effectively an embedding table 504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 506 | # perform a slice. 507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 508 | [seq_length, -1]) 509 | num_dims = len(output.shape.as_list()) 510 | 511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 512 | # we broadcast among the first dimensions, which is typically just 513 | # the batch size. 514 | position_broadcast_shape = [] 515 | for _ in range(num_dims - 2): 516 | position_broadcast_shape.append(1) 517 | position_broadcast_shape.extend([seq_length, width]) 518 | position_embeddings = tf.reshape(position_embeddings, 519 | position_broadcast_shape) 520 | output += position_embeddings 521 | 522 | output = layer_norm_and_dropout(output, dropout_prob) 523 | return output 524 | 525 | 526 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 527 | """Create 3D attention mask from a 2D tensor mask. 528 | 529 | Args: 530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 532 | 533 | Returns: 534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 535 | """ 536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 537 | batch_size = from_shape[0] 538 | from_seq_length = from_shape[1] 539 | 540 | to_shape = get_shape_list(to_mask, expected_rank=2) 541 | to_seq_length = to_shape[1] 542 | 543 | to_mask = tf.cast( 544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 545 | 546 | # We don't assume that `from_tensor` is a mask (although it could be). We 547 | # don't actually care if we attend *from* padding tokens (only *to* padding) 548 | # tokens so we create a tensor of all ones. 549 | # 550 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 551 | broadcast_ones = tf.ones( 552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 553 | 554 | # Here we broadcast along two dimensions to create the mask. 555 | mask = broadcast_ones * to_mask 556 | 557 | return mask 558 | 559 | 560 | def attention_layer(from_tensor, 561 | to_tensor, 562 | attention_mask=None, 563 | num_attention_heads=1, 564 | size_per_head=512, 565 | query_act=None, 566 | key_act=None, 567 | value_act=None, 568 | attention_probs_dropout_prob=0.0, 569 | initializer_range=0.02, 570 | do_return_2d_tensor=False, 571 | batch_size=None, 572 | from_seq_length=None, 573 | to_seq_length=None): 574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 575 | 576 | This is an implementation of multi-headed attention based on "Attention 577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 578 | this is self-attention. Each timestep in `from_tensor` attends to the 579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 580 | 581 | This function first projects `from_tensor` into a "query" tensor and 582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 583 | of tensors of length `num_attention_heads`, where each tensor is of shape 584 | [batch_size, seq_length, size_per_head]. 585 | 586 | Then, the query and key tensors are dot-producted and scaled. These are 587 | softmaxed to obtain attention probabilities. The value tensors are then 588 | interpolated by these probabilities, then concatenated back to a single 589 | tensor and returned. 590 | 591 | In practice, the multi-headed attention are done with transposes and 592 | reshapes rather than actual separate tensors. 593 | 594 | Args: 595 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 596 | from_width]. 597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 598 | attention_mask: (optional) int32 Tensor of shape [batch_size, 599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 600 | attention scores will effectively be set to -infinity for any positions in 601 | the mask that are 0, and will be unchanged for positions that are 1. 602 | num_attention_heads: int. Number of attention heads. 603 | size_per_head: int. Size of each attention head. 604 | query_act: (optional) Activation function for the query transform. 605 | key_act: (optional) Activation function for the key transform. 606 | value_act: (optional) Activation function for the value transform. 607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 608 | attention probabilities. 609 | initializer_range: float. Range of the weight initializer. 610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 612 | output will be of shape [batch_size, from_seq_length, num_attention_heads 613 | * size_per_head]. 614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 615 | of the 3D version of the `from_tensor` and `to_tensor`. 616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `from_tensor`. 618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 619 | of the 3D version of the `to_tensor`. 620 | 621 | Returns: 622 | float Tensor of shape [batch_size, from_seq_length, 623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 624 | true, this will be of shape [batch_size * from_seq_length, 625 | num_attention_heads * size_per_head]). 626 | 627 | Raises: 628 | ValueError: Any of the arguments or tensor shapes are invalid. 629 | """ 630 | 631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 632 | seq_length, width): 633 | output_tensor = tf.reshape( 634 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 635 | 636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 637 | return output_tensor 638 | 639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 641 | 642 | if len(from_shape) != len(to_shape): 643 | raise ValueError( 644 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 645 | 646 | if len(from_shape) == 3: 647 | batch_size = from_shape[0] 648 | from_seq_length = from_shape[1] 649 | to_seq_length = to_shape[1] 650 | elif len(from_shape) == 2: 651 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 652 | raise ValueError( 653 | "When passing in rank 2 tensors to attention_layer, the values " 654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 655 | "must all be specified.") 656 | 657 | # Scalar dimensions referenced here: 658 | # B = batch size (number of sequences) 659 | # F = `from_tensor` sequence length 660 | # T = `to_tensor` sequence length 661 | # N = `num_attention_heads` 662 | # H = `size_per_head` 663 | 664 | from_tensor_2d = reshape_to_matrix(from_tensor) 665 | to_tensor_2d = reshape_to_matrix(to_tensor) 666 | 667 | # `query_layer` = [B*F, N*H] 668 | query_layer = tf.layers.dense( 669 | from_tensor_2d, 670 | num_attention_heads * size_per_head, 671 | activation=query_act, 672 | name="query", 673 | kernel_initializer=create_initializer(initializer_range)) 674 | 675 | # `key_layer` = [B*T, N*H] 676 | key_layer = tf.layers.dense( 677 | to_tensor_2d, 678 | num_attention_heads * size_per_head, 679 | activation=key_act, 680 | name="key", 681 | kernel_initializer=create_initializer(initializer_range)) 682 | 683 | # `value_layer` = [B*T, N*H] 684 | value_layer = tf.layers.dense( 685 | to_tensor_2d, 686 | num_attention_heads * size_per_head, 687 | activation=value_act, 688 | name="value", 689 | kernel_initializer=create_initializer(initializer_range)) 690 | 691 | # `query_layer` = [B, N, F, H] 692 | query_layer = transpose_for_scores(query_layer, batch_size, 693 | num_attention_heads, from_seq_length, 694 | size_per_head) 695 | 696 | # `key_layer` = [B, N, T, H] 697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 698 | to_seq_length, size_per_head) 699 | 700 | # Take the dot product between "query" and "key" to get the raw 701 | # attention scores. 702 | # `attention_scores` = [B, N, F, T] 703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 704 | attention_scores = tf.multiply(attention_scores, 705 | 1.0 / math.sqrt(float(size_per_head))) 706 | 707 | if attention_mask is not None: 708 | # `attention_mask` = [B, 1, F, T] 709 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 710 | 711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 712 | # masked positions, this operation will create a tensor which is 0.0 for 713 | # positions we want to attend and -10000.0 for masked positions. 714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 715 | 716 | # Since we are adding it to the raw scores before the softmax, this is 717 | # effectively the same as removing these entirely. 718 | attention_scores += adder 719 | 720 | # Normalize the attention scores to probabilities. 721 | # `attention_probs` = [B, N, F, T] 722 | attention_probs = tf.nn.softmax(attention_scores) 723 | 724 | # This is actually dropping out entire tokens to attend to, which might 725 | # seem a bit unusual, but is taken from the original Transformer paper. 726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 727 | 728 | # `value_layer` = [B, T, N, H] 729 | value_layer = tf.reshape( 730 | value_layer, 731 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 732 | 733 | # `value_layer` = [B, N, T, H] 734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 735 | 736 | # `context_layer` = [B, N, F, H] 737 | context_layer = tf.matmul(attention_probs, value_layer) 738 | 739 | # `context_layer` = [B, F, N, H] 740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 741 | 742 | if do_return_2d_tensor: 743 | # `context_layer` = [B*F, N*H] 744 | context_layer = tf.reshape( 745 | context_layer, 746 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 747 | else: 748 | # `context_layer` = [B, F, N*H] 749 | context_layer = tf.reshape( 750 | context_layer, 751 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 752 | 753 | return context_layer 754 | 755 | 756 | def transformer_model(input_tensor, 757 | attention_mask=None, 758 | hidden_size=768, 759 | num_hidden_layers=12, 760 | num_attention_heads=12, 761 | intermediate_size=3072, 762 | intermediate_act_fn=gelu, 763 | hidden_dropout_prob=0.1, 764 | attention_probs_dropout_prob=0.1, 765 | initializer_range=0.02, 766 | do_return_all_layers=False): 767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 768 | 769 | This is almost an exact implementation of the original Transformer encoder. 770 | 771 | See the original paper: 772 | https://arxiv.org/abs/1706.03762 773 | 774 | Also see: 775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 776 | 777 | Args: 778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 780 | seq_length], with 1 for positions that can be attended to and 0 in 781 | positions that should not be. 782 | hidden_size: int. Hidden size of the Transformer. 783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 784 | num_attention_heads: int. Number of attention heads in the Transformer. 785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 786 | forward) layer. 787 | intermediate_act_fn: function. The non-linear activation function to apply 788 | to the output of the intermediate/feed-forward layer. 789 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 790 | attention_probs_dropout_prob: float. Dropout probability of the attention 791 | probabilities. 792 | initializer_range: float. Range of the initializer (stddev of truncated 793 | normal). 794 | do_return_all_layers: Whether to also return all layers or just the final 795 | layer. 796 | 797 | Returns: 798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 799 | hidden layer of the Transformer. 800 | 801 | Raises: 802 | ValueError: A Tensor shape or parameter is invalid. 803 | """ 804 | if hidden_size % num_attention_heads != 0: 805 | raise ValueError( 806 | "The hidden size (%d) is not a multiple of the number of attention " 807 | "heads (%d)" % (hidden_size, num_attention_heads)) 808 | 809 | attention_head_size = int(hidden_size / num_attention_heads) 810 | input_shape = get_shape_list(input_tensor, expected_rank=3) 811 | batch_size = input_shape[0] 812 | seq_length = input_shape[1] 813 | input_width = input_shape[2] 814 | 815 | # The Transformer performs sum residuals on all layers so the input needs 816 | # to be the same as the hidden size. 817 | if input_width != hidden_size: 818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 819 | (input_width, hidden_size)) 820 | 821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 824 | # help the optimizer. 825 | prev_output = reshape_to_matrix(input_tensor) 826 | 827 | all_layer_outputs = [] 828 | for layer_idx in range(num_hidden_layers): 829 | with tf.variable_scope("layer_%d" % layer_idx): 830 | layer_input = prev_output 831 | 832 | with tf.variable_scope("attention"): 833 | attention_heads = [] 834 | with tf.variable_scope("self"): 835 | attention_head = attention_layer( 836 | from_tensor=layer_input, 837 | to_tensor=layer_input, 838 | attention_mask=attention_mask, 839 | num_attention_heads=num_attention_heads, 840 | size_per_head=attention_head_size, 841 | attention_probs_dropout_prob=attention_probs_dropout_prob, 842 | initializer_range=initializer_range, 843 | do_return_2d_tensor=True, 844 | batch_size=batch_size, 845 | from_seq_length=seq_length, 846 | to_seq_length=seq_length) 847 | attention_heads.append(attention_head) 848 | 849 | attention_output = None 850 | if len(attention_heads) == 1: 851 | attention_output = attention_heads[0] 852 | else: 853 | # In the case where we have other sequences, we just concatenate 854 | # them to the self-attention head before the projection. 855 | attention_output = tf.concat(attention_heads, axis=-1) 856 | 857 | # Run a linear projection of `hidden_size` then add a residual 858 | # with `layer_input`. 859 | with tf.variable_scope("output"): 860 | attention_output = tf.layers.dense( 861 | attention_output, 862 | hidden_size, 863 | kernel_initializer=create_initializer(initializer_range)) 864 | attention_output = dropout(attention_output, hidden_dropout_prob) 865 | attention_output = layer_norm(attention_output + layer_input) 866 | 867 | # The activation is only applied to the "intermediate" hidden layer. 868 | with tf.variable_scope("intermediate"): 869 | intermediate_output = tf.layers.dense( 870 | attention_output, 871 | intermediate_size, 872 | activation=intermediate_act_fn, 873 | kernel_initializer=create_initializer(initializer_range)) 874 | 875 | # Down-project back to `hidden_size` then add the residual. 876 | with tf.variable_scope("output"): 877 | layer_output = tf.layers.dense( 878 | intermediate_output, 879 | hidden_size, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | layer_output = dropout(layer_output, hidden_dropout_prob) 882 | layer_output = layer_norm(layer_output + attention_output) 883 | prev_output = layer_output 884 | all_layer_outputs.append(layer_output) 885 | 886 | if do_return_all_layers: 887 | final_outputs = [] 888 | for layer_output in all_layer_outputs: 889 | final_output = reshape_from_matrix(layer_output, input_shape) 890 | final_outputs.append(final_output) 891 | return final_outputs 892 | else: 893 | final_output = reshape_from_matrix(prev_output, input_shape) 894 | return final_output 895 | 896 | 897 | def get_shape_list(tensor, expected_rank=None, name=None): 898 | """Returns a list of the shape of tensor, preferring static dimensions. 899 | 900 | Args: 901 | tensor: A tf.Tensor object to find the shape of. 902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 903 | specified and the `tensor` has a different rank, and exception will be 904 | thrown. 905 | name: Optional name of the tensor for the error message. 906 | 907 | Returns: 908 | A list of dimensions of the shape of tensor. All static dimensions will 909 | be returned as python integers, and dynamic dimensions will be returned 910 | as tf.Tensor scalars. 911 | """ 912 | if name is None: 913 | name = tensor.name 914 | 915 | if expected_rank is not None: 916 | assert_rank(tensor, expected_rank, name) 917 | 918 | shape = tensor.shape.as_list() 919 | 920 | non_static_indexes = [] 921 | for (index, dim) in enumerate(shape): 922 | if dim is None: 923 | non_static_indexes.append(index) 924 | 925 | if not non_static_indexes: 926 | return shape 927 | 928 | dyn_shape = tf.shape(tensor) 929 | for index in non_static_indexes: 930 | shape[index] = dyn_shape[index] 931 | return shape 932 | 933 | 934 | def reshape_to_matrix(input_tensor): 935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 936 | ndims = input_tensor.shape.ndims 937 | if ndims < 2: 938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 939 | (input_tensor.shape)) 940 | if ndims == 2: 941 | return input_tensor 942 | 943 | width = input_tensor.shape[-1] 944 | output_tensor = tf.reshape(input_tensor, [-1, width]) 945 | return output_tensor 946 | 947 | 948 | def reshape_from_matrix(output_tensor, orig_shape_list): 949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 950 | if len(orig_shape_list) == 2: 951 | return output_tensor 952 | 953 | output_shape = get_shape_list(output_tensor) 954 | 955 | orig_dims = orig_shape_list[0:-1] 956 | width = output_shape[-1] 957 | 958 | return tf.reshape(output_tensor, orig_dims + [width]) 959 | 960 | 961 | def assert_rank(tensor, expected_rank, name=None): 962 | """Raises an exception if the tensor rank is not of the expected rank. 963 | 964 | Args: 965 | tensor: A tf.Tensor to check the rank of. 966 | expected_rank: Python integer or list of integers, expected rank. 967 | name: Optional name of the tensor for the error message. 968 | 969 | Raises: 970 | ValueError: If the expected shape doesn't match the actual shape. 971 | """ 972 | if name is None: 973 | name = tensor.name 974 | 975 | expected_rank_dict = {} 976 | if isinstance(expected_rank, six.integer_types): 977 | expected_rank_dict[expected_rank] = True 978 | else: 979 | for x in expected_rank: 980 | expected_rank_dict[x] = True 981 | 982 | actual_rank = tensor.shape.ndims 983 | if actual_rank not in expected_rank_dict: 984 | scope_name = tf.get_variable_scope().name 985 | raise ValueError( 986 | "For the tensor `%s` in scope `%s`, the actual rank " 987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 989 | -------------------------------------------------------------------------------- /baseline/run_cmrc2018_drcd_baseline.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run BERT on SQuAD 1.1 and SQuAD 2.0.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import json 23 | import math 24 | import os 25 | import random 26 | import modeling 27 | import optimization 28 | import tokenization 29 | import six 30 | import tensorflow as tf 31 | import numpy 32 | import pdb 33 | 34 | flags = tf.flags 35 | FLAGS = flags.FLAGS 36 | 37 | ## Required parameters 38 | flags.DEFINE_string( 39 | "bert_config_file", None, 40 | "The config json file corresponding to the pre-trained BERT model. " 41 | "This specifies the model architecture.") 42 | 43 | flags.DEFINE_string("vocab_file", None, 44 | "The vocabulary file that the BERT model was trained on.") 45 | 46 | flags.DEFINE_string( 47 | "output_dir", None, 48 | "The output directory where the model checkpoints will be written.") 49 | 50 | ## Other parameters 51 | flags.DEFINE_string("train_file", None, 52 | "SQuAD json for training. E.g., train-v1.1.json") 53 | 54 | flags.DEFINE_string( 55 | "predict_file", None, 56 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 57 | 58 | flags.DEFINE_string( 59 | "eval_file", None, 60 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 61 | 62 | flags.DEFINE_string( 63 | "init_checkpoint", None, 64 | "Initial checkpoint (usually from a pre-trained BERT model).") 65 | 66 | flags.DEFINE_bool( 67 | "do_lower_case", False, 68 | "Whether to lower case the input text. Should be True for uncased " 69 | "models and False for cased models.") 70 | 71 | flags.DEFINE_integer( 72 | "max_seq_length", 384, 73 | "The maximum total input sequence length after WordPiece tokenization. " 74 | "Sequences longer than this will be truncated, and sequences shorter " 75 | "than this will be padded.") 76 | 77 | flags.DEFINE_integer( 78 | "doc_stride", 128, 79 | "When splitting up a long document into chunks, how much stride to " 80 | "take between chunks.") 81 | 82 | flags.DEFINE_integer( 83 | "max_query_length", 64, 84 | "The maximum number of tokens for the question. Questions longer than " 85 | "this will be truncated to this length.") 86 | 87 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 88 | 89 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 90 | 91 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 92 | 93 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 94 | 95 | flags.DEFINE_integer("predict_batch_size", 8, 96 | "Total batch size for predictions.") 97 | 98 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 99 | 100 | flags.DEFINE_float("num_train_epochs", 3.0, 101 | "Total number of training epochs to perform.") 102 | 103 | flags.DEFINE_float( 104 | "warmup_proportion", 0.1, 105 | "Proportion of training to perform linear learning rate warmup for. " 106 | "E.g., 0.1 = 10% of training.") 107 | 108 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 109 | "How often to save the model checkpoint.") 110 | 111 | flags.DEFINE_integer("iterations_per_loop", 1000, 112 | "How many steps to make in each estimator call.") 113 | 114 | flags.DEFINE_integer( 115 | "n_best_size", 20, 116 | "The total number of n-best predictions to generate in the " 117 | "nbest_predictions.json output file.") 118 | 119 | flags.DEFINE_integer( 120 | "max_answer_length", 30, 121 | "The maximum length of an answer that can be generated. This is needed " 122 | "because the start and end predictions are not conditioned on one another.") 123 | 124 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 125 | 126 | tf.flags.DEFINE_string( 127 | "tpu_name", None, 128 | "The Cloud TPU to use for training. This should be either the name " 129 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 130 | "url.") 131 | 132 | tf.flags.DEFINE_string( 133 | "tpu_zone", None, 134 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 135 | "specified, we will attempt to automatically detect the GCE project from " 136 | "metadata.") 137 | 138 | tf.flags.DEFINE_string( 139 | "gcp_project", None, 140 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 141 | "specified, we will attempt to automatically detect the GCE project from " 142 | "metadata.") 143 | 144 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 145 | 146 | flags.DEFINE_integer( 147 | "num_tpu_cores", 8, 148 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 149 | 150 | flags.DEFINE_bool( 151 | "verbose_logging", False, 152 | "If true, all of the warnings related to data processing will be printed. " 153 | "A number of warnings are expected for a normal SQuAD evaluation.") 154 | 155 | flags.DEFINE_integer("rand_seed", 12345, "set random seed") 156 | 157 | # set random seed (i don't know whether it works or not) 158 | numpy.random.seed(int(FLAGS.rand_seed)) 159 | tf.set_random_seed(int(FLAGS.rand_seed)) 160 | 161 | # 162 | class SquadExample(object): 163 | """A single training/test example for simple sequence classification. 164 | 165 | For examples without an answer, the start and end position are -1. 166 | """ 167 | 168 | def __init__(self, 169 | qas_id, 170 | question_text, 171 | doc_tokens, 172 | orig_answer_text=None, 173 | start_position=None, 174 | end_position=None): 175 | self.qas_id = qas_id 176 | self.question_text = question_text 177 | self.doc_tokens = doc_tokens 178 | self.orig_answer_text = orig_answer_text 179 | self.start_position = start_position 180 | self.end_position = end_position 181 | 182 | def __str__(self): 183 | return self.__repr__() 184 | 185 | def __repr__(self): 186 | s = "" 187 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 188 | s += ", question_text: %s" % ( 189 | tokenization.printable_text(self.question_text)) 190 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 191 | if self.start_position: 192 | s += ", start_position: %d" % (self.start_position) 193 | if self.start_position: 194 | s += ", end_position: %d" % (self.end_position) 195 | return s 196 | 197 | 198 | class InputFeatures(object): 199 | """A single set of features of data.""" 200 | 201 | def __init__(self, 202 | unique_id, 203 | example_index, 204 | doc_span_index, 205 | tokens, 206 | token_to_orig_map, 207 | token_is_max_context, 208 | input_ids, 209 | input_mask, 210 | segment_ids, 211 | input_span_mask, 212 | start_position=None, 213 | end_position=None): 214 | self.unique_id = unique_id 215 | self.example_index = example_index 216 | self.doc_span_index = doc_span_index 217 | self.tokens = tokens 218 | self.token_to_orig_map = token_to_orig_map 219 | self.token_is_max_context = token_is_max_context 220 | self.input_ids = input_ids 221 | self.input_mask = input_mask 222 | self.segment_ids = segment_ids 223 | self.input_span_mask = input_span_mask 224 | self.start_position = start_position 225 | self.end_position = end_position 226 | 227 | # 228 | def customize_tokenizer(text, do_lower_case=False): 229 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 230 | temp_x = "" 231 | text = tokenization.convert_to_unicode(text) 232 | for c in text: 233 | if tokenizer._is_chinese_char(ord(c)) or tokenization._is_punctuation(c) or tokenization._is_whitespace(c) or tokenization._is_control(c): 234 | temp_x += " " + c + " " 235 | else: 236 | temp_x += c 237 | if do_lower_case: 238 | temp_x = temp_x.lower() 239 | return temp_x.split() 240 | 241 | # 242 | class ChineseFullTokenizer(object): 243 | """Runs end-to-end tokenziation.""" 244 | 245 | def __init__(self, vocab_file, do_lower_case=False): 246 | self.vocab = tokenization.load_vocab(vocab_file) 247 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 248 | self.wordpiece_tokenizer = tokenization.WordpieceTokenizer(vocab=self.vocab) 249 | self.do_lower_case = do_lower_case 250 | def tokenize(self, text): 251 | split_tokens = [] 252 | for token in customize_tokenizer(text, do_lower_case=self.do_lower_case): 253 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 254 | split_tokens.append(sub_token) 255 | 256 | return split_tokens 257 | 258 | def convert_tokens_to_ids(self, tokens): 259 | return tokenization.convert_by_vocab(self.vocab, tokens) 260 | 261 | def convert_ids_to_tokens(self, ids): 262 | return tokenization.convert_by_vocab(self.inv_vocab, ids) 263 | 264 | # 265 | def read_squad_examples(input_file, is_training): 266 | """Read a SQuAD json file into a list of SquadExample.""" 267 | with tf.gfile.Open(input_file, "r") as reader: 268 | input_data = json.load(reader)["data"] 269 | 270 | # 271 | examples = [] 272 | for entry in input_data: 273 | for paragraph in entry["paragraphs"]: 274 | paragraph_text = paragraph["context"] 275 | raw_doc_tokens = customize_tokenizer(paragraph_text, do_lower_case=FLAGS.do_lower_case) 276 | doc_tokens = [] 277 | char_to_word_offset = [] 278 | prev_is_whitespace = True 279 | 280 | k = 0 281 | temp_word = "" 282 | for c in paragraph_text: 283 | if tokenization._is_whitespace(c): 284 | char_to_word_offset.append(k-1) 285 | continue 286 | else: 287 | temp_word += c 288 | char_to_word_offset.append(k) 289 | if FLAGS.do_lower_case: 290 | temp_word = temp_word.lower() 291 | if temp_word == raw_doc_tokens[k]: 292 | doc_tokens.append(temp_word) 293 | temp_word = "" 294 | k += 1 295 | 296 | assert k==len(raw_doc_tokens) 297 | 298 | for qa in paragraph["qas"]: 299 | qas_id = qa["id"] 300 | question_text = qa["question"] 301 | start_position = None 302 | end_position = None 303 | orig_answer_text = None 304 | 305 | if is_training: 306 | answer = qa["answers"][0] 307 | orig_answer_text = answer["text"] 308 | 309 | if orig_answer_text not in paragraph_text: 310 | tf.logging.warning("Could not find answer") 311 | else: 312 | answer_offset = paragraph_text.index(orig_answer_text) 313 | answer_length = len(orig_answer_text) 314 | start_position = char_to_word_offset[answer_offset] 315 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 316 | 317 | # Only add answers where the text can be exactly recovered from the 318 | # document. If this CAN'T happen it's likely due to weird Unicode 319 | # stuff so we will just skip the example. 320 | # 321 | # Note that this means for training mode, every example is NOT 322 | # guaranteed to be preserved. 323 | actual_text = "".join( 324 | doc_tokens[start_position:(end_position + 1)]) 325 | cleaned_answer_text = "".join( 326 | tokenization.whitespace_tokenize(orig_answer_text)) 327 | if FLAGS.do_lower_case: 328 | cleaned_answer_text = cleaned_answer_text.lower() 329 | if actual_text.find(cleaned_answer_text) == -1: 330 | pdb.set_trace() 331 | tf.logging.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) 332 | continue 333 | 334 | example = SquadExample( 335 | qas_id=qas_id, 336 | question_text=question_text, 337 | doc_tokens=doc_tokens, 338 | orig_answer_text=orig_answer_text, 339 | start_position=start_position, 340 | end_position=end_position) 341 | examples.append(example) 342 | tf.logging.info("**********read_squad_examples complete!**********") 343 | 344 | return examples 345 | 346 | 347 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 348 | doc_stride, max_query_length, is_training, 349 | output_fn): 350 | """Loads a data file into a list of `InputBatch`s.""" 351 | 352 | unique_id = 1000000000 353 | tokenizer = ChineseFullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 354 | 355 | for (example_index, example) in enumerate(examples): 356 | query_tokens = tokenizer.tokenize(example.question_text) 357 | 358 | if len(query_tokens) > max_query_length: 359 | query_tokens = query_tokens[0:max_query_length] 360 | 361 | tok_to_orig_index = [] 362 | orig_to_tok_index = [] 363 | all_doc_tokens = [] 364 | for (i, token) in enumerate(example.doc_tokens): 365 | orig_to_tok_index.append(len(all_doc_tokens)) 366 | sub_tokens = tokenizer.tokenize(token) 367 | for sub_token in sub_tokens: 368 | tok_to_orig_index.append(i) 369 | all_doc_tokens.append(sub_token) 370 | 371 | tok_start_position = None 372 | tok_end_position = None 373 | if is_training: 374 | tok_start_position = orig_to_tok_index[example.start_position] 375 | if example.end_position < len(example.doc_tokens) - 1: 376 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 377 | else: 378 | tok_end_position = len(all_doc_tokens) - 1 379 | (tok_start_position, tok_end_position) = _improve_answer_span( 380 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 381 | example.orig_answer_text) 382 | 383 | # The -3 accounts for [CLS], [SEP] and [SEP] 384 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 385 | 386 | # We can have documents that are longer than the maximum sequence length. 387 | # To deal with this we do a sliding window approach, where we take chunks 388 | # of the up to our max length with a stride of `doc_stride`. 389 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 390 | "DocSpan", ["start", "length"]) 391 | doc_spans = [] 392 | start_offset = 0 393 | while start_offset < len(all_doc_tokens): 394 | length = len(all_doc_tokens) - start_offset 395 | if length > max_tokens_for_doc: 396 | length = max_tokens_for_doc 397 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 398 | if start_offset + length == len(all_doc_tokens): 399 | break 400 | start_offset += min(length, doc_stride) 401 | 402 | for (doc_span_index, doc_span) in enumerate(doc_spans): 403 | tokens = [] 404 | token_to_orig_map = {} 405 | token_is_max_context = {} 406 | segment_ids = [] 407 | input_span_mask = [] 408 | tokens.append("[CLS]") 409 | segment_ids.append(0) 410 | input_span_mask.append(1) 411 | for token in query_tokens: 412 | tokens.append(token) 413 | segment_ids.append(0) 414 | input_span_mask.append(0) 415 | tokens.append("[SEP]") 416 | segment_ids.append(0) 417 | input_span_mask.append(0) 418 | 419 | for i in range(doc_span.length): 420 | split_token_index = doc_span.start + i 421 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 422 | 423 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 424 | split_token_index) 425 | token_is_max_context[len(tokens)] = is_max_context 426 | tokens.append(all_doc_tokens[split_token_index]) 427 | segment_ids.append(1) 428 | input_span_mask.append(1) 429 | tokens.append("[SEP]") 430 | segment_ids.append(1) 431 | input_span_mask.append(0) 432 | 433 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 434 | 435 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 436 | # tokens are attended to. 437 | input_mask = [1] * len(input_ids) 438 | 439 | # Zero-pad up to the sequence length. 440 | while len(input_ids) < max_seq_length: 441 | input_ids.append(0) 442 | input_mask.append(0) 443 | segment_ids.append(0) 444 | input_span_mask.append(0) 445 | 446 | assert len(input_ids) == max_seq_length 447 | assert len(input_mask) == max_seq_length 448 | assert len(segment_ids) == max_seq_length 449 | assert len(input_span_mask) == max_seq_length 450 | 451 | start_position = None 452 | end_position = None 453 | if is_training: 454 | # For training, if our document chunk does not contain an annotation 455 | # we throw it out, since there is nothing to predict. 456 | doc_start = doc_span.start 457 | doc_end = doc_span.start + doc_span.length - 1 458 | out_of_span = False 459 | if not (tok_start_position >= doc_start and 460 | tok_end_position <= doc_end): 461 | out_of_span = True 462 | if out_of_span: 463 | start_position = 0 464 | end_position = 0 465 | else: 466 | doc_offset = len(query_tokens) + 2 467 | start_position = tok_start_position - doc_start + doc_offset 468 | end_position = tok_end_position - doc_start + doc_offset 469 | 470 | if example_index < 3: 471 | tf.logging.info("*** Example ***") 472 | tf.logging.info("unique_id: %s" % (unique_id)) 473 | tf.logging.info("example_index: %s" % (example_index)) 474 | tf.logging.info("doc_span_index: %s" % (doc_span_index)) 475 | tf.logging.info("tokens: %s" % " ".join( 476 | [tokenization.printable_text(x) for x in tokens])) 477 | tf.logging.info("token_to_orig_map: %s" % " ".join( 478 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 479 | tf.logging.info("token_is_max_context: %s" % " ".join([ 480 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 481 | ])) 482 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 483 | tf.logging.info( 484 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 485 | tf.logging.info( 486 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 487 | tf.logging.info( 488 | "input_span_mask: %s" % " ".join([str(x) for x in input_span_mask])) 489 | if is_training: 490 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 491 | tf.logging.info("start_position: %d" % (start_position)) 492 | tf.logging.info("end_position: %d" % (end_position)) 493 | tf.logging.info( 494 | "answer: %s" % (tokenization.printable_text(answer_text))) 495 | 496 | 497 | feature = InputFeatures( 498 | unique_id=unique_id, 499 | example_index=example_index, 500 | doc_span_index=doc_span_index, 501 | tokens=tokens, 502 | token_to_orig_map=token_to_orig_map, 503 | token_is_max_context=token_is_max_context, 504 | input_ids=input_ids, 505 | input_mask=input_mask, 506 | segment_ids=segment_ids, 507 | input_span_mask=input_span_mask, 508 | start_position=start_position, 509 | end_position=end_position) 510 | 511 | # Run callback 512 | output_fn(feature) 513 | 514 | unique_id += 1 515 | 516 | 517 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 518 | orig_answer_text): 519 | """Returns tokenized answer spans that better match the annotated answer.""" 520 | 521 | # The SQuAD annotations are character based. We first project them to 522 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 523 | # often find a "better match". For example: 524 | # 525 | # Question: What year was John Smith born? 526 | # Context: The leader was John Smith (1895-1943). 527 | # Answer: 1895 528 | # 529 | # The original whitespace-tokenized answer will be "(1895-1943).". However 530 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 531 | # the exact answer, 1895. 532 | # 533 | # However, this is not always possible. Consider the following: 534 | # 535 | # Question: What country is the top exporter of electornics? 536 | # Context: The Japanese electronics industry is the lagest in the world. 537 | # Answer: Japan 538 | # 539 | # In this case, the annotator chose "Japan" as a character sub-span of 540 | # the word "Japanese". Since our WordPiece tokenizer does not split 541 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 542 | # in SQuAD, but does happen. 543 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 544 | 545 | for new_start in range(input_start, input_end + 1): 546 | for new_end in range(input_end, new_start - 1, -1): 547 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 548 | if text_span == tok_answer_text: 549 | return (new_start, new_end) 550 | 551 | return (input_start, input_end) 552 | 553 | 554 | def _check_is_max_context(doc_spans, cur_span_index, position): 555 | """Check if this is the 'max context' doc span for the token.""" 556 | 557 | # Because of the sliding window approach taken to scoring documents, a single 558 | # token can appear in multiple documents. E.g. 559 | # Doc: the man went to the store and bought a gallon of milk 560 | # Span A: the man went to the 561 | # Span B: to the store and bought 562 | # Span C: and bought a gallon of 563 | # ... 564 | # 565 | # Now the word 'bought' will have two scores from spans B and C. We only 566 | # want to consider the score with "maximum context", which we define as 567 | # the *minimum* of its left and right context (the *sum* of left and 568 | # right context will always be the same, of course). 569 | # 570 | # In the example the maximum context for 'bought' would be span C since 571 | # it has 1 left context and 3 right context, while span B has 4 left context 572 | # and 0 right context. 573 | best_score = None 574 | best_span_index = None 575 | for (span_index, doc_span) in enumerate(doc_spans): 576 | end = doc_span.start + doc_span.length - 1 577 | if position < doc_span.start: 578 | continue 579 | if position > end: 580 | continue 581 | num_left_context = position - doc_span.start 582 | num_right_context = end - position 583 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 584 | if best_score is None or score > best_score: 585 | best_score = score 586 | best_span_index = span_index 587 | 588 | return cur_span_index == best_span_index 589 | 590 | 591 | # 592 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, input_span_mask, 593 | use_one_hot_embeddings): 594 | """Creates a classification model.""" 595 | model = modeling.BertModel( 596 | config=bert_config, 597 | is_training=is_training, 598 | input_ids=input_ids, 599 | input_mask=input_mask, 600 | token_type_ids=segment_ids, 601 | use_one_hot_embeddings=use_one_hot_embeddings) 602 | 603 | 604 | final_hidden = model.get_sequence_output() 605 | 606 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) 607 | batch_size = final_hidden_shape[0] 608 | seq_length = final_hidden_shape[1] 609 | hidden_size = final_hidden_shape[2] 610 | 611 | output_weights = tf.get_variable( 612 | "cls/squad/output_weights", [2, hidden_size], 613 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 614 | 615 | output_bias = tf.get_variable( 616 | "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()) 617 | 618 | final_hidden_matrix = tf.reshape(final_hidden, 619 | [batch_size * seq_length, hidden_size]) 620 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) 621 | logits = tf.nn.bias_add(logits, output_bias) 622 | 623 | logits = tf.reshape(logits, [batch_size, seq_length, 2]) 624 | logits = tf.transpose(logits, [2, 0, 1]) 625 | 626 | unstacked_logits = tf.unstack(logits, axis=0) 627 | 628 | (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) 629 | 630 | # apply output mask 631 | adder = (1.0 - tf.cast(input_span_mask, tf.float32)) * -10000.0 632 | start_logits += adder 633 | end_logits += adder 634 | 635 | return (start_logits, end_logits) 636 | 637 | 638 | # 639 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 640 | num_train_steps, num_warmup_steps, use_tpu, 641 | use_one_hot_embeddings): 642 | """Returns `model_fn` closure for TPUEstimator.""" 643 | 644 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 645 | """The `model_fn` for TPUEstimator.""" 646 | 647 | tf.logging.info("*** Features ***") 648 | for name in sorted(features.keys()): 649 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 650 | 651 | unique_ids = features["unique_ids"] 652 | input_ids = features["input_ids"] 653 | input_mask = features["input_mask"] 654 | segment_ids = features["segment_ids"] 655 | input_span_mask = features["input_span_mask"] 656 | 657 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 658 | 659 | (start_logits, end_logits) = create_model( 660 | bert_config=bert_config, 661 | is_training=is_training, 662 | input_ids=input_ids, 663 | input_mask=input_mask, 664 | segment_ids=segment_ids, 665 | input_span_mask=input_span_mask, 666 | use_one_hot_embeddings=use_one_hot_embeddings) 667 | 668 | tvars = tf.trainable_variables() 669 | 670 | initialized_variable_names = {} 671 | scaffold_fn = None 672 | if init_checkpoint: 673 | (assignment_map, initialized_variable_names 674 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 675 | if use_tpu: 676 | 677 | def tpu_scaffold(): 678 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 679 | return tf.train.Scaffold() 680 | 681 | scaffold_fn = tpu_scaffold 682 | else: 683 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 684 | 685 | tf.logging.info("**** Trainable Variables ****") 686 | for var in tvars: 687 | init_string = "" 688 | if var.name in initialized_variable_names: 689 | init_string = ", *INIT_FROM_CKPT*" 690 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) 691 | 692 | output_spec = None 693 | if mode == tf.estimator.ModeKeys.TRAIN: 694 | seq_length = modeling.get_shape_list(input_ids)[1] 695 | 696 | def compute_loss(logits, positions): 697 | on_hot_pos = tf.one_hot(positions, depth=seq_length, dtype=tf.float32) 698 | log_probs = tf.nn.log_softmax(logits, axis=-1) 699 | loss = -tf.reduce_mean(tf.reduce_sum(on_hot_pos * log_probs, axis=-1)) 700 | return loss 701 | 702 | start_positions = features["start_positions"] 703 | end_positions = features["end_positions"] 704 | 705 | start_loss = compute_loss(start_logits, start_positions) 706 | end_loss = compute_loss(end_logits, end_positions) 707 | total_loss = (start_loss + end_loss) / 2 708 | 709 | train_op = optimization.create_optimizer( 710 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 711 | 712 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 713 | mode=mode, 714 | loss=total_loss, 715 | train_op=train_op, 716 | scaffold_fn=scaffold_fn) 717 | elif mode == tf.estimator.ModeKeys.PREDICT: 718 | start_logits = tf.nn.log_softmax(start_logits, axis=-1) 719 | end_logits = tf.nn.log_softmax(end_logits, axis=-1) 720 | 721 | predictions = { 722 | "unique_ids": unique_ids, 723 | "start_logits": start_logits, 724 | "end_logits": end_logits, 725 | } 726 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 727 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 728 | else: 729 | raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) 730 | 731 | return output_spec 732 | 733 | return model_fn 734 | 735 | 736 | def input_fn_builder(input_file, seq_length, is_training, drop_remainder): 737 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 738 | 739 | name_to_features = { 740 | "unique_ids": tf.FixedLenFeature([], tf.int64), 741 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 742 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 743 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 744 | "input_span_mask": tf.FixedLenFeature([seq_length], tf.int64), 745 | } 746 | 747 | if is_training: 748 | name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64) 749 | name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64) 750 | 751 | def _decode_record(record, name_to_features): 752 | """Decodes a record to a TensorFlow example.""" 753 | example = tf.parse_single_example(record, name_to_features) 754 | 755 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 756 | # So cast all int64 to int32. 757 | for name in list(example.keys()): 758 | t = example[name] 759 | if t.dtype == tf.int64: 760 | t = tf.to_int32(t) 761 | example[name] = t 762 | 763 | return example 764 | 765 | def input_fn(params): 766 | """The actual input function.""" 767 | batch_size = params["batch_size"] 768 | 769 | # For training, we want a lot of parallel reading and shuffling. 770 | # For eval, we want no shuffling and parallel reading doesn't matter. 771 | d = tf.data.TFRecordDataset(input_file) 772 | if is_training: 773 | d = d.repeat() 774 | d = d.shuffle(buffer_size=100) 775 | 776 | d = d.apply( 777 | tf.contrib.data.map_and_batch( 778 | lambda record: _decode_record(record, name_to_features), 779 | batch_size=batch_size, 780 | drop_remainder=drop_remainder)) 781 | 782 | return d 783 | 784 | return input_fn 785 | 786 | 787 | RawResult = collections.namedtuple("RawResult", 788 | ["unique_id", "start_logits", "end_logits"]) 789 | 790 | def write_predictions(all_examples, all_features, all_results, n_best_size, 791 | max_answer_length, do_lower_case, output_prediction_file, 792 | output_nbest_file): 793 | """Write final predictions to the json file and log-odds of null if needed.""" 794 | tf.logging.info("Writing predictions to: %s" % (output_prediction_file)) 795 | tf.logging.info("Writing nbest to: %s" % (output_nbest_file)) 796 | 797 | example_index_to_features = collections.defaultdict(list) 798 | for feature in all_features: 799 | example_index_to_features[feature.example_index].append(feature) 800 | 801 | unique_id_to_result = {} 802 | for result in all_results: 803 | unique_id_to_result[result.unique_id] = result 804 | 805 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 806 | "PrelimPrediction", 807 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 808 | 809 | all_predictions = collections.OrderedDict() 810 | all_nbest_json = collections.OrderedDict() 811 | 812 | for (example_index, example) in enumerate(all_examples): 813 | features = example_index_to_features[example_index] 814 | prelim_predictions = [] 815 | 816 | for (feature_index, feature) in enumerate(features): # multi-trunk 817 | result = unique_id_to_result[feature.unique_id] 818 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 819 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 820 | for start_index in start_indexes: 821 | for end_index in end_indexes: 822 | # We could hypothetically create invalid predictions, e.g., predict 823 | # that the start of the span is in the question. We throw out all 824 | # invalid predictions. 825 | if start_index >= len(feature.tokens): 826 | continue 827 | if end_index >= len(feature.tokens): 828 | continue 829 | if start_index not in feature.token_to_orig_map: 830 | continue 831 | if end_index not in feature.token_to_orig_map: 832 | continue 833 | if not feature.token_is_max_context.get(start_index, False): 834 | continue 835 | if end_index < start_index: 836 | continue 837 | length = end_index - start_index + 1 838 | if length > max_answer_length: 839 | continue 840 | prelim_predictions.append( 841 | _PrelimPrediction( 842 | feature_index=feature_index, 843 | start_index=start_index, 844 | end_index=end_index, 845 | start_logit=result.start_logits[start_index], 846 | end_logit=result.end_logits[end_index])) 847 | 848 | prelim_predictions = sorted( 849 | prelim_predictions, 850 | key=lambda x: (x.start_logit + x.end_logit), 851 | reverse=True) 852 | 853 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 854 | "NbestPrediction", ["text", "start_logit", "end_logit", "start_index", "end_index"]) 855 | 856 | seen_predictions = {} 857 | nbest = [] 858 | for pred in prelim_predictions: 859 | if len(nbest) >= n_best_size: 860 | break 861 | feature = features[pred.feature_index] 862 | if pred.start_index > 0: # this is a non-null prediction 863 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 864 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 865 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 866 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 867 | tok_text = " ".join(tok_tokens) 868 | 869 | # De-tokenize WordPieces that have been split off. 870 | tok_text = tok_text.replace(" ##", "") 871 | tok_text = tok_text.replace("##", "") 872 | 873 | # Clean whitespace 874 | tok_text = tok_text.strip() 875 | tok_text = " ".join(tok_text.split()) 876 | orig_text = " ".join(orig_tokens) 877 | 878 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 879 | final_text = final_text.replace(' ','') 880 | if final_text in seen_predictions: 881 | continue 882 | 883 | seen_predictions[final_text] = True 884 | else: 885 | final_text = "" 886 | seen_predictions[final_text] = True 887 | 888 | nbest.append( 889 | _NbestPrediction( 890 | text=final_text, 891 | start_logit=pred.start_logit, 892 | end_logit=pred.end_logit, 893 | start_index=pred.start_index, 894 | end_index=pred.end_index)) 895 | 896 | # In very rare edge cases we could have no valid predictions. So we 897 | # just create a nonce prediction in this case to avoid failure. 898 | if not nbest: 899 | nbest.append( 900 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0, start_index=0, end_index=0)) 901 | 902 | assert len(nbest) >= 1 903 | 904 | total_scores = [] 905 | best_non_null_entry = None 906 | for entry in nbest: 907 | total_scores.append(entry.start_logit + entry.end_logit) 908 | if not best_non_null_entry: 909 | if entry.text: 910 | best_non_null_entry = entry 911 | 912 | probs = _compute_softmax(total_scores) 913 | 914 | nbest_json = [] 915 | for (i, entry) in enumerate(nbest): 916 | output = collections.OrderedDict() 917 | output["text"] = entry.text 918 | output["probability"] = probs[i] 919 | output["start_logit"] = entry.start_logit 920 | output["end_logit"] = entry.end_logit 921 | output["start_index"] = entry.start_index 922 | output["end_index"] = entry.end_index 923 | nbest_json.append(output) 924 | 925 | assert len(nbest_json) >= 1 926 | 927 | all_predictions[example.qas_id] = best_non_null_entry.text 928 | all_nbest_json[example.qas_id] = nbest_json 929 | 930 | with tf.gfile.GFile(output_prediction_file, "w") as writer: 931 | writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n") 932 | 933 | with tf.gfile.GFile(output_nbest_file, "w") as writer: 934 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 935 | 936 | 937 | def get_final_text(pred_text, orig_text, do_lower_case): 938 | """Project the tokenized prediction back to the original text.""" 939 | 940 | # When we created the data, we kept track of the alignment between original 941 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 942 | # now `orig_text` contains the span of our original text corresponding to the 943 | # span that we predicted. 944 | # 945 | # However, `orig_text` may contain extra characters that we don't want in 946 | # our prediction. 947 | # 948 | # For example, let's say: 949 | # pred_text = steve smith 950 | # orig_text = Steve Smith's 951 | # 952 | # We don't want to return `orig_text` because it contains the extra "'s". 953 | # 954 | # We don't want to return `pred_text` because it's already been normalized 955 | # (the SQuAD eval script also does punctuation stripping/lower casing but 956 | # our tokenizer does additional normalization like stripping accent 957 | # characters). 958 | # 959 | # What we really want to return is "Steve Smith". 960 | # 961 | # Therefore, we have to apply a semi-complicated alignment heruistic between 962 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 963 | # can fail in certain cases in which case we just return `orig_text`. 964 | 965 | def _strip_spaces(text): 966 | ns_chars = [] 967 | ns_to_s_map = collections.OrderedDict() 968 | for (i, c) in enumerate(text): 969 | if c == " ": 970 | continue 971 | ns_to_s_map[len(ns_chars)] = i 972 | ns_chars.append(c) 973 | ns_text = "".join(ns_chars) 974 | return (ns_text, ns_to_s_map) 975 | 976 | # We first tokenize `orig_text`, strip whitespace from the result 977 | # and `pred_text`, and check if they are the same length. If they are 978 | # NOT the same length, the heuristic has failed. If they are the same 979 | # length, we assume the characters are one-to-one aligned. 980 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 981 | 982 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 983 | 984 | start_position = tok_text.find(pred_text) 985 | if start_position == -1: 986 | if FLAGS.verbose_logging: 987 | tf.logging.info( 988 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 989 | return orig_text 990 | end_position = start_position + len(pred_text) - 1 991 | 992 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 993 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 994 | 995 | if len(orig_ns_text) != len(tok_ns_text): 996 | if FLAGS.verbose_logging: 997 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", 998 | orig_ns_text, tok_ns_text) 999 | return orig_text 1000 | 1001 | # We then project the characters in `pred_text` back to `orig_text` using 1002 | # the character-to-character alignment. 1003 | tok_s_to_ns_map = {} 1004 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 1005 | tok_s_to_ns_map[tok_index] = i 1006 | 1007 | orig_start_position = None 1008 | if start_position in tok_s_to_ns_map: 1009 | ns_start_position = tok_s_to_ns_map[start_position] 1010 | if ns_start_position in orig_ns_to_s_map: 1011 | orig_start_position = orig_ns_to_s_map[ns_start_position] 1012 | 1013 | if orig_start_position is None: 1014 | if FLAGS.verbose_logging: 1015 | tf.logging.info("Couldn't map start position") 1016 | return orig_text 1017 | 1018 | orig_end_position = None 1019 | if end_position in tok_s_to_ns_map: 1020 | ns_end_position = tok_s_to_ns_map[end_position] 1021 | if ns_end_position in orig_ns_to_s_map: 1022 | orig_end_position = orig_ns_to_s_map[ns_end_position] 1023 | 1024 | if orig_end_position is None: 1025 | if FLAGS.verbose_logging: 1026 | tf.logging.info("Couldn't map end position") 1027 | return orig_text 1028 | 1029 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 1030 | return output_text 1031 | 1032 | 1033 | def _get_best_indexes(logits, n_best_size): 1034 | """Get the n-best logits from a list.""" 1035 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 1036 | 1037 | best_indexes = [] 1038 | for i in range(len(index_and_score)): 1039 | if i >= n_best_size: 1040 | break 1041 | best_indexes.append(index_and_score[i][0]) 1042 | return best_indexes 1043 | 1044 | 1045 | def _compute_softmax(scores): 1046 | """Compute softmax probability over raw logits.""" 1047 | if not scores: 1048 | return [] 1049 | 1050 | max_score = None 1051 | for score in scores: 1052 | if max_score is None or score > max_score: 1053 | max_score = score 1054 | 1055 | exp_scores = [] 1056 | total_sum = 0.0 1057 | for score in scores: 1058 | x = math.exp(score - max_score) 1059 | exp_scores.append(x) 1060 | total_sum += x 1061 | 1062 | probs = [] 1063 | for score in exp_scores: 1064 | probs.append(score / total_sum) 1065 | return probs 1066 | 1067 | 1068 | class FeatureWriter(object): 1069 | """Writes InputFeature to TF example file.""" 1070 | 1071 | def __init__(self, filename, is_training): 1072 | self.filename = filename 1073 | self.is_training = is_training 1074 | self.num_features = 0 1075 | self._writer = tf.python_io.TFRecordWriter(filename) 1076 | 1077 | def process_feature(self, feature): 1078 | """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" 1079 | self.num_features += 1 1080 | 1081 | def create_int_feature(values): 1082 | feature = tf.train.Feature( 1083 | int64_list=tf.train.Int64List(value=list(values))) 1084 | return feature 1085 | 1086 | features = collections.OrderedDict() 1087 | features["unique_ids"] = create_int_feature([feature.unique_id]) 1088 | features["input_ids"] = create_int_feature(feature.input_ids) 1089 | features["input_mask"] = create_int_feature(feature.input_mask) 1090 | features["segment_ids"] = create_int_feature(feature.segment_ids) 1091 | features["input_span_mask"] = create_int_feature(feature.input_span_mask) 1092 | 1093 | if self.is_training: 1094 | features["start_positions"] = create_int_feature([feature.start_position]) 1095 | features["end_positions"] = create_int_feature([feature.end_position]) 1096 | 1097 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 1098 | self._writer.write(tf_example.SerializeToString()) 1099 | 1100 | def close(self): 1101 | self._writer.close() 1102 | 1103 | 1104 | def validate_flags_or_throw(bert_config): 1105 | """Validate the input FLAGS or throw an exception.""" 1106 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 1107 | FLAGS.init_checkpoint) 1108 | 1109 | if not FLAGS.do_train and not FLAGS.do_predict: 1110 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 1111 | 1112 | if FLAGS.do_train: 1113 | if not FLAGS.train_file: 1114 | raise ValueError( 1115 | "If `do_train` is True, then `train_file` must be specified.") 1116 | if FLAGS.do_predict: 1117 | if not FLAGS.predict_file: 1118 | raise ValueError( 1119 | "If `do_predict` is True, then `predict_file` must be specified.") 1120 | if FLAGS.do_eval: 1121 | if not FLAGS.eval_file: 1122 | raise ValueError( 1123 | "If `do_eval` is True, then `eval_file` must be specified.") 1124 | 1125 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 1126 | raise ValueError( 1127 | "Cannot use sequence length %d because the BERT model " 1128 | "was only trained up to sequence length %d" % 1129 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 1130 | 1131 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 1132 | raise ValueError( 1133 | "The max_seq_length (%d) must be greater than max_query_length " 1134 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 1135 | 1136 | 1137 | def main(_): 1138 | tf.logging.set_verbosity(tf.logging.INFO) 1139 | 1140 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 1141 | 1142 | validate_flags_or_throw(bert_config) 1143 | 1144 | tf.gfile.MakeDirs(FLAGS.output_dir) 1145 | 1146 | tokenizer = tokenization.FullTokenizer( 1147 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 1148 | 1149 | tpu_cluster_resolver = None 1150 | if FLAGS.use_tpu and FLAGS.tpu_name: 1151 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 1152 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 1153 | 1154 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 1155 | run_config = tf.contrib.tpu.RunConfig( 1156 | cluster=tpu_cluster_resolver, 1157 | master=FLAGS.master, 1158 | model_dir=FLAGS.output_dir, 1159 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 1160 | keep_checkpoint_max=2, 1161 | tpu_config=tf.contrib.tpu.TPUConfig( 1162 | iterations_per_loop=FLAGS.iterations_per_loop, 1163 | num_shards=FLAGS.num_tpu_cores, 1164 | per_host_input_for_training=is_per_host)) 1165 | 1166 | train_examples = None 1167 | num_train_steps = None 1168 | num_warmup_steps = None 1169 | if FLAGS.do_train: 1170 | train_examples = read_squad_examples(input_file=FLAGS.train_file, is_training=True) 1171 | 1172 | # Pre-shuffle the input to avoid having to make a very large shuffle 1173 | # buffer in in the `input_fn`. 1174 | rng = random.Random(int(FLAGS.rand_seed)) 1175 | rng.shuffle(train_examples) 1176 | 1177 | # We write to a temporary file to avoid storing very large constant tensors 1178 | # in memory. 1179 | train_writer = FeatureWriter( 1180 | filename=os.path.join(FLAGS.output_dir, "train.tf_record"), 1181 | is_training=True) 1182 | convert_examples_to_features( 1183 | examples=train_examples, 1184 | tokenizer=tokenizer, 1185 | max_seq_length=FLAGS.max_seq_length, 1186 | doc_stride=FLAGS.doc_stride, 1187 | max_query_length=FLAGS.max_query_length, 1188 | is_training=True, 1189 | output_fn=train_writer.process_feature) 1190 | train_writer.close() 1191 | num_features = train_writer.num_features 1192 | train_examples_len = len(train_examples) 1193 | del train_examples 1194 | 1195 | num_train_steps = int(num_features / FLAGS.train_batch_size * FLAGS.num_train_epochs) 1196 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 1197 | 1198 | tf.logging.info("***** Running training *****") 1199 | tf.logging.info(" Num orig examples = %d", train_examples_len) 1200 | tf.logging.info(" Num split examples = %d", num_features) 1201 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 1202 | tf.logging.info(" Num steps = %d", num_train_steps) 1203 | 1204 | model_fn = model_fn_builder( 1205 | bert_config=bert_config, 1206 | init_checkpoint=FLAGS.init_checkpoint, 1207 | learning_rate=FLAGS.learning_rate, 1208 | num_train_steps=num_train_steps, 1209 | num_warmup_steps=num_warmup_steps, 1210 | use_tpu=FLAGS.use_tpu, 1211 | use_one_hot_embeddings=FLAGS.use_tpu) 1212 | 1213 | # If TPU is not available, this will fall back to normal Estimator on CPU 1214 | # or GPU. 1215 | estimator = tf.contrib.tpu.TPUEstimator( 1216 | use_tpu=FLAGS.use_tpu, 1217 | model_fn=model_fn, 1218 | config=run_config, 1219 | train_batch_size=FLAGS.train_batch_size, 1220 | predict_batch_size=FLAGS.predict_batch_size) 1221 | 1222 | # do training 1223 | if FLAGS.do_train: 1224 | train_writer_filename = train_writer.filename 1225 | 1226 | train_input_fn = input_fn_builder( 1227 | input_file=train_writer_filename, 1228 | seq_length=FLAGS.max_seq_length, 1229 | is_training=True, 1230 | drop_remainder=True) 1231 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 1232 | 1233 | # do predictions 1234 | if FLAGS.do_predict: 1235 | eval_examples = read_squad_examples( 1236 | input_file=FLAGS.predict_file, is_training=False) 1237 | 1238 | eval_writer = FeatureWriter( 1239 | filename=os.path.join(FLAGS.output_dir, "predict.tf_record"), 1240 | is_training=False) 1241 | eval_features = [] 1242 | 1243 | def append_feature(feature): 1244 | eval_features.append(feature) 1245 | eval_writer.process_feature(feature) 1246 | 1247 | convert_examples_to_features( 1248 | examples=eval_examples, 1249 | tokenizer=tokenizer, 1250 | max_seq_length=FLAGS.max_seq_length, 1251 | doc_stride=FLAGS.doc_stride, 1252 | max_query_length=FLAGS.max_query_length, 1253 | is_training=False, 1254 | output_fn=append_feature) 1255 | eval_writer.close() 1256 | 1257 | tf.logging.info("***** Running predictions *****") 1258 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 1259 | tf.logging.info(" Num split examples = %d", len(eval_features)) 1260 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1261 | 1262 | all_results = [] 1263 | 1264 | predict_input_fn = input_fn_builder( 1265 | input_file=eval_writer.filename, 1266 | seq_length=FLAGS.max_seq_length, 1267 | is_training=False, 1268 | drop_remainder=False) 1269 | 1270 | # If running eval on the TPU, you will need to specify the number of 1271 | # steps. 1272 | all_results = [] 1273 | for result in estimator.predict( 1274 | predict_input_fn, yield_single_examples=True): 1275 | if len(all_results) % 1000 == 0: 1276 | tf.logging.info("Processing example: %d" % (len(all_results))) 1277 | unique_id = int(result["unique_ids"]) 1278 | start_logits = [float(x) for x in result["start_logits"].flat] 1279 | end_logits = [float(x) for x in result["end_logits"].flat] 1280 | all_results.append( 1281 | RawResult( 1282 | unique_id=unique_id, 1283 | start_logits=start_logits, 1284 | end_logits=end_logits)) 1285 | 1286 | output_json_name = "dev_predictions.json" 1287 | output_nbest_name = "dev_nbest_predictions.json" 1288 | 1289 | output_prediction_file = os.path.join(FLAGS.output_dir, output_json_name) 1290 | output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_name) 1291 | 1292 | write_predictions(eval_examples, eval_features, all_results, 1293 | FLAGS.n_best_size, FLAGS.max_answer_length, 1294 | FLAGS.do_lower_case, output_prediction_file, 1295 | output_nbest_file) 1296 | 1297 | # do predictions 1298 | if FLAGS.do_eval: 1299 | eval_examples = read_squad_examples( 1300 | input_file=FLAGS.eval_file, is_training=False) 1301 | 1302 | eval_writer = FeatureWriter( 1303 | filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), 1304 | is_training=False) 1305 | eval_features = [] 1306 | 1307 | def append_feature(feature): 1308 | eval_features.append(feature) 1309 | eval_writer.process_feature(feature) 1310 | 1311 | convert_examples_to_features( 1312 | examples=eval_examples, 1313 | tokenizer=tokenizer, 1314 | max_seq_length=FLAGS.max_seq_length, 1315 | doc_stride=FLAGS.doc_stride, 1316 | max_query_length=FLAGS.max_query_length, 1317 | is_training=False, 1318 | output_fn=append_feature) 1319 | eval_writer.close() 1320 | 1321 | tf.logging.info("***** Running evals *****") 1322 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 1323 | tf.logging.info(" Num split examples = %d", len(eval_features)) 1324 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1325 | 1326 | all_results = [] 1327 | 1328 | predict_input_fn = input_fn_builder( 1329 | input_file=eval_writer.filename, 1330 | seq_length=FLAGS.max_seq_length, 1331 | is_training=False, 1332 | drop_remainder=False) 1333 | 1334 | # If running eval on the TPU, you will need to specify the number of 1335 | # steps. 1336 | all_results = [] 1337 | for result in estimator.predict( 1338 | predict_input_fn, yield_single_examples=True): 1339 | if len(all_results) % 1000 == 0: 1340 | tf.logging.info("Processing example: %d" % (len(all_results))) 1341 | unique_id = int(result["unique_ids"]) 1342 | start_logits = [float(x) for x in result["start_logits"].flat] 1343 | end_logits = [float(x) for x in result["end_logits"].flat] 1344 | all_results.append( 1345 | RawResult( 1346 | unique_id=unique_id, 1347 | start_logits=start_logits, 1348 | end_logits=end_logits)) 1349 | 1350 | output_json_name = "test_predictions.json" 1351 | output_nbest_name = "test_nbest_predictions.json" 1352 | 1353 | output_prediction_file = os.path.join(FLAGS.output_dir, output_json_name) 1354 | output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_name) 1355 | 1356 | write_predictions(eval_examples, eval_features, all_results, 1357 | FLAGS.n_best_size, FLAGS.max_answer_length, 1358 | FLAGS.do_lower_case, output_prediction_file, 1359 | output_nbest_file) 1360 | 1361 | 1362 | if __name__ == "__main__": 1363 | flags.mark_flag_as_required("vocab_file") 1364 | flags.mark_flag_as_required("bert_config_file") 1365 | flags.mark_flag_as_required("output_dir") 1366 | tf.app.run() 1367 | 1368 | 1369 | 1370 | --------------------------------------------------------------------------------