├── .gitignore
├── requirements.txt
├── start.sh
├── input.json
├── README.md
├── data-chinese.json
├── interactive.py
├── utils_squad_evaluate.py
├── bert_qa.py
└── utils_squad.py
/.gitignore:
--------------------------------------------------------------------------------
1 | lib/
2 | bin/
3 | include/
4 | .vscode
5 | __pycache__
6 | cache*
7 | .idea/*
8 | output/*
9 | runs/*
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-transformers==1.2.0
2 | tensorboardX==1.8
3 | tensorflow==1.14.0
4 | torch==1.0.1.post2
5 | tqdm==4.31.1
6 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=`pwd`
2 | export DATA_DIR=$CURRENT_DIR/data
3 |
4 | python bert_qa.py \
5 | --model_type bert \
6 | --model_name_or_path bert-base-chinese \
7 | --do_train \
8 | --do_eval \
9 | --do_lower_case \
10 | --train_file $DATA_DIR/cmrc2018_train.json \
11 | --predict_file $DATA_DIR/cmrc2018_dev.json \
12 | --per_gpu_train_batch_size 32 \
13 | --learning_rate 3e-5 \
14 | --num_train_epochs 3.0 \
15 | --max_seq_length 384 \
16 | --doc_stride 128 \
17 | --output_dir output
--------------------------------------------------------------------------------
/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "v1.0",
3 | "data": [
4 | {
5 | "paragraphs": [
6 | {
7 | "id": "",
8 | "context": "fd",
9 | "qas": [
10 | {
11 | "question": "df",
12 | "id": "",
13 | "answers": []
14 | },
15 | {
16 | "question": "df",
17 | "id": "",
18 | "answers": []
19 | },
20 | {
21 | "question": "df",
22 | "id": "",
23 | "answers": []
24 | }
25 | ]
26 | }
27 | ],
28 | "id": "",
29 | "title": ""
30 | }
31 | ]
32 | }
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
基于预训练模型BERT的阅读理解
2 |
3 |
4 | Here we are going to building a machine reading comprehension system using pretrained model bert from google, the latest advances in deep learning for NLP.
5 |
6 | Stanford Question Answering Dataset (SQuAD) is one of the first large reading comprehension datasets in English. From the perspective of model, the inputs come in the form of a Context / Question pair, and the outputs are Answers: pairs of integers, indexing the start and the end of the answer's text contained inside the Context.
7 | [The 2nd Evaluation Workshop on Chinese Machine Reading Comprehension(2018)](https://github.com/ymcui/cmrc2018) release part of the datasets similar to SQuAD, which we used in this example.
8 |
9 | The model is built on top of [pytorch-transformers](https://github.com/huggingface/pytorch-transformers) which help to use pretrained model like BERT, GPT, GPT2 to downstream tasks. The repository includes various utilities and training scripts for multiple NLP tasks, including Question Answering. Below are two relate post about QA using bert:
10 |
11 | >-[Understanding text with BERT](https://blog.scaleway.com/2019/understanding-text-with-bert/)
12 |
13 | >-[Extending Google-BERT as Question and Answering model and Chatbot](https://medium.com/datadriveninvestor/extending-google-bert-as-question-and-answering-model-and-chatbot-e3e7b47b721a)
14 |
15 | Getting Started
16 |
17 | #### 1. Prepare data, the virtual python environment and install the package in requirements.txt
18 |
19 | #### 2. Run the command below to fine tune
20 |
21 | ```bash
22 | python bert_qa.py \
23 | --model_type bert \
24 | --model_name_or_path bert-base-chinese \
25 | --do_train \
26 | --do_eval \
27 | --do_lower_case \
28 | --train_file ~/cmrc2018_train.json \
29 | --predict_file ~/cmrc2018_trial.json \
30 | --per_gpu_train_batch_size 12 \
31 | --learning_rate 3e-5 \
32 | --num_train_epochs 2.0 \
33 | --max_seq_length 384 \
34 | --doc_stride 128 \
35 | --output_dir ~/chinese-qa-with-bert/output \
36 | --save_steps 200
37 | ```
38 | #### 4. Load the fine-tuned params and run this command to interactive
39 |
40 | ```bash
41 | python interactive.py \
42 | --output_dir ~/chinese-qa-with-bert/output \
43 | --model_type bert \
44 | --predict_file ~/chinese-qa-with-bert/input.json \
45 | --state_dict ~/chinese-qa-with-bert/output/pytorch_model.bin \
46 | --model_name_or_path bert-base-chinese
47 | ```
48 |
49 | the input json string contains context and question is like this:
50 |
51 | ```json
52 | {"context": "海梅·雷耶斯()是DC漫画公司的一个虚拟人物。该人物首先出现于《无限危机 #3》(2006年二月),是第三代蓝甲虫,由作家基斯·吉芬和约翰·罗杰斯创作,屈伊·哈姆纳作画。海梅与他的父母妹妹生活在得克萨斯州的艾尔帕索。他的父亲拥有一间汽车修理厂。海梅建议自己帮助父亲在汽车修理厂中干活,然而他的父亲迄今为止未答应他,觉得海梅应该花更多的功夫在学习上并享受他自己的童年生活。海梅对他的家庭和朋友们有强烈的责任感,可是他经常抱怨做一个仅解决琐事的人。第二代蓝甲虫(泰德·科德)死前派遣圣甲虫去给惊奇队长送信,圣甲虫也因此留在了惊奇队长的永恒之岩里。之后惊奇队长被杀,蓝甲虫降落到了得克萨斯州的艾尔帕索,被少年海梅捡到,后来蓝甲虫在海梅睡觉时融合进他的脊椎,海梅从此成为第三代蓝甲虫,此时的蓝甲虫具备了随意变武器的能力。蓝甲虫原本是宇宙侵略组织Reach的战争工具,具有自己的思想。它曾被作为礼物赐予某星球的勇士,实际上是暗中控制他们。“卡基达(Kaji Dha)”是它的启动口令,第一代蓝甲虫加勒特就在它的影响下攻击过科德。蓝甲虫在无限危机后受到外界强大能量的影响沉睡了一年,苏醒后又想控制海梅,但被海梅的意念克制了,加上它本身存在程序故障,久而久之他与海梅成了好友。","qas": [{"question": "该漫画是由谁创作的?"}]}
53 | ```
54 | ### Below are some results for show
55 |
56 | ```json
57 | {"AVERAGE": "27.276", "F1": "54.552", "EM": "0.000", "TOTAL": 1002, "SKIP": 0, "FILE": "~/chinese-qa-with-bert/output/predictions_.
58 | json"}
59 | ```
60 |
61 | ```json
62 | Please Enter:{"context": "江苏路街道是中国上海市长宁区下辖的一个街道办事处,位于长宁区最东部,东到镇宁路接邻静安区的静安寺街道,北到武定西路接邻静安区的曹家渡 街道,南到华山路接邻徐汇区的湖南路街道,西部界限为长宁路、安西路、武夷路等。面积1.52平方公里,户籍人口5.26万人,下辖13个居委会。长宁区区政府设在该街道辖区内的 愚园路(近安西路)。江苏路街道的主要街道江苏路、愚园路、长宁路、武夷路、武定西路、延安西路,均为上海公共租界越界筑路,是花园洋房集中的街区,是愚园路历史文化风 貌区的主体部分。较著名的近代建筑有中西女中(今上海市第三女子中学)、王伯群及汪精卫公馆(今长宁区少年宫)、西园公寓等。江苏路、长宁路、延安西路(高架)等经过拓 宽,已经形成上海市的交通干道。上海市轨道交通二号线经过该街道辖区,设有江苏路站。江苏路街道下辖歧山居委会、江苏居委会、万村居委会、南汪居委会、东浜居委会、愚三 居委会、曹家堰居委会、西浜居委会、福世居委会、长新居委会、华山居委会、利西居委会、北汪居委会。","qas": [{"question": "江苏路街道在上海市的什么地方?"}]}
63 | Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.61it/s]
64 |
65 | 江 苏 路 街 道 是 中 国 上 海 市 长 宁
66 | ```
67 |
68 |
69 | Just For Learn, More Optimization to Do
--------------------------------------------------------------------------------
/data-chinese.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "v1.0",
3 | "data": [
4 | {
5 | "paragraphs": [
6 | {
7 | "id": "TRAIN_54",
8 | "context": "安雅·罗素法(,),来自俄罗斯圣彼得堡的模特儿。她是《全美超级模特儿新秀大赛》第十季的亚军。2008年,安雅宣布改回出生时的名字:安雅·罗素法(Anya Rozova),在此之前是使用安雅·冈()。安雅于俄罗斯出生,后来被一个居住在美国夏威夷群岛欧胡岛檀香山的家庭领养。安雅十七岁时曾参与香奈儿、路易·威登及芬迪(Fendi)等品牌的非正式时装秀。2007年,她于瓦伊帕胡高级中学毕业。毕业后,她当了一名售货员。她曾为Russell Tanoue拍摄照片,Russell Tanoue称赞她是「有前途的新面孔」。安雅在半准决赛面试时说她对模特儿行业充满热诚,所以参加全美超级模特儿新秀大赛。她于比赛中表现出色,曾五次首名入围,平均入围顺序更拿下历届以来最优异的成绩(2.64),另外胜出三次小挑战,分别获得与评判尼祖·百克拍照、为柠檬味道的七喜拍摄广告的机会及十万美元、和盖马蒂洛(Gai Mattiolo)设计的晚装。在最后两强中,安雅与另一名参赛者惠妮·汤姆森为范思哲走秀,但评判认为她在台上不够惠妮突出,所以选了惠妮当冠军,安雅屈居亚军(但就整体表现来说,部份网友认为安雅才是第十季名副其实的冠军。)安雅在比赛拿五次第一,也胜出多次小挑战。安雅赛后再次与Russell Tanoue合作,为2008年4月30日出版的MidWeek杂志拍摄封面及内页照。其后她参加了V杂志与Supreme模特儿公司合办的模特儿选拔赛2008。她其后更与Elite签约。最近她与香港的模特儿公司 Style International Management 签约,并在香港发展其模特儿事业。她曾在很多香港的时装杂志中任模特儿,《Jet》、《东方日报》、《Elle》等。",
9 | "qas": [
10 | {
11 | "question": "安雅·罗素法参加了什么比赛获得了亚军?",
12 | "id": "TRAIN_54_QUERY_0",
13 | "answers": [
14 | {
15 | "text": "《全美超级模特儿新秀大赛》第十季",
16 | "answer_start": 26
17 | }
18 | ]
19 | },
20 | {
21 | "question": "Russell Tanoue对安雅·罗素法的评价是什么?",
22 | "id": "TRAIN_54_QUERY_1",
23 | "answers": [
24 | {
25 | "text": "有前途的新面孔",
26 | "answer_start": 247
27 | }
28 | ]
29 | },
30 | {
31 | "question": "安雅·罗素法合作过的香港杂志有哪些?",
32 | "id": "TRAIN_54_QUERY_2",
33 | "answers": [
34 | {
35 | "text": "《Jet》、《东方日报》、《Elle》等",
36 | "answer_start": 706
37 | }
38 | ]
39 | },
40 | {
41 | "question": "毕业后的安雅·罗素法职业是什么?",
42 | "id": "TRAIN_54_QUERY_3",
43 | "answers": [
44 | {
45 | "text": "售货员",
46 | "answer_start": 202
47 | }
48 | ]
49 | }
50 | ]
51 | }
52 | ],
53 | "id": "TRAIN_54",
54 | "title": "安雅·罗素法"
55 | },
56 | {
57 | "paragraphs": [
58 | {
59 | "id": "TRAIN_756",
60 | "context": "为日本漫画足球小将翼的一个角色,自小父母离异,与父亲一起四处为家,每个地方也是待一会便离开,但他仍然能够保持优秀的学业成绩。在第一次南葛市生活时,与同样就读于南葛小学的大空翼为黄金拍档,曾效力球队包括南葛小学、南葛高中、日本少年队、日本青年军、日本奥运队。效力日本青年军期间,因救同母异父的妹妹导致被车撞至断脚,在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。基本资料绰号:球场上的艺术家出身地:日本南葛市诞生日:5月5日星座:金牛座球衣号码:11担任位置:中场、攻击中场、右中场擅长脚:右脚所属队伍:盘田山叶故事发展岬太郎在小学期间不断转换学校,在南葛小学就读时在全国大赛中夺得冠军;国中三年随父亲孤单地在法国留学;回国后三年的高中生涯一直输给日本王牌射手日向小次郎率领的东邦学院。在【Golden 23】年代,大空翼、日向小次郎等名将均转战海外,他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。必杀技1. 回力刀射门2. S. S. S. 射门3. 双人射门(与大空翼合作)",
61 | "qas": [
62 | {
63 | "question": "岬太郎在第一次南葛市生活时的搭档是谁?",
64 | "id": "TRAIN_756_QUERY_0",
65 | "answers": [
66 | {
67 | "text": "大空翼",
68 | "answer_start": 84
69 | }
70 | ]
71 | },
72 | {
73 | "question": "日本队夺得世青冠军,岬太郎发挥了什么作用?",
74 | "id": "TRAIN_756_QUERY_1",
75 | "answers": [
76 | {
77 | "text": "在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。",
78 | "answer_start": 156
79 | }
80 | ]
81 | },
82 | {
83 | "question": "岬太郎与谁一起组成了「3M」组合?",
84 | "id": "TRAIN_756_QUERY_2",
85 | "answers": [
86 | {
87 | "text": "他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。",
88 | "answer_start": 391
89 | }
90 | ]
91 | }
92 | ]
93 | }
94 | ],
95 | "id": "TRAIN_756",
96 | "title": "岬太郎"
97 | }
98 | ]
99 | }
--------------------------------------------------------------------------------
/interactive.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import argparse
4 | from pytorch_transformers import (BertConfig, BertForQuestionAnswering,
5 | BertTokenizer)
6 | from bert_qa import evaluate
7 | import os
8 |
9 | parser = argparse.ArgumentParser()
10 | ## Required parameters
11 | parser.add_argument(
12 | "--train_file",
13 | default=None,
14 | type=str,
15 | required=False,
16 | help="SQuAD json for training. E.g., train-v1.1.json")
17 | parser.add_argument(
18 | "--predict_file",
19 | default=None,
20 | type=str,
21 | required=True,
22 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
23 | parser.add_argument(
24 | "--model_type",
25 | default=None,
26 | type=str,
27 | required=True,
28 | help="Model type selected in the list: ")
29 | parser.add_argument(
30 | "--model_name_or_path",
31 | default=None,
32 | type=str,
33 | required=True,
34 | help="Path to pre-trained model or shortcut name selected in the list: ")
35 | parser.add_argument(
36 | "--output_dir",
37 | default=None,
38 | type=str,
39 | required=True,
40 | help=
41 | "The output directory where the model checkpoints and predictions will be written."
42 | )
43 |
44 | ## Other parameters
45 | parser.add_argument(
46 | "--config_name",
47 | default="",
48 | type=str,
49 | help="Pretrained config name or path if not the same as model_name")
50 | parser.add_argument(
51 | "--tokenizer_name",
52 | default="",
53 | type=str,
54 | help="Pretrained tokenizer name or path if not the same as model_name")
55 | parser.add_argument(
56 | "--cache_dir",
57 | default="",
58 | type=str,
59 | help="Where do you want to store the pre-trained models downloaded from s3"
60 | )
61 |
62 | parser.add_argument(
63 | '--version_2_with_negative',
64 | action='store_true',
65 | help='If true, the SQuAD examples contain some that do not have an answer.'
66 | )
67 | parser.add_argument(
68 | '--null_score_diff_threshold',
69 | type=float,
70 | default=0.0,
71 | help=
72 | "If null_score - best_non_null is greater than the threshold predict null."
73 | )
74 |
75 | parser.add_argument(
76 | "--max_seq_length",
77 | default=384,
78 | type=int,
79 | help=
80 | "The maximum total input sequence length after WordPiece tokenization. Sequences "
81 | "longer than this will be truncated, and sequences shorter than this will be padded."
82 | )
83 | parser.add_argument(
84 | "--doc_stride",
85 | default=128,
86 | type=int,
87 | help=
88 | "When splitting up a long document into chunks, how much stride to take between chunks."
89 | )
90 | parser.add_argument(
91 | "--max_query_length",
92 | default=64,
93 | type=int,
94 | help=
95 | "The maximum number of tokens for the question. Questions longer than this will "
96 | "be truncated to this length.")
97 | parser.add_argument(
98 | "--do_train", action='store_true', help="Whether to run training.")
99 | parser.add_argument(
100 | "--do_eval",
101 | action='store_true',
102 | help="Whether to run eval on the dev set.")
103 | parser.add_argument(
104 | "--evaluate_during_training",
105 | action='store_true',
106 | help="Rul evaluation during training at each logging step.")
107 | parser.add_argument(
108 | "--do_lower_case",
109 | action='store_true',
110 | help="Set this flag if you are using an uncased model.")
111 |
112 | parser.add_argument(
113 | "--per_gpu_train_batch_size",
114 | default=8,
115 | type=int,
116 | help="Batch size per GPU/CPU for training.")
117 | parser.add_argument(
118 | "--per_gpu_eval_batch_size",
119 | default=8,
120 | type=int,
121 | help="Batch size per GPU/CPU for evaluation.")
122 | parser.add_argument(
123 | "--learning_rate",
124 | default=5e-5,
125 | type=float,
126 | help="The initial learning rate for Adam.")
127 | parser.add_argument(
128 | '--gradient_accumulation_steps',
129 | type=int,
130 | default=1,
131 | help=
132 | "Number of updates steps to accumulate before performing a backward/update pass."
133 | )
134 | parser.add_argument(
135 | "--weight_decay",
136 | default=0.0,
137 | type=float,
138 | help="Weight deay if we apply some.")
139 | parser.add_argument(
140 | "--adam_epsilon",
141 | default=1e-8,
142 | type=float,
143 | help="Epsilon for Adam optimizer.")
144 | parser.add_argument(
145 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
146 | parser.add_argument(
147 | "--num_train_epochs",
148 | default=3.0,
149 | type=float,
150 | help="Total number of training epochs to perform.")
151 | parser.add_argument(
152 | "--max_steps",
153 | default=-1,
154 | type=int,
155 | help=
156 | "If > 0: set total number of training steps to perform. Override num_train_epochs."
157 | )
158 | parser.add_argument(
159 | "--warmup_steps",
160 | default=0,
161 | type=int,
162 | help="Linear warmup over warmup_steps.")
163 | parser.add_argument(
164 | "--n_best_size",
165 | default=20,
166 | type=int,
167 | help=
168 | "The total number of n-best predictions to generate in the nbest_predictions.json output file."
169 | )
170 | parser.add_argument(
171 | "--max_answer_length",
172 | default=30,
173 | type=int,
174 | help=
175 | "The maximum length of an answer that can be generated. This is needed because the start "
176 | "and end predictions are not conditioned on one another.")
177 | parser.add_argument(
178 | "--verbose_logging",
179 | action='store_true',
180 | help=
181 | "If true, all of the warnings related to data processing will be printed. "
182 | "A number of warnings are expected for a normal SQuAD evaluation.")
183 |
184 | parser.add_argument(
185 | '--logging_steps', type=int, default=50, help="Log every X updates steps.")
186 | parser.add_argument(
187 | '--save_steps',
188 | type=int,
189 | default=50,
190 | help="Save checkpoint every X updates steps.")
191 | parser.add_argument(
192 | "--eval_all_checkpoints",
193 | action='store_true',
194 | help=
195 | "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
196 | )
197 | parser.add_argument(
198 | "--no_cuda",
199 | action='store_true',
200 | help="Whether not to use CUDA when available")
201 | parser.add_argument(
202 | '--overwrite_output_dir',
203 | action='store_true',
204 | help="Overwrite the content of the output directory")
205 | parser.add_argument(
206 | '--overwrite_cache',
207 | action='store_true',
208 | help="Overwrite the cached training and evaluation sets")
209 | parser.add_argument(
210 | '--seed', type=int, default=42, help="random seed for initialization")
211 |
212 | parser.add_argument(
213 | "--local_rank",
214 | type=int,
215 | default=-1,
216 | help="local_rank for distributed training on gpus")
217 | parser.add_argument(
218 | '--fp16',
219 | action='store_true',
220 | help=
221 | "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
222 | )
223 | parser.add_argument(
224 | '--fp16_opt_level',
225 | type=str,
226 | default='O1',
227 | help=
228 | "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
229 | "See details at https://nvidia.github.io/apex/amp.html")
230 | parser.add_argument(
231 | "--state_dict",
232 | default=None,
233 | type=str,
234 | required=True,
235 | help="model para after pretrained")
236 |
237 | args = parser.parse_args()
238 | args.n_gpu = torch.cuda.device_count()
239 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
240 | device = torch.device(
241 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
242 | args.device = device
243 | tokenizer = BertTokenizer.from_pretrained(
244 | 'bert-base-chinese', do_lower_case=False)
245 | config = BertConfig.from_pretrained('bert-base-chinese')
246 | model = BertForQuestionAnswering(config)
247 | model_state_dict = args.state_dict
248 | model.load_state_dict(torch.load(model_state_dict))
249 | model.to(args.device)
250 | model.eval()
251 | input_file = args.predict_file
252 |
253 |
254 | def handle_file(input_file, context, question):
255 | with open(input_file, "r") as reader:
256 | orig_data = json.load(reader)
257 | orig_data["data"][0]['paragraphs'][0]['context'] = context
258 | for i in range(len(question)):
259 | orig_data["data"][0]['paragraphs'][0]['qas'][i][
260 | 'question'] = question[i]
261 | with open(input_file, "w") as writer:
262 | writer.write(json.dumps(orig_data, indent=4) + "\n")
263 |
264 |
265 | def run():
266 | while True:
267 | raw_text = input("Please Enter:")
268 | while not raw_text:
269 | print('Input should not be empty!')
270 | raw_text = input("Please Enter:")
271 | context = ''
272 | question = []
273 | try:
274 | raw_json = json.loads(raw_text)
275 | context = raw_json['context']
276 | if not context:
277 | continue
278 | raw_qas = raw_json['qas']
279 | if not raw_qas:
280 | continue
281 | for i in range(len(raw_qas)):
282 | question.append(raw_qas[i]['question'])
283 | except Exception as identifier:
284 | print(identifier)
285 | continue
286 | handle_file(input_file, context, question)
287 | evaluate(args, model, tokenizer)
288 |
289 | predict_file = os.path.join(args.output_dir, "predictions_.json")
290 | with open(predict_file, "r") as reader:
291 | orig_data = json.load(reader)
292 | print(orig_data[""])
293 | # clean input file
294 | handle_file(input_file, "", ["", "", ""])
295 |
296 |
297 | if __name__ == "__main__":
298 | run()
299 |
--------------------------------------------------------------------------------
/utils_squad_evaluate.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for SQuAD version 2.0.
2 | Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
3 | In addition to basic functionality, we also compute additional statistics and
4 | plot precision-recall curves if an additional na_prob.json file is provided.
5 | This file is expected to map question ID's to the model's predicted probability
6 | that a question is unanswerable.
7 | """
8 | import argparse
9 | import collections
10 | import json
11 | import numpy as np
12 | import os
13 | import re
14 | import string
15 | import sys
16 |
17 |
18 | class EVAL_OPTS():
19 | def __init__(self,
20 | data_file,
21 | pred_file,
22 | out_file="",
23 | na_prob_file="na_prob.json",
24 | na_prob_thresh=1.0,
25 | out_image_dir=None,
26 | verbose=False):
27 | self.data_file = data_file
28 | self.pred_file = pred_file
29 | self.out_file = out_file
30 | self.na_prob_file = na_prob_file
31 | self.na_prob_thresh = na_prob_thresh
32 | self.out_image_dir = out_image_dir
33 | self.verbose = verbose
34 |
35 |
36 | OPTS = None
37 |
38 |
39 | def parse_args():
40 | parser = argparse.ArgumentParser(
41 | 'Official evaluation script for SQuAD version 2.0.')
42 | parser.add_argument(
43 | 'data_file', metavar='data.json', help='Input data JSON file.')
44 | parser.add_argument(
45 | 'pred_file', metavar='pred.json', help='Model predictions.')
46 | parser.add_argument(
47 | '--out-file',
48 | '-o',
49 | metavar='eval.json',
50 | help='Write accuracy metrics to file (default is stdout).')
51 | parser.add_argument(
52 | '--na-prob-file',
53 | '-n',
54 | metavar='na_prob.json',
55 | help='Model estimates of probability of no answer.')
56 | parser.add_argument(
57 | '--na-prob-thresh',
58 | '-t',
59 | type=float,
60 | default=1.0,
61 | help='Predict "" if no-answer probability exceeds this (default = 1.0).'
62 | )
63 | parser.add_argument(
64 | '--out-image-dir',
65 | '-p',
66 | metavar='out_images',
67 | default=None,
68 | help='Save precision-recall curves to directory.')
69 | parser.add_argument('--verbose', '-v', action='store_true')
70 | if len(sys.argv) == 1:
71 | parser.print_help()
72 | sys.exit(1)
73 | return parser.parse_args()
74 |
75 |
76 | def make_qid_to_has_ans(dataset):
77 | qid_to_has_ans = {}
78 | for article in dataset:
79 | for p in article['paragraphs']:
80 | for qa in p['qas']:
81 | qid_to_has_ans[qa['id']] = bool(qa['answers'])
82 | return qid_to_has_ans
83 |
84 |
85 | def normalize_answer(s):
86 | """Lower text and remove punctuation, articles and extra whitespace."""
87 |
88 | def remove_articles(text):
89 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
90 | return re.sub(regex, ' ', text)
91 |
92 | def white_space_fix(text):
93 | return ' '.join(text.split())
94 |
95 | def remove_punc(text):
96 | exclude = set(string.punctuation)
97 | return ''.join(ch for ch in text if ch not in exclude)
98 |
99 | def lower(text):
100 | return text.lower()
101 |
102 | return white_space_fix(remove_articles(remove_punc(lower(s))))
103 |
104 |
105 | def get_tokens(s):
106 | if not s: return []
107 | return normalize_answer(s).split()
108 |
109 |
110 | def compute_exact(a_gold, a_pred):
111 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
112 |
113 |
114 | def compute_f1(a_gold, a_pred):
115 | gold_toks = get_tokens(a_gold)
116 | pred_toks = get_tokens(a_pred)
117 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
118 | num_same = sum(common.values())
119 | if len(gold_toks) == 0 or len(pred_toks) == 0:
120 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
121 | return int(gold_toks == pred_toks)
122 | if num_same == 0:
123 | return 0
124 | precision = 1.0 * num_same / len(pred_toks)
125 | recall = 1.0 * num_same / len(gold_toks)
126 | f1 = (2 * precision * recall) / (precision + recall)
127 | return f1
128 |
129 |
130 | def get_raw_scores(dataset, preds):
131 | exact_scores = {}
132 | f1_scores = {}
133 | for article in dataset:
134 | for p in article['paragraphs']:
135 | for qa in p['qas']:
136 | qid = qa['id']
137 | gold_answers = [
138 | a['text'] for a in qa['answers']
139 | if normalize_answer(a['text'])
140 | ]
141 | if not gold_answers:
142 | # For unanswerable questions, only correct answer is empty string
143 | gold_answers = ['']
144 | if qid not in preds:
145 | print('Missing prediction for %s' % qid)
146 | continue
147 | a_pred = preds[qid]
148 | # Take max over all gold answers
149 | exact_scores[qid] = max(
150 | compute_exact(a, a_pred) for a in gold_answers)
151 | f1_scores[qid] = max(
152 | compute_f1(a, a_pred) for a in gold_answers)
153 | return exact_scores, f1_scores
154 |
155 |
156 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
157 | new_scores = {}
158 | for qid, s in scores.items():
159 | pred_na = na_probs[qid] > na_prob_thresh
160 | if pred_na:
161 | new_scores[qid] = float(not qid_to_has_ans[qid])
162 | else:
163 | new_scores[qid] = s
164 | return new_scores
165 |
166 |
167 | def make_eval_dict(exact_scores, f1_scores, qid_list=None):
168 | if not qid_list:
169 | total = len(exact_scores)
170 | return collections.OrderedDict([
171 | ('exact', 100.0 * sum(exact_scores.values()) / total),
172 | ('f1', 100.0 * sum(f1_scores.values()) / total),
173 | ('total', total),
174 | ])
175 | else:
176 | total = len(qid_list)
177 | return collections.OrderedDict([
178 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
179 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
180 | ('total', total),
181 | ])
182 |
183 |
184 | def merge_eval(main_eval, new_eval, prefix):
185 | for k in new_eval:
186 | main_eval['%s_%s' % (prefix, k)] = new_eval[k]
187 |
188 |
189 | def plot_pr_curve(precisions, recalls, out_image, title):
190 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
191 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
192 | plt.xlabel('Recall')
193 | plt.ylabel('Precision')
194 | plt.xlim([0.0, 1.05])
195 | plt.ylim([0.0, 1.05])
196 | plt.title(title)
197 | plt.savefig(out_image)
198 | plt.clf()
199 |
200 |
201 | def make_precision_recall_eval(scores,
202 | na_probs,
203 | num_true_pos,
204 | qid_to_has_ans,
205 | out_image=None,
206 | title=None):
207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
208 | true_pos = 0.0
209 | cur_p = 1.0
210 | cur_r = 0.0
211 | precisions = [1.0]
212 | recalls = [0.0]
213 | avg_prec = 0.0
214 | for i, qid in enumerate(qid_list):
215 | if qid_to_has_ans[qid]:
216 | true_pos += scores[qid]
217 | cur_p = true_pos / float(i + 1)
218 | cur_r = true_pos / float(num_true_pos)
219 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i +
220 | 1]]:
221 | # i.e., if we can put a threshold after this point
222 | avg_prec += cur_p * (cur_r - recalls[-1])
223 | precisions.append(cur_p)
224 | recalls.append(cur_r)
225 | if out_image:
226 | plot_pr_curve(precisions, recalls, out_image, title)
227 | return {'ap': 100.0 * avg_prec}
228 |
229 |
230 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
231 | qid_to_has_ans, out_image_dir):
232 | if out_image_dir and not os.path.exists(out_image_dir):
233 | os.makedirs(out_image_dir)
234 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
235 | if num_true_pos == 0:
236 | return
237 | pr_exact = make_precision_recall_eval(
238 | exact_raw,
239 | na_probs,
240 | num_true_pos,
241 | qid_to_has_ans,
242 | out_image=os.path.join(out_image_dir, 'pr_exact.png'),
243 | title='Precision-Recall curve for Exact Match score')
244 | pr_f1 = make_precision_recall_eval(
245 | f1_raw,
246 | na_probs,
247 | num_true_pos,
248 | qid_to_has_ans,
249 | out_image=os.path.join(out_image_dir, 'pr_f1.png'),
250 | title='Precision-Recall curve for F1 score')
251 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
252 | pr_oracle = make_precision_recall_eval(
253 | oracle_scores,
254 | na_probs,
255 | num_true_pos,
256 | qid_to_has_ans,
257 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
258 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)'
259 | )
260 | merge_eval(main_eval, pr_exact, 'pr_exact')
261 | merge_eval(main_eval, pr_f1, 'pr_f1')
262 | merge_eval(main_eval, pr_oracle, 'pr_oracle')
263 |
264 |
265 | def histogram_na_prob(na_probs, qid_list, image_dir, name):
266 | if not qid_list:
267 | return
268 | x = [na_probs[k] for k in qid_list]
269 | weights = np.ones_like(x) / float(len(x))
270 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
271 | plt.xlabel('Model probability of no-answer')
272 | plt.ylabel('Proportion of dataset')
273 | plt.title('Histogram of no-answer probability: %s' % name)
274 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
275 | plt.clf()
276 |
277 |
278 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
279 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
280 | cur_score = num_no_ans
281 | best_score = cur_score
282 | best_thresh = 0.0
283 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
284 | for i, qid in enumerate(qid_list):
285 | if qid not in scores: continue
286 | if qid_to_has_ans[qid]:
287 | diff = scores[qid]
288 | else:
289 | if preds[qid]:
290 | diff = -1
291 | else:
292 | diff = 0
293 | cur_score += diff
294 | if cur_score > best_score:
295 | best_score = cur_score
296 | best_thresh = na_probs[qid]
297 | return 100.0 * best_score / len(scores), best_thresh
298 |
299 |
300 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
301 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
302 | cur_score = num_no_ans
303 | best_score = cur_score
304 | best_thresh = 0.0
305 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
306 | for i, qid in enumerate(qid_list):
307 | if qid not in scores: continue
308 | if qid_to_has_ans[qid]:
309 | diff = scores[qid]
310 | else:
311 | if preds[qid]:
312 | diff = -1
313 | else:
314 | diff = 0
315 | cur_score += diff
316 | if cur_score > best_score:
317 | best_score = cur_score
318 | best_thresh = na_probs[qid]
319 |
320 | has_ans_score, has_ans_cnt = 0, 0
321 | for qid in qid_list:
322 | if not qid_to_has_ans[qid]: continue
323 | has_ans_cnt += 1
324 |
325 | if qid not in scores: continue
326 | has_ans_score += scores[qid]
327 |
328 | return 100.0 * best_score / len(
329 | scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
330 |
331 |
332 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
333 | qid_to_has_ans):
334 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs,
335 | qid_to_has_ans)
336 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs,
337 | qid_to_has_ans)
338 | main_eval['best_exact'] = best_exact
339 | main_eval['best_exact_thresh'] = exact_thresh
340 | main_eval['best_f1'] = best_f1
341 | main_eval['best_f1_thresh'] = f1_thresh
342 |
343 |
344 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs,
345 | qid_to_has_ans):
346 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
347 | preds, exact_raw, na_probs, qid_to_has_ans)
348 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
349 | preds, f1_raw, na_probs, qid_to_has_ans)
350 | main_eval['best_exact'] = best_exact
351 | main_eval['best_exact_thresh'] = exact_thresh
352 | main_eval['best_f1'] = best_f1
353 | main_eval['best_f1_thresh'] = f1_thresh
354 | main_eval['has_ans_exact'] = has_ans_exact
355 | main_eval['has_ans_f1'] = has_ans_f1
356 |
357 |
358 | def main(OPTS):
359 | with open(OPTS.data_file) as f:
360 | dataset_json = json.load(f)
361 | dataset = dataset_json['data']
362 | with open(OPTS.pred_file) as f:
363 | preds = json.load(f)
364 | if OPTS.na_prob_file:
365 | with open(OPTS.na_prob_file) as f:
366 | na_probs = json.load(f)
367 | else:
368 | na_probs = {k: 0.0 for k in preds}
369 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
370 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
371 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
372 | exact_raw, f1_raw = get_raw_scores(dataset, preds)
373 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
374 | OPTS.na_prob_thresh)
375 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
376 | OPTS.na_prob_thresh)
377 | out_eval = make_eval_dict(exact_thresh, f1_thresh)
378 | if has_ans_qids:
379 | has_ans_eval = make_eval_dict(
380 | exact_thresh, f1_thresh, qid_list=has_ans_qids)
381 | merge_eval(out_eval, has_ans_eval, 'HasAns')
382 | if no_ans_qids:
383 | no_ans_eval = make_eval_dict(
384 | exact_thresh, f1_thresh, qid_list=no_ans_qids)
385 | merge_eval(out_eval, no_ans_eval, 'NoAns')
386 | if OPTS.na_prob_file:
387 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs,
388 | qid_to_has_ans)
389 | if OPTS.na_prob_file and OPTS.out_image_dir:
390 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
391 | qid_to_has_ans, OPTS.out_image_dir)
392 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
393 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
394 | if OPTS.out_file:
395 | with open(OPTS.out_file, 'w') as f:
396 | json.dump(out_eval, f)
397 | else:
398 | print(json.dumps(out_eval, indent=2))
399 | return out_eval
400 |
401 |
402 | if __name__ == '__main__':
403 | OPTS = parse_args()
404 | if OPTS.out_image_dir:
405 | import matplotlib
406 | matplotlib.use('Agg')
407 | import matplotlib.pyplot as plt
408 | main(OPTS)
409 |
--------------------------------------------------------------------------------
/bert_qa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import argparse
21 | import logging
22 | import os
23 | import random
24 | import glob
25 |
26 | import numpy as np
27 | import torch
28 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
29 | TensorDataset)
30 | from torch.utils.data.distributed import DistributedSampler
31 | from tqdm import tqdm, trange
32 | from tensorboardX import SummaryWriter
33 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
34 | BertForQuestionAnswering, BertTokenizer,
35 | XLMConfig, XLMForQuestionAnswering,
36 | XLMTokenizer, XLNetConfig,
37 | XLNetForQuestionAnswering,
38 | XLNetTokenizer)
39 |
40 | from pytorch_transformers import AdamW, WarmupLinearSchedule
41 |
42 | from utils_squad import (read_squad_examples, convert_examples_to_features,
43 | RawResult, write_predictions, write_predictions_extended)
44 |
45 | # The follwing import is the official SQuAD evaluation script (2.0).
46 | # You can remove it from the dependencies if you are using this script outside of the library
47 | # We've added it here for automated tests (see examples/test_examples.py file)
48 | from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
49 |
50 | logger = logging.getLogger(__name__)
51 |
52 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
53 | for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
54 |
55 | MODEL_CLASSES = {
56 | 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
57 | 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
58 | 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
59 | }
60 |
61 | def set_seed(args):
62 | random.seed(args.seed)
63 | np.random.seed(args.seed)
64 | torch.manual_seed(args.seed)
65 | if args.n_gpu > 0:
66 | torch.cuda.manual_seed_all(args.seed)
67 |
68 | def to_list(tensor):
69 | return tensor.detach().cpu().tolist()
70 |
71 | def train(args, train_dataset, model, tokenizer):
72 | """ Train the model """
73 | if args.local_rank in [-1, 0]:
74 | tb_writer = SummaryWriter()
75 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
76 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
77 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
78 |
79 | if args.max_steps > 0:
80 | t_total = args.max_steps
81 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
82 | else:
83 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
84 |
85 | # Prepare optimizer and schedule (linear warmup and decay)
86 | no_decay = ['bias', 'LayerNorm.weight']
87 | optimizer_grouped_parameters = [
88 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
89 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
90 | ]
91 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
92 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
93 | if args.fp16:
94 | try:
95 | from apex import amp
96 | except ImportError:
97 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
98 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
99 |
100 | # multi-gpu training (should be after apex fp16 initialization)
101 | if args.n_gpu > 1:
102 | model = torch.nn.DataParallel(model)
103 |
104 | # Distributed training (should be after apex fp16 initialization)
105 | if args.local_rank != -1:
106 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
107 | output_device=args.local_rank,
108 | find_unused_parameters=True)
109 |
110 | # Train!
111 | logger.info("***** Running training *****")
112 | logger.info(" Num examples = %d", len(train_dataset))
113 | logger.info(" Num Epochs = %d", args.num_train_epochs)
114 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
115 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
116 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
117 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
118 | logger.info(" Total optimization steps = %d", t_total)
119 |
120 | global_step = 0
121 | tr_loss, logging_loss = 0.0, 0.0
122 | model.zero_grad()
123 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
124 | set_seed(args) # Added here for reproductibility (even between python 2 and 3)
125 | for _ in train_iterator:
126 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
127 | for step, batch in enumerate(epoch_iterator):
128 | model.train()
129 | batch = tuple(t.to(args.device) for t in batch)
130 | inputs = {'input_ids': batch[0],
131 | 'attention_mask': batch[1],
132 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2],
133 | 'start_positions': batch[3],
134 | 'end_positions': batch[4]}
135 | if args.model_type in ['xlnet', 'xlm']:
136 | inputs.update({'cls_index': batch[5],
137 | 'p_mask': batch[6]})
138 | outputs = model(**inputs)
139 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
140 |
141 | if args.n_gpu > 1:
142 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
143 | if args.gradient_accumulation_steps > 1:
144 | loss = loss / args.gradient_accumulation_steps
145 |
146 | if args.fp16:
147 | with amp.scale_loss(loss, optimizer) as scaled_loss:
148 | scaled_loss.backward()
149 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
150 | else:
151 | loss.backward()
152 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
153 |
154 | tr_loss += loss.item()
155 | if (step + 1) % args.gradient_accumulation_steps == 0:
156 | optimizer.step()
157 | scheduler.step() # Update learning rate schedule
158 | model.zero_grad()
159 | global_step += 1
160 |
161 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
162 | # Log metrics
163 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
164 | results = evaluate(args, model, tokenizer)
165 | for key, value in results.items():
166 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
167 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
168 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
169 | logging_loss = tr_loss
170 |
171 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
172 | # Save model checkpoint
173 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
174 | if not os.path.exists(output_dir):
175 | os.makedirs(output_dir)
176 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
177 | model_to_save.save_pretrained(output_dir)
178 | torch.save(args, os.path.join(output_dir, 'training_args.bin'))
179 | logger.info("Saving model checkpoint to %s", output_dir)
180 |
181 | if args.max_steps > 0 and global_step > args.max_steps:
182 | epoch_iterator.close()
183 | break
184 | if args.max_steps > 0 and global_step > args.max_steps:
185 | train_iterator.close()
186 | break
187 |
188 | if args.local_rank in [-1, 0]:
189 | tb_writer.close()
190 |
191 | return global_step, tr_loss / global_step
192 |
193 |
194 | def evaluate(args, model, tokenizer, prefix=""):
195 | dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
196 |
197 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
198 | os.makedirs(args.output_dir)
199 |
200 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
201 | # Note that DistributedSampler samples randomly
202 | eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
203 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
204 |
205 | # Eval!
206 | logger.info("***** Running evaluation {} *****".format(prefix))
207 | logger.info(" Num examples = %d", len(dataset))
208 | logger.info(" Batch size = %d", args.eval_batch_size)
209 | all_results = []
210 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
211 | model.eval()
212 | batch = tuple(t.to(args.device) for t in batch)
213 | with torch.no_grad():
214 | inputs = {'input_ids': batch[0],
215 | 'attention_mask': batch[1],
216 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
217 | }
218 | example_indices = batch[3]
219 | if args.model_type in ['xlnet', 'xlm']:
220 | inputs.update({'cls_index': batch[4],
221 | 'p_mask': batch[5]})
222 | outputs = model(**inputs)
223 |
224 | for i, example_index in enumerate(example_indices):
225 | eval_feature = features[example_index.item()]
226 | unique_id = int(eval_feature.unique_id)
227 | if args.model_type in ['xlnet', 'xlm']:
228 | # XLNet uses a more complex post-processing procedure
229 | result = RawResultExtended(unique_id = unique_id,
230 | start_top_log_probs = to_list(outputs[0][i]),
231 | start_top_index = to_list(outputs[1][i]),
232 | end_top_log_probs = to_list(outputs[2][i]),
233 | end_top_index = to_list(outputs[3][i]),
234 | cls_logits = to_list(outputs[4][i]))
235 | else:
236 | result = RawResult(unique_id = unique_id,
237 | start_logits = to_list(outputs[0][i]),
238 | end_logits = to_list(outputs[1][i]))
239 | all_results.append(result)
240 |
241 | # Compute predictions
242 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
243 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
244 | if args.version_2_with_negative:
245 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
246 | else:
247 | output_null_log_odds_file = None
248 |
249 | if args.model_type in ['xlnet', 'xlm']:
250 | # XLNet uses a more complex post-processing procedure
251 | write_predictions_extended(examples, features, all_results, args.n_best_size,
252 | args.max_answer_length, output_prediction_file,
253 | output_nbest_file, output_null_log_odds_file, args.predict_file,
254 | model.config.start_n_top, model.config.end_n_top,
255 | args.version_2_with_negative, tokenizer, args.verbose_logging)
256 | else:
257 | write_predictions(examples, features, all_results, args.n_best_size,
258 | args.max_answer_length, args.do_lower_case, output_prediction_file,
259 | output_nbest_file, output_null_log_odds_file, args.verbose_logging,
260 | args.version_2_with_negative, args.null_score_diff_threshold)
261 |
262 | # Evaluate with the official SQuAD script
263 | evaluate_options = EVAL_OPTS(data_file=args.predict_file,
264 | pred_file=output_prediction_file,
265 | na_prob_file=output_null_log_odds_file)
266 | results = evaluate_on_squad(evaluate_options)
267 | return results
268 |
269 |
270 | def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
271 | if args.local_rank not in [-1, 0] and not evaluate:
272 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
273 |
274 | # Load data features from cache or dataset file
275 | input_file = args.predict_file if evaluate else args.train_file
276 | cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
277 | 'dev' if evaluate else 'train',
278 | list(filter(None, args.model_name_or_path.split('/'))).pop(),
279 | str(args.max_seq_length)))
280 | if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
281 | logger.info("Loading features from cached file %s", cached_features_file)
282 | features = torch.load(cached_features_file)
283 | else:
284 | logger.info("Creating features from dataset file at %s", input_file)
285 | examples = read_squad_examples(input_file=input_file,
286 | is_training=not evaluate,
287 | version_2_with_negative=args.version_2_with_negative)
288 | features = convert_examples_to_features(examples=examples,
289 | tokenizer=tokenizer,
290 | max_seq_length=args.max_seq_length,
291 | doc_stride=args.doc_stride,
292 | max_query_length=args.max_query_length,
293 | is_training=not evaluate)
294 | if args.local_rank in [-1, 0]:
295 | logger.info("Saving features into cached file %s", cached_features_file)
296 | torch.save(features, cached_features_file)
297 |
298 | if args.local_rank == 0 and not evaluate:
299 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
300 |
301 | # Convert to Tensors and build dataset
302 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
303 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
304 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
305 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
306 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
307 | if evaluate:
308 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
309 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
310 | all_example_index, all_cls_index, all_p_mask)
311 | else:
312 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
313 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
314 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
315 | all_start_positions, all_end_positions,
316 | all_cls_index, all_p_mask)
317 |
318 | if output_examples:
319 | return dataset, examples, features
320 | return dataset
321 |
322 |
323 | def main():
324 | parser = argparse.ArgumentParser()
325 |
326 | ## Required parameters
327 | parser.add_argument("--train_file", default=None, type=str, required=True,
328 | help="SQuAD json for training. E.g., train-v1.1.json")
329 | parser.add_argument("--predict_file", default=None, type=str, required=True,
330 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
331 | parser.add_argument("--model_type", default=None, type=str, required=True,
332 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
333 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
334 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
335 | parser.add_argument("--output_dir", default=None, type=str, required=True,
336 | help="The output directory where the model checkpoints and predictions will be written.")
337 |
338 | ## Other parameters
339 | parser.add_argument("--config_name", default="", type=str,
340 | help="Pretrained config name or path if not the same as model_name")
341 | parser.add_argument("--tokenizer_name", default="", type=str,
342 | help="Pretrained tokenizer name or path if not the same as model_name")
343 | parser.add_argument("--cache_dir", default="", type=str,
344 | help="Where do you want to store the pre-trained models downloaded from s3")
345 |
346 | parser.add_argument('--version_2_with_negative', action='store_true',
347 | help='If true, the SQuAD examples contain some that do not have an answer.')
348 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
349 | help="If null_score - best_non_null is greater than the threshold predict null.")
350 |
351 | parser.add_argument("--max_seq_length", default=384, type=int,
352 | help="The maximum total input sequence length after WordPiece tokenization. Sequences "
353 | "longer than this will be truncated, and sequences shorter than this will be padded.")
354 | parser.add_argument("--doc_stride", default=128, type=int,
355 | help="When splitting up a long document into chunks, how much stride to take between chunks.")
356 | parser.add_argument("--max_query_length", default=64, type=int,
357 | help="The maximum number of tokens for the question. Questions longer than this will "
358 | "be truncated to this length.")
359 | parser.add_argument("--do_train", action='store_true',
360 | help="Whether to run training.")
361 | parser.add_argument("--do_eval", action='store_true',
362 | help="Whether to run eval on the dev set.")
363 | parser.add_argument("--evaluate_during_training", action='store_true',
364 | help="Rul evaluation during training at each logging step.")
365 | parser.add_argument("--do_lower_case", action='store_true',
366 | help="Set this flag if you are using an uncased model.")
367 |
368 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
369 | help="Batch size per GPU/CPU for training.")
370 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
371 | help="Batch size per GPU/CPU for evaluation.")
372 | parser.add_argument("--learning_rate", default=5e-5, type=float,
373 | help="The initial learning rate for Adam.")
374 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
375 | help="Number of updates steps to accumulate before performing a backward/update pass.")
376 | parser.add_argument("--weight_decay", default=0.0, type=float,
377 | help="Weight deay if we apply some.")
378 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
379 | help="Epsilon for Adam optimizer.")
380 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
381 | help="Max gradient norm.")
382 | parser.add_argument("--num_train_epochs", default=3.0, type=float,
383 | help="Total number of training epochs to perform.")
384 | parser.add_argument("--max_steps", default=-1, type=int,
385 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
386 | parser.add_argument("--warmup_steps", default=0, type=int,
387 | help="Linear warmup over warmup_steps.")
388 | parser.add_argument("--n_best_size", default=20, type=int,
389 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
390 | parser.add_argument("--max_answer_length", default=30, type=int,
391 | help="The maximum length of an answer that can be generated. This is needed because the start "
392 | "and end predictions are not conditioned on one another.")
393 | parser.add_argument("--verbose_logging", action='store_true',
394 | help="If true, all of the warnings related to data processing will be printed. "
395 | "A number of warnings are expected for a normal SQuAD evaluation.")
396 |
397 | parser.add_argument('--logging_steps', type=int, default=50,
398 | help="Log every X updates steps.")
399 | parser.add_argument('--save_steps', type=int, default=50,
400 | help="Save checkpoint every X updates steps.")
401 | parser.add_argument("--eval_all_checkpoints", action='store_true',
402 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
403 | parser.add_argument("--no_cuda", action='store_true',
404 | help="Whether not to use CUDA when available")
405 | parser.add_argument('--overwrite_output_dir', action='store_true',
406 | help="Overwrite the content of the output directory")
407 | parser.add_argument('--overwrite_cache', action='store_true',
408 | help="Overwrite the cached training and evaluation sets")
409 | parser.add_argument('--seed', type=int, default=42,
410 | help="random seed for initialization")
411 |
412 | parser.add_argument("--local_rank", type=int, default=-1,
413 | help="local_rank for distributed training on gpus")
414 | parser.add_argument('--fp16', action='store_true',
415 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
416 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
417 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
418 | "See details at https://nvidia.github.io/apex/amp.html")
419 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
420 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
421 | args = parser.parse_args()
422 |
423 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
424 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
425 |
426 | # Setup distant debugging if needed
427 | if args.server_ip and args.server_port:
428 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
429 | import ptvsd
430 | print("Waiting for debugger attach")
431 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
432 | ptvsd.wait_for_attach()
433 |
434 | # Setup CUDA, GPU & distributed training
435 | if args.local_rank == -1 or args.no_cuda:
436 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
437 | args.n_gpu = torch.cuda.device_count()
438 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
439 | torch.cuda.set_device(args.local_rank)
440 | device = torch.device("cuda", args.local_rank)
441 | torch.distributed.init_process_group(backend='nccl')
442 | args.n_gpu = 1
443 | args.device = device
444 |
445 | # Setup logging
446 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
447 | datefmt = '%m/%d/%Y %H:%M:%S',
448 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
449 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
450 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
451 |
452 | # Set seed
453 | set_seed(args)
454 |
455 | # Load pretrained model and tokenizer
456 | if args.local_rank not in [-1, 0]:
457 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
458 |
459 | args.model_type = args.model_type.lower()
460 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
461 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
462 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
463 |
464 | model = model_class.from_pretrained("bert-base-chinese", config=config)
465 | if args.local_rank == 0:
466 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
467 |
468 | model.to(args.device)
469 |
470 | logger.info("Training/evaluation parameters %s", args)
471 |
472 | # Training
473 | if args.do_train:
474 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
475 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
476 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
477 |
478 |
479 | # Save the trained model and the tokenizer
480 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
481 | # Create output directory if needed
482 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
483 | os.makedirs(args.output_dir)
484 |
485 | logger.info("Saving model checkpoint to %s", args.output_dir)
486 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
487 | model_to_save.save_pretrained(args.output_dir)
488 | tokenizer.save_pretrained(args.output_dir)
489 |
490 | # Good practice: save your training arguments together with the trained model
491 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
492 |
493 | # Load a trained model and vocabulary that you have fine-tuned
494 | model = model_class.from_pretrained(args.output_dir)
495 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
496 | model.to(args.device)
497 |
498 |
499 | # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
500 | results = {}
501 | if args.do_eval and args.local_rank in [-1, 0]:
502 | checkpoints = [args.output_dir]
503 | if args.eval_all_checkpoints:
504 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
505 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
506 |
507 | logger.info("Evaluate the following checkpoints: %s", checkpoints)
508 |
509 | for checkpoint in checkpoints:
510 | # Reload the model
511 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
512 | model = model_class.from_pretrained(checkpoint)
513 | model.to(args.device)
514 |
515 | # Evaluate
516 | result = evaluate(args, model, tokenizer, prefix=global_step)
517 |
518 | result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
519 | results.update(result)
520 |
521 | logger.info("Results: {}".format(results))
522 |
523 | return results
524 |
525 |
526 | if __name__ == "__main__":
527 | main()
--------------------------------------------------------------------------------
/utils_squad.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Load SQuAD dataset. """
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import json
21 | import logging
22 | import math
23 | import collections
24 | from io import open
25 |
26 | from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
27 |
28 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
29 | from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | class SquadExample(object):
35 | """
36 | A single training/test example for the Squad dataset.
37 | For examples without an answer, the start and end position are -1.
38 | """
39 |
40 | def __init__(self,
41 | qas_id,
42 | question_text,
43 | doc_tokens,
44 | orig_answer_text=None,
45 | start_position=None,
46 | end_position=None,
47 | is_impossible=None):
48 | self.qas_id = qas_id
49 | self.question_text = question_text
50 | self.doc_tokens = doc_tokens
51 | self.orig_answer_text = orig_answer_text
52 | self.start_position = start_position
53 | self.end_position = end_position
54 | self.is_impossible = is_impossible
55 |
56 | def __str__(self):
57 | return self.__repr__()
58 |
59 | def __repr__(self):
60 | s = ""
61 | s += "qas_id: %s" % (self.qas_id)
62 | s += ", question_text: %s" % (self.question_text)
63 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
64 | if self.start_position:
65 | s += ", start_position: %d" % (self.start_position)
66 | if self.end_position:
67 | s += ", end_position: %d" % (self.end_position)
68 | if self.is_impossible:
69 | s += ", is_impossible: %r" % (self.is_impossible)
70 | return s
71 |
72 |
73 | class InputFeatures(object):
74 | """A single set of features of data."""
75 |
76 | def __init__(self,
77 | unique_id,
78 | example_index,
79 | doc_span_index,
80 | tokens,
81 | token_to_orig_map,
82 | token_is_max_context,
83 | input_ids,
84 | input_mask,
85 | segment_ids,
86 | cls_index,
87 | p_mask,
88 | paragraph_len,
89 | start_position=None,
90 | end_position=None,
91 | is_impossible=None):
92 | self.unique_id = unique_id
93 | self.example_index = example_index
94 | self.doc_span_index = doc_span_index
95 | self.tokens = tokens
96 | self.token_to_orig_map = token_to_orig_map
97 | self.token_is_max_context = token_is_max_context
98 | self.input_ids = input_ids
99 | self.input_mask = input_mask
100 | self.segment_ids = segment_ids
101 | self.cls_index = cls_index
102 | self.p_mask = p_mask
103 | self.paragraph_len = paragraph_len
104 | self.start_position = start_position
105 | self.end_position = end_position
106 | self.is_impossible = is_impossible
107 |
108 |
109 | def read_squad_examples(input_file, is_training, version_2_with_negative):
110 | """Read a SQuAD json file into a list of SquadExample."""
111 | with open(input_file, "r", encoding='utf-8') as reader:
112 | input_data = json.load(reader)["data"]
113 |
114 | def is_whitespace(c):
115 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
116 | return True
117 | return False
118 |
119 | def is_english_or_number(c):
120 | return (ord(c) > 64 and ord(c) < 91) or (ord(c) < 123 and ord(c) > 96)
121 |
122 | examples = []
123 | for entry in input_data:
124 | for paragraph in entry["paragraphs"]:
125 | paragraph_text = paragraph["context"]
126 | doc_tokens = []
127 | char_to_word_offset = []
128 | prev_is_whitespace = True
129 | for c in paragraph_text:
130 | if is_whitespace(c):
131 | continue
132 | doc_tokens.append(c)
133 | char_to_word_offset.append(len(doc_tokens) - 1)
134 |
135 | for qa in paragraph["qas"]:
136 | qas_id = qa["id"]
137 | question_text = qa["question"]
138 | start_position = None
139 | end_position = None
140 | orig_answer_text = None
141 | is_impossible = False
142 | if is_training:
143 | if (len(qa["answers"]) != 1) and (not is_impossible):
144 | raise ValueError(
145 | "For training, each question should have exactly 1 answer."
146 | )
147 | answer = qa["answers"][0]
148 | orig_answer_text = answer["text"]
149 | answer_offset = answer["answer_start"]
150 | answer_length = len(orig_answer_text)
151 | if answer_offset > len(char_to_word_offset) - 1:
152 | logger.warning("样本错误: '%s' offfset vs. length'%s'",
153 | answer_offset, len(char_to_word_offset))
154 | continue
155 | start_position = char_to_word_offset[answer_offset]
156 | end_position = answer_offset + answer_length - 1
157 | if end_position > len(char_to_word_offset) - 1:
158 | logger.warning("样本错误: '%s' vs. '%s'", end_position, len(char_to_word_offset))
159 | continue
160 | end_position = char_to_word_offset[answer_offset +
161 | answer_length - 1]
162 | # Only add answers where the text can be exactly recovered from the
163 | # document. If this CAN'T happen it's likely due to weird Unicode
164 | # stuff so we will just skip the example.
165 | #
166 | # Note that this means for training mode, every example is NOT
167 | # guaranteed to be preserved.
168 | actual_text = "".join(
169 | doc_tokens[start_position:(end_position + 1)])
170 | cleaned_answer_text = "".join(
171 | whitespace_tokenize(orig_answer_text))
172 | if actual_text.find(cleaned_answer_text) == -1:
173 | logger.warning("样本错误: '%s' vs. '%s'", actual_text,
174 | cleaned_answer_text)
175 | continue
176 |
177 | example = SquadExample(
178 | qas_id=qas_id,
179 | question_text=question_text,
180 | doc_tokens=doc_tokens,
181 | orig_answer_text=orig_answer_text,
182 | start_position=start_position,
183 | end_position=end_position,
184 | is_impossible=is_impossible)
185 | examples.append(example)
186 | return examples
187 |
188 |
189 | def convert_examples_to_features(examples,
190 | tokenizer,
191 | max_seq_length,
192 | doc_stride,
193 | max_query_length,
194 | is_training,
195 | cls_token_at_end=False,
196 | cls_token='[CLS]',
197 | sep_token='[SEP]',
198 | pad_token=0,
199 | sequence_a_segment_id=0,
200 | sequence_b_segment_id=1,
201 | cls_token_segment_id=0,
202 | pad_token_segment_id=0,
203 | mask_padding_with_zero=True):
204 | """Loads a data file into a list of `InputBatch`s."""
205 |
206 | unique_id = 1000000000
207 | # cnt_pos, cnt_neg = 0, 0
208 | # max_N, max_M = 1024, 1024
209 | # f = np.zeros((max_N, max_M), dtype=np.float32)
210 |
211 | features = []
212 | for (example_index, example) in enumerate(examples):
213 |
214 | # if example_index % 100 == 0:
215 | # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
216 |
217 | query_tokens = tokenizer.tokenize(example.question_text)
218 |
219 | if len(query_tokens) > max_query_length:
220 | query_tokens = query_tokens[0:max_query_length]
221 |
222 | tok_to_orig_index = []
223 | orig_to_tok_index = []
224 | all_doc_tokens = []
225 | for (i, token) in enumerate(example.doc_tokens):
226 | orig_to_tok_index.append(len(all_doc_tokens))
227 | sub_tokens = tokenizer.tokenize(token)
228 | for sub_token in sub_tokens:
229 | tok_to_orig_index.append(i)
230 | all_doc_tokens.append(sub_token)
231 |
232 | tok_start_position = None
233 | tok_end_position = None
234 | if is_training and example.is_impossible:
235 | tok_start_position = -1
236 | tok_end_position = -1
237 | if is_training and not example.is_impossible:
238 | tok_start_position = orig_to_tok_index[example.start_position]
239 | if example.end_position < len(example.doc_tokens) - 1:
240 | tok_end_position = orig_to_tok_index[example.end_position +
241 | 1] - 1
242 | else:
243 | tok_end_position = len(all_doc_tokens) - 1
244 | (tok_start_position, tok_end_position) = _improve_answer_span(
245 | all_doc_tokens, tok_start_position, tok_end_position,
246 | tokenizer, example.orig_answer_text)
247 |
248 | # The -3 accounts for [CLS], [SEP] and [SEP]
249 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
250 |
251 | # We can have documents that are longer than the maximum sequence length.
252 | # To deal with this we do a sliding window approach, where we take chunks
253 | # of the up to our max length with a stride of `doc_stride`.
254 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
255 | "DocSpan", ["start", "length"])
256 | doc_spans = []
257 | start_offset = 0
258 | while start_offset < len(all_doc_tokens):
259 | length = len(all_doc_tokens) - start_offset
260 | if length > max_tokens_for_doc:
261 | length = max_tokens_for_doc
262 | doc_spans.append(_DocSpan(start=start_offset, length=length))
263 | if start_offset + length == len(all_doc_tokens):
264 | break
265 | start_offset += min(length, doc_stride)
266 |
267 | for (doc_span_index, doc_span) in enumerate(doc_spans):
268 | tokens = []
269 | token_to_orig_map = {}
270 | token_is_max_context = {}
271 | segment_ids = []
272 |
273 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
274 | # Original TF implem also keep the classification token (set to 0) (not sure why...)
275 | p_mask = []
276 |
277 | # CLS token at the beginning
278 | if not cls_token_at_end:
279 | tokens.append(cls_token)
280 | segment_ids.append(cls_token_segment_id)
281 | p_mask.append(0)
282 | cls_index = 0
283 |
284 | # Query
285 | for token in query_tokens:
286 | tokens.append(token)
287 | segment_ids.append(sequence_a_segment_id)
288 | p_mask.append(1)
289 |
290 | # SEP token
291 | tokens.append(sep_token)
292 | segment_ids.append(sequence_a_segment_id)
293 | p_mask.append(1)
294 |
295 | # Paragraph
296 | for i in range(doc_span.length):
297 | split_token_index = doc_span.start + i
298 | token_to_orig_map[len(
299 | tokens)] = tok_to_orig_index[split_token_index]
300 |
301 | is_max_context = _check_is_max_context(
302 | doc_spans, doc_span_index, split_token_index)
303 | token_is_max_context[len(tokens)] = is_max_context
304 | tokens.append(all_doc_tokens[split_token_index])
305 | segment_ids.append(sequence_b_segment_id)
306 | p_mask.append(0)
307 | paragraph_len = doc_span.length
308 |
309 | # SEP token
310 | tokens.append(sep_token)
311 | segment_ids.append(sequence_b_segment_id)
312 | p_mask.append(1)
313 |
314 | # CLS token at the end
315 | if cls_token_at_end:
316 | tokens.append(cls_token)
317 | segment_ids.append(cls_token_segment_id)
318 | p_mask.append(0)
319 | cls_index = len(tokens) - 1 # Index of classification token
320 |
321 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
322 |
323 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
324 | # tokens are attended to.
325 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
326 |
327 | # Zero-pad up to the sequence length.
328 | while len(input_ids) < max_seq_length:
329 | input_ids.append(pad_token)
330 | input_mask.append(0 if mask_padding_with_zero else 1)
331 | segment_ids.append(pad_token_segment_id)
332 | p_mask.append(1)
333 |
334 | assert len(input_ids) == max_seq_length
335 | assert len(input_mask) == max_seq_length
336 | assert len(segment_ids) == max_seq_length
337 |
338 | span_is_impossible = example.is_impossible
339 | start_position = None
340 | end_position = None
341 | if is_training and not span_is_impossible:
342 | # For training, if our document chunk does not contain an annotation
343 | # we throw it out, since there is nothing to predict.
344 | doc_start = doc_span.start
345 | doc_end = doc_span.start + doc_span.length - 1
346 | out_of_span = False
347 | if not (tok_start_position >= doc_start
348 | and tok_end_position <= doc_end):
349 | out_of_span = True
350 | if out_of_span:
351 | start_position = 0
352 | end_position = 0
353 | span_is_impossible = True
354 | else:
355 | doc_offset = len(query_tokens) + 2
356 | start_position = tok_start_position - doc_start + doc_offset
357 | end_position = tok_end_position - doc_start + doc_offset
358 |
359 | if is_training and span_is_impossible:
360 | start_position = cls_index
361 | end_position = cls_index
362 |
363 | if example_index < 20:
364 | logger.info("*** Example ***")
365 | logger.info("unique_id: %s" % (unique_id))
366 | logger.info("example_index: %s" % (example_index))
367 | logger.info("doc_span_index: %s" % (doc_span_index))
368 | logger.info("tokens: %s" % " ".join(tokens))
369 | logger.info("token_to_orig_map: %s" % " ".join(
370 | ["%d:%d" % (x, y)
371 | for (x, y) in token_to_orig_map.items()]))
372 | logger.info("token_is_max_context: %s" % " ".join([
373 | "%d:%s" % (x, y)
374 | for (x, y) in token_is_max_context.items()
375 | ]))
376 | logger.info(
377 | "input_ids: %s" % " ".join([str(x) for x in input_ids]))
378 | logger.info(
379 | "input_mask: %s" % " ".join([str(x) for x in input_mask]))
380 | logger.info(
381 | "segment_ids: %s" % " ".join([str(x)
382 | for x in segment_ids]))
383 | if is_training and span_is_impossible:
384 | logger.info("impossible example")
385 | if is_training and not span_is_impossible:
386 | answer_text = " ".join(
387 | tokens[start_position:(end_position + 1)])
388 | logger.info("start_position: %d" % (start_position))
389 | logger.info("end_position: %d" % (end_position))
390 | logger.info("answer: %s" % (answer_text))
391 |
392 | features.append(
393 | InputFeatures(
394 | unique_id=unique_id,
395 | example_index=example_index,
396 | doc_span_index=doc_span_index,
397 | tokens=tokens,
398 | token_to_orig_map=token_to_orig_map,
399 | token_is_max_context=token_is_max_context,
400 | input_ids=input_ids,
401 | input_mask=input_mask,
402 | segment_ids=segment_ids,
403 | cls_index=cls_index,
404 | p_mask=p_mask,
405 | paragraph_len=paragraph_len,
406 | start_position=start_position,
407 | end_position=end_position,
408 | is_impossible=span_is_impossible))
409 | unique_id += 1
410 |
411 | return features
412 |
413 |
414 | # convert_examples_to_features(examp)
415 |
416 |
417 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
418 | orig_answer_text):
419 | """Returns tokenized answer spans that better match the annotated answer."""
420 |
421 | # The SQuAD annotations are character based. We first project them to
422 | # whitespace-tokenized words. But then after WordPiece tokenization, we can
423 | # often find a "better match". For example:
424 | #
425 | # Question: What year was John Smith born?
426 | # Context: The leader was John Smith (1895-1943).
427 | # Answer: 1895
428 | #
429 | # The original whitespace-tokenized answer will be "(1895-1943).". However
430 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
431 | # the exact answer, 1895.
432 | #
433 | # However, this is not always possible. Consider the following:
434 | #
435 | # Question: What country is the top exporter of electornics?
436 | # Context: The Japanese electronics industry is the lagest in the world.
437 | # Answer: Japan
438 | #
439 | # In this case, the annotator chose "Japan" as a character sub-span of
440 | # the word "Japanese". Since our WordPiece tokenizer does not split
441 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
442 | # in SQuAD, but does happen.
443 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
444 |
445 | for new_start in range(input_start, input_end + 1):
446 | for new_end in range(input_end, new_start - 1, -1):
447 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
448 | if text_span == tok_answer_text:
449 | return (new_start, new_end)
450 |
451 | return (input_start, input_end)
452 |
453 |
454 | def _check_is_max_context(doc_spans, cur_span_index, position):
455 | """Check if this is the 'max context' doc span for the token."""
456 |
457 | # Because of the sliding window approach taken to scoring documents, a single
458 | # token can appear in multiple documents. E.g.
459 | # Doc: the man went to the store and bought a gallon of milk
460 | # Span A: the man went to the
461 | # Span B: to the store and bought
462 | # Span C: and bought a gallon of
463 | # ...
464 | #
465 | # Now the word 'bought' will have two scores from spans B and C. We only
466 | # want to consider the score with "maximum context", which we define as
467 | # the *minimum* of its left and right context (the *sum* of left and
468 | # right context will always be the same, of course).
469 | #
470 | # In the example the maximum context for 'bought' would be span C since
471 | # it has 1 left context and 3 right context, while span B has 4 left context
472 | # and 0 right context.
473 | best_score = None
474 | best_span_index = None
475 | for (span_index, doc_span) in enumerate(doc_spans):
476 | end = doc_span.start + doc_span.length - 1
477 | if position < doc_span.start:
478 | continue
479 | if position > end:
480 | continue
481 | num_left_context = position - doc_span.start
482 | num_right_context = end - position
483 | score = min(num_left_context,
484 | num_right_context) + 0.01 * doc_span.length
485 | if best_score is None or score > best_score:
486 | best_score = score
487 | best_span_index = span_index
488 |
489 | return cur_span_index == best_span_index
490 |
491 |
492 | RawResult = collections.namedtuple("RawResult",
493 | ["unique_id", "start_logits", "end_logits"])
494 |
495 |
496 | def write_predictions(all_examples, all_features, all_results, n_best_size,
497 | max_answer_length, do_lower_case, output_prediction_file,
498 | output_nbest_file, output_null_log_odds_file,
499 | verbose_logging, version_2_with_negative,
500 | null_score_diff_threshold):
501 | """Write final predictions to the json file and log-odds of null if needed."""
502 | logger.info("Writing predictions to: %s" % (output_prediction_file))
503 | logger.info("Writing nbest to: %s" % (output_nbest_file))
504 |
505 | example_index_to_features = collections.defaultdict(list)
506 | for feature in all_features:
507 | example_index_to_features[feature.example_index].append(feature)
508 |
509 | unique_id_to_result = {}
510 | for result in all_results:
511 | unique_id_to_result[result.unique_id] = result
512 |
513 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
514 | "PrelimPrediction", [
515 | "feature_index", "start_index", "end_index", "start_logit",
516 | "end_logit"
517 | ])
518 |
519 | all_predictions = collections.OrderedDict()
520 | all_nbest_json = collections.OrderedDict()
521 | scores_diff_json = collections.OrderedDict()
522 |
523 | for (example_index, example) in enumerate(all_examples):
524 | features = example_index_to_features[example_index]
525 |
526 | prelim_predictions = []
527 | # keep track of the minimum score of null start+end of position 0
528 | score_null = 1000000 # large and positive
529 | min_null_feature_index = 0 # the paragraph slice with min null score
530 | null_start_logit = 0 # the start logit at the slice with min null score
531 | null_end_logit = 0 # the end logit at the slice with min null score
532 | for (feature_index, feature) in enumerate(features):
533 | result = unique_id_to_result[feature.unique_id]
534 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
535 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
536 | # if we could have irrelevant answers, get the min score of irrelevant
537 | if version_2_with_negative:
538 | feature_null_score = result.start_logits[
539 | 0] + result.end_logits[0]
540 | if feature_null_score < score_null:
541 | score_null = feature_null_score
542 | min_null_feature_index = feature_index
543 | null_start_logit = result.start_logits[0]
544 | null_end_logit = result.end_logits[0]
545 | for start_index in start_indexes:
546 | for end_index in end_indexes:
547 | # We could hypothetically create invalid predictions, e.g., predict
548 | # that the start of the span is in the question. We throw out all
549 | # invalid predictions.
550 | if start_index >= len(feature.tokens):
551 | continue
552 | if end_index >= len(feature.tokens):
553 | continue
554 | if start_index not in feature.token_to_orig_map:
555 | continue
556 | if end_index not in feature.token_to_orig_map:
557 | continue
558 | if not feature.token_is_max_context.get(
559 | start_index, False):
560 | continue
561 | if end_index < start_index:
562 | continue
563 | length = end_index - start_index + 1
564 | if length > max_answer_length:
565 | continue
566 | prelim_predictions.append(
567 | _PrelimPrediction(
568 | feature_index=feature_index,
569 | start_index=start_index,
570 | end_index=end_index,
571 | start_logit=result.start_logits[start_index],
572 | end_logit=result.end_logits[end_index]))
573 | if version_2_with_negative:
574 | prelim_predictions.append(
575 | _PrelimPrediction(
576 | feature_index=min_null_feature_index,
577 | start_index=0,
578 | end_index=0,
579 | start_logit=null_start_logit,
580 | end_logit=null_end_logit))
581 | prelim_predictions = sorted(
582 | prelim_predictions,
583 | key=lambda x: (x.start_logit + x.end_logit),
584 | reverse=True)
585 |
586 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
587 | "NbestPrediction", ["text", "start_logit", "end_logit"])
588 |
589 | seen_predictions = {}
590 | nbest = []
591 | for pred in prelim_predictions:
592 | if len(nbest) >= n_best_size:
593 | break
594 | feature = features[pred.feature_index]
595 | if pred.start_index > 0: # this is a non-null prediction
596 | tok_tokens = feature.tokens[pred.start_index:(
597 | pred.end_index + 1)]
598 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
599 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
600 | orig_tokens = example.doc_tokens[orig_doc_start:(
601 | orig_doc_end + 1)]
602 | tok_text = " ".join(tok_tokens)
603 |
604 | # De-tokenize WordPieces that have been split off.
605 | tok_text = tok_text.replace(" ##", "")
606 | tok_text = tok_text.replace("##", "")
607 |
608 | # Clean whitespace
609 | tok_text = tok_text.strip()
610 | tok_text = " ".join(tok_text.split())
611 | orig_text = " ".join(orig_tokens)
612 |
613 | final_text = get_final_text(tok_text, orig_text, do_lower_case,
614 | verbose_logging)
615 | if final_text in seen_predictions:
616 | continue
617 |
618 | seen_predictions[final_text] = True
619 | else:
620 | final_text = ""
621 | seen_predictions[final_text] = True
622 |
623 | nbest.append(
624 | _NbestPrediction(
625 | text=final_text,
626 | start_logit=pred.start_logit,
627 | end_logit=pred.end_logit))
628 | # if we didn't include the empty option in the n-best, include it
629 | if version_2_with_negative:
630 | if "" not in seen_predictions:
631 | nbest.append(
632 | _NbestPrediction(
633 | text="",
634 | start_logit=null_start_logit,
635 | end_logit=null_end_logit))
636 |
637 | # In very rare edge cases we could only have single null prediction.
638 | # So we just create a nonce prediction in this case to avoid failure.
639 | if len(nbest) == 1:
640 | nbest.insert(
641 | 0,
642 | _NbestPrediction(
643 | text="empty", start_logit=0.0, end_logit=0.0))
644 |
645 | # In very rare edge cases we could have no valid predictions. So we
646 | # just create a nonce prediction in this case to avoid failure.
647 | if not nbest:
648 | nbest.append(
649 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
650 |
651 | assert len(nbest) >= 1
652 |
653 | total_scores = []
654 | best_non_null_entry = None
655 | for entry in nbest:
656 | total_scores.append(entry.start_logit + entry.end_logit)
657 | if not best_non_null_entry:
658 | if entry.text:
659 | best_non_null_entry = entry
660 |
661 | probs = _compute_softmax(total_scores)
662 |
663 | nbest_json = []
664 | for (i, entry) in enumerate(nbest):
665 | output = collections.OrderedDict()
666 | output["text"] = entry.text
667 | output["probability"] = probs[i]
668 | output["start_logit"] = entry.start_logit
669 | output["end_logit"] = entry.end_logit
670 | nbest_json.append(output)
671 |
672 | assert len(nbest_json) >= 1
673 |
674 | if not version_2_with_negative:
675 | all_predictions[example.qas_id] = nbest_json[0]["text"]
676 | else:
677 | # predict "" iff the null score - the score of best non-null > threshold
678 | score_diff = score_null - best_non_null_entry.start_logit - (
679 | best_non_null_entry.end_logit)
680 | scores_diff_json[example.qas_id] = score_diff
681 | if score_diff > null_score_diff_threshold:
682 | all_predictions[example.qas_id] = ""
683 | else:
684 | all_predictions[example.qas_id] = best_non_null_entry.text
685 | all_nbest_json[example.qas_id] = nbest_json
686 |
687 | with open(output_prediction_file, "w") as writer:
688 | writer.write(json.dumps(all_predictions, indent=4) + "\n")
689 |
690 | with open(output_nbest_file, "w") as writer:
691 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
692 |
693 | if version_2_with_negative:
694 | with open(output_null_log_odds_file, "w") as writer:
695 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
696 |
697 | return all_predictions
698 |
699 |
700 | def write_predictions_extended(
701 | all_examples, all_features, all_results, n_best_size,
702 | max_answer_length, output_prediction_file, output_nbest_file,
703 | output_null_log_odds_file, orig_data_file, start_n_top, end_n_top,
704 | version_2_with_negative, tokenizer, verbose_logging):
705 | """ XLNet write prediction logic (more complex than Bert's).
706 | Write final predictions to the json file and log-odds of null if needed.
707 | Requires utils_squad_evaluate.py
708 | """
709 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
710 | "PrelimPrediction", [
711 | "feature_index", "start_index", "end_index", "start_log_prob",
712 | "end_log_prob"
713 | ])
714 |
715 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
716 | "NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
717 |
718 | logger.info("Writing predictions to: %s", output_prediction_file)
719 | # logger.info("Writing nbest to: %s" % (output_nbest_file))
720 |
721 | example_index_to_features = collections.defaultdict(list)
722 | for feature in all_features:
723 | example_index_to_features[feature.example_index].append(feature)
724 |
725 | unique_id_to_result = {}
726 | for result in all_results:
727 | unique_id_to_result[result.unique_id] = result
728 |
729 | all_predictions = collections.OrderedDict()
730 | all_nbest_json = collections.OrderedDict()
731 | scores_diff_json = collections.OrderedDict()
732 |
733 | for (example_index, example) in enumerate(all_examples):
734 | features = example_index_to_features[example_index]
735 |
736 | prelim_predictions = []
737 | # keep track of the minimum score of null start+end of position 0
738 | score_null = 1000000 # large and positive
739 |
740 | for (feature_index, feature) in enumerate(features):
741 | result = unique_id_to_result[feature.unique_id]
742 |
743 | cur_null_score = result.cls_logits
744 |
745 | # if we could have irrelevant answers, get the min score of irrelevant
746 | score_null = min(score_null, cur_null_score)
747 |
748 | for i in range(start_n_top):
749 | for j in range(end_n_top):
750 | start_log_prob = result.start_top_log_probs[i]
751 | start_index = result.start_top_index[i]
752 |
753 | j_index = i * end_n_top + j
754 |
755 | end_log_prob = result.end_top_log_probs[j_index]
756 | end_index = result.end_top_index[j_index]
757 |
758 | # We could hypothetically create invalid predictions, e.g., predict
759 | # that the start of the span is in the question. We throw out all
760 | # invalid predictions.
761 | if start_index >= feature.paragraph_len - 1:
762 | continue
763 | if end_index >= feature.paragraph_len - 1:
764 | continue
765 |
766 | if not feature.token_is_max_context.get(
767 | start_index, False):
768 | continue
769 | if end_index < start_index:
770 | continue
771 | length = end_index - start_index + 1
772 | if length > max_answer_length:
773 | continue
774 |
775 | prelim_predictions.append(
776 | _PrelimPrediction(
777 | feature_index=feature_index,
778 | start_index=start_index,
779 | end_index=end_index,
780 | start_log_prob=start_log_prob,
781 | end_log_prob=end_log_prob))
782 |
783 | prelim_predictions = sorted(
784 | prelim_predictions,
785 | key=lambda x: (x.start_log_prob + x.end_log_prob),
786 | reverse=True)
787 |
788 | seen_predictions = {}
789 | nbest = []
790 | for pred in prelim_predictions:
791 | if len(nbest) >= n_best_size:
792 | break
793 | feature = features[pred.feature_index]
794 |
795 | # XLNet un-tokenizer
796 | # Let's keep it simple for now and see if we need all this later.
797 | #
798 | # tok_start_to_orig_index = feature.tok_start_to_orig_index
799 | # tok_end_to_orig_index = feature.tok_end_to_orig_index
800 | # start_orig_pos = tok_start_to_orig_index[pred.start_index]
801 | # end_orig_pos = tok_end_to_orig_index[pred.end_index]
802 | # paragraph_text = example.paragraph_text
803 | # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
804 |
805 | # Previously used Bert untokenizer
806 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
807 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
808 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
809 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
810 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
811 |
812 | # Clean whitespace
813 | tok_text = tok_text.strip()
814 | tok_text = " ".join(tok_text.split())
815 | orig_text = " ".join(orig_tokens)
816 |
817 | final_text = get_final_text(
818 | tok_text, orig_text, tokenizer.do_lower_case, verbose_logging)
819 |
820 | if final_text in seen_predictions:
821 | continue
822 |
823 | seen_predictions[final_text] = True
824 |
825 | nbest.append(
826 | _NbestPrediction(
827 | text=final_text,
828 | start_log_prob=pred.start_log_prob,
829 | end_log_prob=pred.end_log_prob))
830 |
831 | # In very rare edge cases we could have no valid predictions. So we
832 | # just create a nonce prediction in this case to avoid failure.
833 | if not nbest:
834 | nbest.append(
835 | _NbestPrediction(
836 | text="", start_log_prob=-1e6, end_log_prob=-1e6))
837 |
838 | total_scores = []
839 | best_non_null_entry = None
840 | for entry in nbest:
841 | total_scores.append(entry.start_log_prob + entry.end_log_prob)
842 | if not best_non_null_entry:
843 | best_non_null_entry = entry
844 |
845 | probs = _compute_softmax(total_scores)
846 |
847 | nbest_json = []
848 | for (i, entry) in enumerate(nbest):
849 | output = collections.OrderedDict()
850 | output["text"] = entry.text
851 | output["probability"] = probs[i]
852 | output["start_log_prob"] = entry.start_log_prob
853 | output["end_log_prob"] = entry.end_log_prob
854 | nbest_json.append(output)
855 |
856 | assert len(nbest_json) >= 1
857 | assert best_non_null_entry is not None
858 |
859 | score_diff = score_null
860 | scores_diff_json[example.qas_id] = score_diff
861 | # note(zhiliny): always predict best_non_null_entry
862 | # and the evaluation script will search for the best threshold
863 | all_predictions[example.qas_id] = best_non_null_entry.text
864 |
865 | all_nbest_json[example.qas_id] = nbest_json
866 |
867 | with open(output_prediction_file, "w") as writer:
868 | writer.write(json.dumps(all_predictions, indent=4) + "\n")
869 |
870 | with open(output_nbest_file, "w") as writer:
871 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
872 |
873 | if version_2_with_negative:
874 | with open(output_null_log_odds_file, "w") as writer:
875 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
876 |
877 | with open(orig_data_file, "r", encoding='utf-8') as reader:
878 | orig_data = json.load(reader)["data"]
879 |
880 | qid_to_has_ans = make_qid_to_has_ans(orig_data)
881 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
882 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
883 | exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
884 | out_eval = {}
885 |
886 | find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw,
887 | scores_diff_json, qid_to_has_ans)
888 |
889 | return out_eval
890 |
891 |
892 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
893 | """Project the tokenized prediction back to the original text."""
894 |
895 | # When we created the data, we kept track of the alignment between original
896 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
897 | # now `orig_text` contains the span of our original text corresponding to the
898 | # span that we predicted.
899 | #
900 | # However, `orig_text` may contain extra characters that we don't want in
901 | # our prediction.
902 | #
903 | # For example, let's say:
904 | # pred_text = steve smith
905 | # orig_text = Steve Smith's
906 | #
907 | # We don't want to return `orig_text` because it contains the extra "'s".
908 | #
909 | # We don't want to return `pred_text` because it's already been normalized
910 | # (the SQuAD eval script also does punctuation stripping/lower casing but
911 | # our tokenizer does additional normalization like stripping accent
912 | # characters).
913 | #
914 | # What we really want to return is "Steve Smith".
915 | #
916 | # Therefore, we have to apply a semi-complicated alignment heuristic between
917 | # `pred_text` and `orig_text` to get a character-to-character alignment. This
918 | # can fail in certain cases in which case we just return `orig_text`.
919 |
920 | def _strip_spaces(text):
921 | ns_chars = []
922 | ns_to_s_map = collections.OrderedDict()
923 | for (i, c) in enumerate(text):
924 | if c == " ":
925 | continue
926 | ns_to_s_map[len(ns_chars)] = i
927 | ns_chars.append(c)
928 | ns_text = "".join(ns_chars)
929 | return (ns_text, ns_to_s_map)
930 |
931 | # We first tokenize `orig_text`, strip whitespace from the result
932 | # and `pred_text`, and check if they are the same length. If they are
933 | # NOT the same length, the heuristic has failed. If they are the same
934 | # length, we assume the characters are one-to-one aligned.
935 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
936 |
937 | tok_text = " ".join(tokenizer.tokenize(orig_text))
938 |
939 | start_position = tok_text.find(pred_text)
940 | if start_position == -1:
941 | if verbose_logging:
942 | logger.info(
943 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
944 | return orig_text
945 | end_position = start_position + len(pred_text) - 1
946 |
947 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
948 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
949 |
950 | if len(orig_ns_text) != len(tok_ns_text):
951 | if verbose_logging:
952 | logger.info(
953 | "Length not equal after stripping spaces: '%s' vs '%s'",
954 | orig_ns_text, tok_ns_text)
955 | return orig_text
956 |
957 | # We then project the characters in `pred_text` back to `orig_text` using
958 | # the character-to-character alignment.
959 | tok_s_to_ns_map = {}
960 | for (i, tok_index) in tok_ns_to_s_map.items():
961 | tok_s_to_ns_map[tok_index] = i
962 |
963 | orig_start_position = None
964 | if start_position in tok_s_to_ns_map:
965 | ns_start_position = tok_s_to_ns_map[start_position]
966 | if ns_start_position in orig_ns_to_s_map:
967 | orig_start_position = orig_ns_to_s_map[ns_start_position]
968 |
969 | if orig_start_position is None:
970 | if verbose_logging:
971 | logger.info("Couldn't map start position")
972 | return orig_text
973 |
974 | orig_end_position = None
975 | if end_position in tok_s_to_ns_map:
976 | ns_end_position = tok_s_to_ns_map[end_position]
977 | if ns_end_position in orig_ns_to_s_map:
978 | orig_end_position = orig_ns_to_s_map[ns_end_position]
979 |
980 | if orig_end_position is None:
981 | if verbose_logging:
982 | logger.info("Couldn't map end position")
983 | return orig_text
984 |
985 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
986 | return output_text
987 |
988 |
989 | def _get_best_indexes(logits, n_best_size):
990 | """Get the n-best logits from a list."""
991 | index_and_score = sorted(
992 | enumerate(logits), key=lambda x: x[1], reverse=True)
993 |
994 | best_indexes = []
995 | for i in range(len(index_and_score)):
996 | if i >= n_best_size:
997 | break
998 | best_indexes.append(index_and_score[i][0])
999 | return best_indexes
1000 |
1001 |
1002 | def _compute_softmax(scores):
1003 | """Compute softmax probability over raw logits."""
1004 | if not scores:
1005 | return []
1006 |
1007 | max_score = None
1008 | for score in scores:
1009 | if max_score is None or score > max_score:
1010 | max_score = score
1011 |
1012 | exp_scores = []
1013 | total_sum = 0.0
1014 | for score in scores:
1015 | x = math.exp(score - max_score)
1016 | exp_scores.append(x)
1017 | total_sum += x
1018 |
1019 | probs = []
1020 | for score in exp_scores:
1021 | probs.append(score / total_sum)
1022 | return probs
--------------------------------------------------------------------------------