├── .gitignore ├── requirements.txt ├── start.sh ├── input.json ├── README.md ├── data-chinese.json ├── interactive.py ├── utils_squad_evaluate.py ├── bert_qa.py └── utils_squad.py /.gitignore: -------------------------------------------------------------------------------- 1 | lib/ 2 | bin/ 3 | include/ 4 | .vscode 5 | __pycache__ 6 | cache* 7 | .idea/* 8 | output/* 9 | runs/* 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-transformers==1.2.0 2 | tensorboardX==1.8 3 | tensorflow==1.14.0 4 | torch==1.0.1.post2 5 | tqdm==4.31.1 6 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=`pwd` 2 | export DATA_DIR=$CURRENT_DIR/data 3 | 4 | python bert_qa.py \ 5 | --model_type bert \ 6 | --model_name_or_path bert-base-chinese \ 7 | --do_train \ 8 | --do_eval \ 9 | --do_lower_case \ 10 | --train_file $DATA_DIR/cmrc2018_train.json \ 11 | --predict_file $DATA_DIR/cmrc2018_dev.json \ 12 | --per_gpu_train_batch_size 32 \ 13 | --learning_rate 3e-5 \ 14 | --num_train_epochs 3.0 \ 15 | --max_seq_length 384 \ 16 | --doc_stride 128 \ 17 | --output_dir output -------------------------------------------------------------------------------- /input.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "v1.0", 3 | "data": [ 4 | { 5 | "paragraphs": [ 6 | { 7 | "id": "", 8 | "context": "fd", 9 | "qas": [ 10 | { 11 | "question": "df", 12 | "id": "", 13 | "answers": [] 14 | }, 15 | { 16 | "question": "df", 17 | "id": "", 18 | "answers": [] 19 | }, 20 | { 21 | "question": "df", 22 | "id": "", 23 | "answers": [] 24 | } 25 | ] 26 | } 27 | ], 28 | "id": "", 29 | "title": "" 30 | } 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

基于预训练模型BERT的阅读理解

2 | 3 | 4 | Here we are going to building a machine reading comprehension system using pretrained model bert from google, the latest advances in deep learning for NLP. 5 | 6 | Stanford Question Answering Dataset (SQuAD) is one of the first large reading comprehension datasets in English. From the perspective of model, the inputs come in the form of a Context / Question pair, and the outputs are Answers: pairs of integers, indexing the start and the end of the answer's text contained inside the Context. 7 | [The 2nd Evaluation Workshop on Chinese Machine Reading Comprehension(2018)](https://github.com/ymcui/cmrc2018) release part of the datasets similar to SQuAD, which we used in this example. 8 | 9 | The model is built on top of [pytorch-transformers](https://github.com/huggingface/pytorch-transformers) which help to use pretrained model like BERT, GPT, GPT2 to downstream tasks. The repository includes various utilities and training scripts for multiple NLP tasks, including Question Answering. Below are two relate post about QA using bert: 10 | 11 | >-[Understanding text with BERT](https://blog.scaleway.com/2019/understanding-text-with-bert/) 12 | 13 | >-[Extending Google-BERT as Question and Answering model and Chatbot](https://medium.com/datadriveninvestor/extending-google-bert-as-question-and-answering-model-and-chatbot-e3e7b47b721a) 14 | 15 |

Getting Started

16 | 17 | #### 1. Prepare data, the virtual python environment and install the package in requirements.txt 18 | 19 | #### 2. Run the command below to fine tune 20 | 21 | ```bash 22 | python bert_qa.py \ 23 | --model_type bert \ 24 | --model_name_or_path bert-base-chinese \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_lower_case \ 28 | --train_file ~/cmrc2018_train.json \ 29 | --predict_file ~/cmrc2018_trial.json \ 30 | --per_gpu_train_batch_size 12 \ 31 | --learning_rate 3e-5 \ 32 | --num_train_epochs 2.0 \ 33 | --max_seq_length 384 \ 34 | --doc_stride 128 \ 35 | --output_dir ~/chinese-qa-with-bert/output \ 36 | --save_steps 200 37 | ``` 38 | #### 4. Load the fine-tuned params and run this command to interactive 39 | 40 | ```bash 41 | python interactive.py \ 42 | --output_dir ~/chinese-qa-with-bert/output \ 43 | --model_type bert \ 44 | --predict_file ~/chinese-qa-with-bert/input.json \ 45 | --state_dict ~/chinese-qa-with-bert/output/pytorch_model.bin \ 46 | --model_name_or_path bert-base-chinese 47 | ``` 48 | 49 | the input json string contains context and question is like this: 50 | 51 | ```json 52 | {"context": "海梅·雷耶斯()是DC漫画公司的一个虚拟人物。该人物首先出现于《无限危机 #3》(2006年二月),是第三代蓝甲虫,由作家基斯·吉芬和约翰·罗杰斯创作,屈伊·哈姆纳作画。海梅与他的父母妹妹生活在得克萨斯州的艾尔帕索。他的父亲拥有一间汽车修理厂。海梅建议自己帮助父亲在汽车修理厂中干活,然而他的父亲迄今为止未答应他,觉得海梅应该花更多的功夫在学习上并享受他自己的童年生活。海梅对他的家庭和朋友们有强烈的责任感,可是他经常抱怨做一个仅解决琐事的人。第二代蓝甲虫(泰德·科德)死前派遣圣甲虫去给惊奇队长送信,圣甲虫也因此留在了惊奇队长的永恒之岩里。之后惊奇队长被杀,蓝甲虫降落到了得克萨斯州的艾尔帕索,被少年海梅捡到,后来蓝甲虫在海梅睡觉时融合进他的脊椎,海梅从此成为第三代蓝甲虫,此时的蓝甲虫具备了随意变武器的能力。蓝甲虫原本是宇宙侵略组织Reach的战争工具,具有自己的思想。它曾被作为礼物赐予某星球的勇士,实际上是暗中控制他们。“卡基达(Kaji Dha)”是它的启动口令,第一代蓝甲虫加勒特就在它的影响下攻击过科德。蓝甲虫在无限危机后受到外界强大能量的影响沉睡了一年,苏醒后又想控制海梅,但被海梅的意念克制了,加上它本身存在程序故障,久而久之他与海梅成了好友。","qas": [{"question": "该漫画是由谁创作的?"}]} 53 | ``` 54 | ### Below are some results for show 55 | 56 | ```json 57 | {"AVERAGE": "27.276", "F1": "54.552", "EM": "0.000", "TOTAL": 1002, "SKIP": 0, "FILE": "~/chinese-qa-with-bert/output/predictions_. 58 | json"} 59 | ``` 60 | 61 | ```json 62 | Please Enter:{"context": "江苏路街道是中国上海市长宁区下辖的一个街道办事处,位于长宁区最东部,东到镇宁路接邻静安区的静安寺街道,北到武定西路接邻静安区的曹家渡 街道,南到华山路接邻徐汇区的湖南路街道,西部界限为长宁路、安西路、武夷路等。面积1.52平方公里,户籍人口5.26万人,下辖13个居委会。长宁区区政府设在该街道辖区内的 愚园路(近安西路)。江苏路街道的主要街道江苏路、愚园路、长宁路、武夷路、武定西路、延安西路,均为上海公共租界越界筑路,是花园洋房集中的街区,是愚园路历史文化风 貌区的主体部分。较著名的近代建筑有中西女中(今上海市第三女子中学)、王伯群及汪精卫公馆(今长宁区少年宫)、西园公寓等。江苏路、长宁路、延安西路(高架)等经过拓 宽,已经形成上海市的交通干道。上海市轨道交通二号线经过该街道辖区,设有江苏路站。江苏路街道下辖歧山居委会、江苏居委会、万村居委会、南汪居委会、东浜居委会、愚三 居委会、曹家堰居委会、西浜居委会、福世居委会、长新居委会、华山居委会、利西居委会、北汪居委会。","qas": [{"question": "江苏路街道在上海市的什么地方?"}]} 63 | Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.61it/s] 64 | 65 | 江 苏 路 街 道 是 中 国 上 海 市 长 宁 66 | ``` 67 | 68 | 69 | Just For Learn, More Optimization to Do -------------------------------------------------------------------------------- /data-chinese.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "v1.0", 3 | "data": [ 4 | { 5 | "paragraphs": [ 6 | { 7 | "id": "TRAIN_54", 8 | "context": "安雅·罗素法(,),来自俄罗斯圣彼得堡的模特儿。她是《全美超级模特儿新秀大赛》第十季的亚军。2008年,安雅宣布改回出生时的名字:安雅·罗素法(Anya Rozova),在此之前是使用安雅·冈()。安雅于俄罗斯出生,后来被一个居住在美国夏威夷群岛欧胡岛檀香山的家庭领养。安雅十七岁时曾参与香奈儿、路易·威登及芬迪(Fendi)等品牌的非正式时装秀。2007年,她于瓦伊帕胡高级中学毕业。毕业后,她当了一名售货员。她曾为Russell Tanoue拍摄照片,Russell Tanoue称赞她是「有前途的新面孔」。安雅在半准决赛面试时说她对模特儿行业充满热诚,所以参加全美超级模特儿新秀大赛。她于比赛中表现出色,曾五次首名入围,平均入围顺序更拿下历届以来最优异的成绩(2.64),另外胜出三次小挑战,分别获得与评判尼祖·百克拍照、为柠檬味道的七喜拍摄广告的机会及十万美元、和盖马蒂洛(Gai Mattiolo)设计的晚装。在最后两强中,安雅与另一名参赛者惠妮·汤姆森为范思哲走秀,但评判认为她在台上不够惠妮突出,所以选了惠妮当冠军,安雅屈居亚军(但就整体表现来说,部份网友认为安雅才是第十季名副其实的冠军。)安雅在比赛拿五次第一,也胜出多次小挑战。安雅赛后再次与Russell Tanoue合作,为2008年4月30日出版的MidWeek杂志拍摄封面及内页照。其后她参加了V杂志与Supreme模特儿公司合办的模特儿选拔赛2008。她其后更与Elite签约。最近她与香港的模特儿公司 Style International Management 签约,并在香港发展其模特儿事业。她曾在很多香港的时装杂志中任模特儿,《Jet》、《东方日报》、《Elle》等。", 9 | "qas": [ 10 | { 11 | "question": "安雅·罗素法参加了什么比赛获得了亚军?", 12 | "id": "TRAIN_54_QUERY_0", 13 | "answers": [ 14 | { 15 | "text": "《全美超级模特儿新秀大赛》第十季", 16 | "answer_start": 26 17 | } 18 | ] 19 | }, 20 | { 21 | "question": "Russell Tanoue对安雅·罗素法的评价是什么?", 22 | "id": "TRAIN_54_QUERY_1", 23 | "answers": [ 24 | { 25 | "text": "有前途的新面孔", 26 | "answer_start": 247 27 | } 28 | ] 29 | }, 30 | { 31 | "question": "安雅·罗素法合作过的香港杂志有哪些?", 32 | "id": "TRAIN_54_QUERY_2", 33 | "answers": [ 34 | { 35 | "text": "《Jet》、《东方日报》、《Elle》等", 36 | "answer_start": 706 37 | } 38 | ] 39 | }, 40 | { 41 | "question": "毕业后的安雅·罗素法职业是什么?", 42 | "id": "TRAIN_54_QUERY_3", 43 | "answers": [ 44 | { 45 | "text": "售货员", 46 | "answer_start": 202 47 | } 48 | ] 49 | } 50 | ] 51 | } 52 | ], 53 | "id": "TRAIN_54", 54 | "title": "安雅·罗素法" 55 | }, 56 | { 57 | "paragraphs": [ 58 | { 59 | "id": "TRAIN_756", 60 | "context": "为日本漫画足球小将翼的一个角色,自小父母离异,与父亲一起四处为家,每个地方也是待一会便离开,但他仍然能够保持优秀的学业成绩。在第一次南葛市生活时,与同样就读于南葛小学的大空翼为黄金拍档,曾效力球队包括南葛小学、南葛高中、日本少年队、日本青年军、日本奥运队。效力日本青年军期间,因救同母异父的妹妹导致被车撞至断脚,在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。基本资料绰号:球场上的艺术家出身地:日本南葛市诞生日:5月5日星座:金牛座球衣号码:11担任位置:中场、攻击中场、右中场擅长脚:右脚所属队伍:盘田山叶故事发展岬太郎在小学期间不断转换学校,在南葛小学就读时在全国大赛中夺得冠军;国中三年随父亲孤单地在法国留学;回国后三年的高中生涯一直输给日本王牌射手日向小次郎率领的东邦学院。在【Golden 23】年代,大空翼、日向小次郎等名将均转战海外,他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。必杀技1. 回力刀射门2. S. S. S. 射门3. 双人射门(与大空翼合作)", 61 | "qas": [ 62 | { 63 | "question": "岬太郎在第一次南葛市生活时的搭档是谁?", 64 | "id": "TRAIN_756_QUERY_0", 65 | "answers": [ 66 | { 67 | "text": "大空翼", 68 | "answer_start": 84 69 | } 70 | ] 71 | }, 72 | { 73 | "question": "日本队夺得世青冠军,岬太郎发挥了什么作用?", 74 | "id": "TRAIN_756_QUERY_1", 75 | "answers": [ 76 | { 77 | "text": "在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。", 78 | "answer_start": 156 79 | } 80 | ] 81 | }, 82 | { 83 | "question": "岬太郎与谁一起组成了「3M」组合?", 84 | "id": "TRAIN_756_QUERY_2", 85 | "answers": [ 86 | { 87 | "text": "他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。", 88 | "answer_start": 391 89 | } 90 | ] 91 | } 92 | ] 93 | } 94 | ], 95 | "id": "TRAIN_756", 96 | "title": "岬太郎" 97 | } 98 | ] 99 | } -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | from pytorch_transformers import (BertConfig, BertForQuestionAnswering, 5 | BertTokenizer) 6 | from bert_qa import evaluate 7 | import os 8 | 9 | parser = argparse.ArgumentParser() 10 | ## Required parameters 11 | parser.add_argument( 12 | "--train_file", 13 | default=None, 14 | type=str, 15 | required=False, 16 | help="SQuAD json for training. E.g., train-v1.1.json") 17 | parser.add_argument( 18 | "--predict_file", 19 | default=None, 20 | type=str, 21 | required=True, 22 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 23 | parser.add_argument( 24 | "--model_type", 25 | default=None, 26 | type=str, 27 | required=True, 28 | help="Model type selected in the list: ") 29 | parser.add_argument( 30 | "--model_name_or_path", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="Path to pre-trained model or shortcut name selected in the list: ") 35 | parser.add_argument( 36 | "--output_dir", 37 | default=None, 38 | type=str, 39 | required=True, 40 | help= 41 | "The output directory where the model checkpoints and predictions will be written." 42 | ) 43 | 44 | ## Other parameters 45 | parser.add_argument( 46 | "--config_name", 47 | default="", 48 | type=str, 49 | help="Pretrained config name or path if not the same as model_name") 50 | parser.add_argument( 51 | "--tokenizer_name", 52 | default="", 53 | type=str, 54 | help="Pretrained tokenizer name or path if not the same as model_name") 55 | parser.add_argument( 56 | "--cache_dir", 57 | default="", 58 | type=str, 59 | help="Where do you want to store the pre-trained models downloaded from s3" 60 | ) 61 | 62 | parser.add_argument( 63 | '--version_2_with_negative', 64 | action='store_true', 65 | help='If true, the SQuAD examples contain some that do not have an answer.' 66 | ) 67 | parser.add_argument( 68 | '--null_score_diff_threshold', 69 | type=float, 70 | default=0.0, 71 | help= 72 | "If null_score - best_non_null is greater than the threshold predict null." 73 | ) 74 | 75 | parser.add_argument( 76 | "--max_seq_length", 77 | default=384, 78 | type=int, 79 | help= 80 | "The maximum total input sequence length after WordPiece tokenization. Sequences " 81 | "longer than this will be truncated, and sequences shorter than this will be padded." 82 | ) 83 | parser.add_argument( 84 | "--doc_stride", 85 | default=128, 86 | type=int, 87 | help= 88 | "When splitting up a long document into chunks, how much stride to take between chunks." 89 | ) 90 | parser.add_argument( 91 | "--max_query_length", 92 | default=64, 93 | type=int, 94 | help= 95 | "The maximum number of tokens for the question. Questions longer than this will " 96 | "be truncated to this length.") 97 | parser.add_argument( 98 | "--do_train", action='store_true', help="Whether to run training.") 99 | parser.add_argument( 100 | "--do_eval", 101 | action='store_true', 102 | help="Whether to run eval on the dev set.") 103 | parser.add_argument( 104 | "--evaluate_during_training", 105 | action='store_true', 106 | help="Rul evaluation during training at each logging step.") 107 | parser.add_argument( 108 | "--do_lower_case", 109 | action='store_true', 110 | help="Set this flag if you are using an uncased model.") 111 | 112 | parser.add_argument( 113 | "--per_gpu_train_batch_size", 114 | default=8, 115 | type=int, 116 | help="Batch size per GPU/CPU for training.") 117 | parser.add_argument( 118 | "--per_gpu_eval_batch_size", 119 | default=8, 120 | type=int, 121 | help="Batch size per GPU/CPU for evaluation.") 122 | parser.add_argument( 123 | "--learning_rate", 124 | default=5e-5, 125 | type=float, 126 | help="The initial learning rate for Adam.") 127 | parser.add_argument( 128 | '--gradient_accumulation_steps', 129 | type=int, 130 | default=1, 131 | help= 132 | "Number of updates steps to accumulate before performing a backward/update pass." 133 | ) 134 | parser.add_argument( 135 | "--weight_decay", 136 | default=0.0, 137 | type=float, 138 | help="Weight deay if we apply some.") 139 | parser.add_argument( 140 | "--adam_epsilon", 141 | default=1e-8, 142 | type=float, 143 | help="Epsilon for Adam optimizer.") 144 | parser.add_argument( 145 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 146 | parser.add_argument( 147 | "--num_train_epochs", 148 | default=3.0, 149 | type=float, 150 | help="Total number of training epochs to perform.") 151 | parser.add_argument( 152 | "--max_steps", 153 | default=-1, 154 | type=int, 155 | help= 156 | "If > 0: set total number of training steps to perform. Override num_train_epochs." 157 | ) 158 | parser.add_argument( 159 | "--warmup_steps", 160 | default=0, 161 | type=int, 162 | help="Linear warmup over warmup_steps.") 163 | parser.add_argument( 164 | "--n_best_size", 165 | default=20, 166 | type=int, 167 | help= 168 | "The total number of n-best predictions to generate in the nbest_predictions.json output file." 169 | ) 170 | parser.add_argument( 171 | "--max_answer_length", 172 | default=30, 173 | type=int, 174 | help= 175 | "The maximum length of an answer that can be generated. This is needed because the start " 176 | "and end predictions are not conditioned on one another.") 177 | parser.add_argument( 178 | "--verbose_logging", 179 | action='store_true', 180 | help= 181 | "If true, all of the warnings related to data processing will be printed. " 182 | "A number of warnings are expected for a normal SQuAD evaluation.") 183 | 184 | parser.add_argument( 185 | '--logging_steps', type=int, default=50, help="Log every X updates steps.") 186 | parser.add_argument( 187 | '--save_steps', 188 | type=int, 189 | default=50, 190 | help="Save checkpoint every X updates steps.") 191 | parser.add_argument( 192 | "--eval_all_checkpoints", 193 | action='store_true', 194 | help= 195 | "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" 196 | ) 197 | parser.add_argument( 198 | "--no_cuda", 199 | action='store_true', 200 | help="Whether not to use CUDA when available") 201 | parser.add_argument( 202 | '--overwrite_output_dir', 203 | action='store_true', 204 | help="Overwrite the content of the output directory") 205 | parser.add_argument( 206 | '--overwrite_cache', 207 | action='store_true', 208 | help="Overwrite the cached training and evaluation sets") 209 | parser.add_argument( 210 | '--seed', type=int, default=42, help="random seed for initialization") 211 | 212 | parser.add_argument( 213 | "--local_rank", 214 | type=int, 215 | default=-1, 216 | help="local_rank for distributed training on gpus") 217 | parser.add_argument( 218 | '--fp16', 219 | action='store_true', 220 | help= 221 | "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" 222 | ) 223 | parser.add_argument( 224 | '--fp16_opt_level', 225 | type=str, 226 | default='O1', 227 | help= 228 | "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 229 | "See details at https://nvidia.github.io/apex/amp.html") 230 | parser.add_argument( 231 | "--state_dict", 232 | default=None, 233 | type=str, 234 | required=True, 235 | help="model para after pretrained") 236 | 237 | args = parser.parse_args() 238 | args.n_gpu = torch.cuda.device_count() 239 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 240 | device = torch.device( 241 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 242 | args.device = device 243 | tokenizer = BertTokenizer.from_pretrained( 244 | 'bert-base-chinese', do_lower_case=False) 245 | config = BertConfig.from_pretrained('bert-base-chinese') 246 | model = BertForQuestionAnswering(config) 247 | model_state_dict = args.state_dict 248 | model.load_state_dict(torch.load(model_state_dict)) 249 | model.to(args.device) 250 | model.eval() 251 | input_file = args.predict_file 252 | 253 | 254 | def handle_file(input_file, context, question): 255 | with open(input_file, "r") as reader: 256 | orig_data = json.load(reader) 257 | orig_data["data"][0]['paragraphs'][0]['context'] = context 258 | for i in range(len(question)): 259 | orig_data["data"][0]['paragraphs'][0]['qas'][i][ 260 | 'question'] = question[i] 261 | with open(input_file, "w") as writer: 262 | writer.write(json.dumps(orig_data, indent=4) + "\n") 263 | 264 | 265 | def run(): 266 | while True: 267 | raw_text = input("Please Enter:") 268 | while not raw_text: 269 | print('Input should not be empty!') 270 | raw_text = input("Please Enter:") 271 | context = '' 272 | question = [] 273 | try: 274 | raw_json = json.loads(raw_text) 275 | context = raw_json['context'] 276 | if not context: 277 | continue 278 | raw_qas = raw_json['qas'] 279 | if not raw_qas: 280 | continue 281 | for i in range(len(raw_qas)): 282 | question.append(raw_qas[i]['question']) 283 | except Exception as identifier: 284 | print(identifier) 285 | continue 286 | handle_file(input_file, context, question) 287 | evaluate(args, model, tokenizer) 288 | 289 | predict_file = os.path.join(args.output_dir, "predictions_.json") 290 | with open(predict_file, "r") as reader: 291 | orig_data = json.load(reader) 292 | print(orig_data[""]) 293 | # clean input file 294 | handle_file(input_file, "", ["", "", ""]) 295 | 296 | 297 | if __name__ == "__main__": 298 | run() 299 | -------------------------------------------------------------------------------- /utils_squad_evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for SQuAD version 2.0. 2 | Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | 18 | class EVAL_OPTS(): 19 | def __init__(self, 20 | data_file, 21 | pred_file, 22 | out_file="", 23 | na_prob_file="na_prob.json", 24 | na_prob_thresh=1.0, 25 | out_image_dir=None, 26 | verbose=False): 27 | self.data_file = data_file 28 | self.pred_file = pred_file 29 | self.out_file = out_file 30 | self.na_prob_file = na_prob_file 31 | self.na_prob_thresh = na_prob_thresh 32 | self.out_image_dir = out_image_dir 33 | self.verbose = verbose 34 | 35 | 36 | OPTS = None 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser( 41 | 'Official evaluation script for SQuAD version 2.0.') 42 | parser.add_argument( 43 | 'data_file', metavar='data.json', help='Input data JSON file.') 44 | parser.add_argument( 45 | 'pred_file', metavar='pred.json', help='Model predictions.') 46 | parser.add_argument( 47 | '--out-file', 48 | '-o', 49 | metavar='eval.json', 50 | help='Write accuracy metrics to file (default is stdout).') 51 | parser.add_argument( 52 | '--na-prob-file', 53 | '-n', 54 | metavar='na_prob.json', 55 | help='Model estimates of probability of no answer.') 56 | parser.add_argument( 57 | '--na-prob-thresh', 58 | '-t', 59 | type=float, 60 | default=1.0, 61 | help='Predict "" if no-answer probability exceeds this (default = 1.0).' 62 | ) 63 | parser.add_argument( 64 | '--out-image-dir', 65 | '-p', 66 | metavar='out_images', 67 | default=None, 68 | help='Save precision-recall curves to directory.') 69 | parser.add_argument('--verbose', '-v', action='store_true') 70 | if len(sys.argv) == 1: 71 | parser.print_help() 72 | sys.exit(1) 73 | return parser.parse_args() 74 | 75 | 76 | def make_qid_to_has_ans(dataset): 77 | qid_to_has_ans = {} 78 | for article in dataset: 79 | for p in article['paragraphs']: 80 | for qa in p['qas']: 81 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 82 | return qid_to_has_ans 83 | 84 | 85 | def normalize_answer(s): 86 | """Lower text and remove punctuation, articles and extra whitespace.""" 87 | 88 | def remove_articles(text): 89 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 90 | return re.sub(regex, ' ', text) 91 | 92 | def white_space_fix(text): 93 | return ' '.join(text.split()) 94 | 95 | def remove_punc(text): 96 | exclude = set(string.punctuation) 97 | return ''.join(ch for ch in text if ch not in exclude) 98 | 99 | def lower(text): 100 | return text.lower() 101 | 102 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 103 | 104 | 105 | def get_tokens(s): 106 | if not s: return [] 107 | return normalize_answer(s).split() 108 | 109 | 110 | def compute_exact(a_gold, a_pred): 111 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 112 | 113 | 114 | def compute_f1(a_gold, a_pred): 115 | gold_toks = get_tokens(a_gold) 116 | pred_toks = get_tokens(a_pred) 117 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 118 | num_same = sum(common.values()) 119 | if len(gold_toks) == 0 or len(pred_toks) == 0: 120 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 121 | return int(gold_toks == pred_toks) 122 | if num_same == 0: 123 | return 0 124 | precision = 1.0 * num_same / len(pred_toks) 125 | recall = 1.0 * num_same / len(gold_toks) 126 | f1 = (2 * precision * recall) / (precision + recall) 127 | return f1 128 | 129 | 130 | def get_raw_scores(dataset, preds): 131 | exact_scores = {} 132 | f1_scores = {} 133 | for article in dataset: 134 | for p in article['paragraphs']: 135 | for qa in p['qas']: 136 | qid = qa['id'] 137 | gold_answers = [ 138 | a['text'] for a in qa['answers'] 139 | if normalize_answer(a['text']) 140 | ] 141 | if not gold_answers: 142 | # For unanswerable questions, only correct answer is empty string 143 | gold_answers = [''] 144 | if qid not in preds: 145 | print('Missing prediction for %s' % qid) 146 | continue 147 | a_pred = preds[qid] 148 | # Take max over all gold answers 149 | exact_scores[qid] = max( 150 | compute_exact(a, a_pred) for a in gold_answers) 151 | f1_scores[qid] = max( 152 | compute_f1(a, a_pred) for a in gold_answers) 153 | return exact_scores, f1_scores 154 | 155 | 156 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 157 | new_scores = {} 158 | for qid, s in scores.items(): 159 | pred_na = na_probs[qid] > na_prob_thresh 160 | if pred_na: 161 | new_scores[qid] = float(not qid_to_has_ans[qid]) 162 | else: 163 | new_scores[qid] = s 164 | return new_scores 165 | 166 | 167 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 168 | if not qid_list: 169 | total = len(exact_scores) 170 | return collections.OrderedDict([ 171 | ('exact', 100.0 * sum(exact_scores.values()) / total), 172 | ('f1', 100.0 * sum(f1_scores.values()) / total), 173 | ('total', total), 174 | ]) 175 | else: 176 | total = len(qid_list) 177 | return collections.OrderedDict([ 178 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 179 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 180 | ('total', total), 181 | ]) 182 | 183 | 184 | def merge_eval(main_eval, new_eval, prefix): 185 | for k in new_eval: 186 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 187 | 188 | 189 | def plot_pr_curve(precisions, recalls, out_image, title): 190 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 191 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 192 | plt.xlabel('Recall') 193 | plt.ylabel('Precision') 194 | plt.xlim([0.0, 1.05]) 195 | plt.ylim([0.0, 1.05]) 196 | plt.title(title) 197 | plt.savefig(out_image) 198 | plt.clf() 199 | 200 | 201 | def make_precision_recall_eval(scores, 202 | na_probs, 203 | num_true_pos, 204 | qid_to_has_ans, 205 | out_image=None, 206 | title=None): 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | true_pos = 0.0 209 | cur_p = 1.0 210 | cur_r = 0.0 211 | precisions = [1.0] 212 | recalls = [0.0] 213 | avg_prec = 0.0 214 | for i, qid in enumerate(qid_list): 215 | if qid_to_has_ans[qid]: 216 | true_pos += scores[qid] 217 | cur_p = true_pos / float(i + 1) 218 | cur_r = true_pos / float(num_true_pos) 219 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 220 | 1]]: 221 | # i.e., if we can put a threshold after this point 222 | avg_prec += cur_p * (cur_r - recalls[-1]) 223 | precisions.append(cur_p) 224 | recalls.append(cur_r) 225 | if out_image: 226 | plot_pr_curve(precisions, recalls, out_image, title) 227 | return {'ap': 100.0 * avg_prec} 228 | 229 | 230 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 231 | qid_to_has_ans, out_image_dir): 232 | if out_image_dir and not os.path.exists(out_image_dir): 233 | os.makedirs(out_image_dir) 234 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 235 | if num_true_pos == 0: 236 | return 237 | pr_exact = make_precision_recall_eval( 238 | exact_raw, 239 | na_probs, 240 | num_true_pos, 241 | qid_to_has_ans, 242 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 243 | title='Precision-Recall curve for Exact Match score') 244 | pr_f1 = make_precision_recall_eval( 245 | f1_raw, 246 | na_probs, 247 | num_true_pos, 248 | qid_to_has_ans, 249 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 250 | title='Precision-Recall curve for F1 score') 251 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 252 | pr_oracle = make_precision_recall_eval( 253 | oracle_scores, 254 | na_probs, 255 | num_true_pos, 256 | qid_to_has_ans, 257 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 258 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)' 259 | ) 260 | merge_eval(main_eval, pr_exact, 'pr_exact') 261 | merge_eval(main_eval, pr_f1, 'pr_f1') 262 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 263 | 264 | 265 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 266 | if not qid_list: 267 | return 268 | x = [na_probs[k] for k in qid_list] 269 | weights = np.ones_like(x) / float(len(x)) 270 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 271 | plt.xlabel('Model probability of no-answer') 272 | plt.ylabel('Proportion of dataset') 273 | plt.title('Histogram of no-answer probability: %s' % name) 274 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 275 | plt.clf() 276 | 277 | 278 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 279 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 280 | cur_score = num_no_ans 281 | best_score = cur_score 282 | best_thresh = 0.0 283 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 284 | for i, qid in enumerate(qid_list): 285 | if qid not in scores: continue 286 | if qid_to_has_ans[qid]: 287 | diff = scores[qid] 288 | else: 289 | if preds[qid]: 290 | diff = -1 291 | else: 292 | diff = 0 293 | cur_score += diff 294 | if cur_score > best_score: 295 | best_score = cur_score 296 | best_thresh = na_probs[qid] 297 | return 100.0 * best_score / len(scores), best_thresh 298 | 299 | 300 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 301 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 302 | cur_score = num_no_ans 303 | best_score = cur_score 304 | best_thresh = 0.0 305 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 306 | for i, qid in enumerate(qid_list): 307 | if qid not in scores: continue 308 | if qid_to_has_ans[qid]: 309 | diff = scores[qid] 310 | else: 311 | if preds[qid]: 312 | diff = -1 313 | else: 314 | diff = 0 315 | cur_score += diff 316 | if cur_score > best_score: 317 | best_score = cur_score 318 | best_thresh = na_probs[qid] 319 | 320 | has_ans_score, has_ans_cnt = 0, 0 321 | for qid in qid_list: 322 | if not qid_to_has_ans[qid]: continue 323 | has_ans_cnt += 1 324 | 325 | if qid not in scores: continue 326 | has_ans_score += scores[qid] 327 | 328 | return 100.0 * best_score / len( 329 | scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 330 | 331 | 332 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, 333 | qid_to_has_ans): 334 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, 335 | qid_to_has_ans) 336 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, 337 | qid_to_has_ans) 338 | main_eval['best_exact'] = best_exact 339 | main_eval['best_exact_thresh'] = exact_thresh 340 | main_eval['best_f1'] = best_f1 341 | main_eval['best_f1_thresh'] = f1_thresh 342 | 343 | 344 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, 345 | qid_to_has_ans): 346 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2( 347 | preds, exact_raw, na_probs, qid_to_has_ans) 348 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2( 349 | preds, f1_raw, na_probs, qid_to_has_ans) 350 | main_eval['best_exact'] = best_exact 351 | main_eval['best_exact_thresh'] = exact_thresh 352 | main_eval['best_f1'] = best_f1 353 | main_eval['best_f1_thresh'] = f1_thresh 354 | main_eval['has_ans_exact'] = has_ans_exact 355 | main_eval['has_ans_f1'] = has_ans_f1 356 | 357 | 358 | def main(OPTS): 359 | with open(OPTS.data_file) as f: 360 | dataset_json = json.load(f) 361 | dataset = dataset_json['data'] 362 | with open(OPTS.pred_file) as f: 363 | preds = json.load(f) 364 | if OPTS.na_prob_file: 365 | with open(OPTS.na_prob_file) as f: 366 | na_probs = json.load(f) 367 | else: 368 | na_probs = {k: 0.0 for k in preds} 369 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 370 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 371 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 372 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 373 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 374 | OPTS.na_prob_thresh) 375 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 376 | OPTS.na_prob_thresh) 377 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 378 | if has_ans_qids: 379 | has_ans_eval = make_eval_dict( 380 | exact_thresh, f1_thresh, qid_list=has_ans_qids) 381 | merge_eval(out_eval, has_ans_eval, 'HasAns') 382 | if no_ans_qids: 383 | no_ans_eval = make_eval_dict( 384 | exact_thresh, f1_thresh, qid_list=no_ans_qids) 385 | merge_eval(out_eval, no_ans_eval, 'NoAns') 386 | if OPTS.na_prob_file: 387 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, 388 | qid_to_has_ans) 389 | if OPTS.na_prob_file and OPTS.out_image_dir: 390 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 391 | qid_to_has_ans, OPTS.out_image_dir) 392 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 393 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 394 | if OPTS.out_file: 395 | with open(OPTS.out_file, 'w') as f: 396 | json.dump(out_eval, f) 397 | else: 398 | print(json.dumps(out_eval, indent=2)) 399 | return out_eval 400 | 401 | 402 | if __name__ == '__main__': 403 | OPTS = parse_args() 404 | if OPTS.out_image_dir: 405 | import matplotlib 406 | matplotlib.use('Agg') 407 | import matplotlib.pyplot as plt 408 | main(OPTS) 409 | -------------------------------------------------------------------------------- /bert_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import logging 22 | import os 23 | import random 24 | import glob 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 29 | TensorDataset) 30 | from torch.utils.data.distributed import DistributedSampler 31 | from tqdm import tqdm, trange 32 | from tensorboardX import SummaryWriter 33 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 34 | BertForQuestionAnswering, BertTokenizer, 35 | XLMConfig, XLMForQuestionAnswering, 36 | XLMTokenizer, XLNetConfig, 37 | XLNetForQuestionAnswering, 38 | XLNetTokenizer) 39 | 40 | from pytorch_transformers import AdamW, WarmupLinearSchedule 41 | 42 | from utils_squad import (read_squad_examples, convert_examples_to_features, 43 | RawResult, write_predictions, write_predictions_extended) 44 | 45 | # The follwing import is the official SQuAD evaluation script (2.0). 46 | # You can remove it from the dependencies if you are using this script outside of the library 47 | # We've added it here for automated tests (see examples/test_examples.py file) 48 | from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ 53 | for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) 54 | 55 | MODEL_CLASSES = { 56 | 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), 57 | 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 58 | 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), 59 | } 60 | 61 | def set_seed(args): 62 | random.seed(args.seed) 63 | np.random.seed(args.seed) 64 | torch.manual_seed(args.seed) 65 | if args.n_gpu > 0: 66 | torch.cuda.manual_seed_all(args.seed) 67 | 68 | def to_list(tensor): 69 | return tensor.detach().cpu().tolist() 70 | 71 | def train(args, train_dataset, model, tokenizer): 72 | """ Train the model """ 73 | if args.local_rank in [-1, 0]: 74 | tb_writer = SummaryWriter() 75 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 76 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 77 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 78 | 79 | if args.max_steps > 0: 80 | t_total = args.max_steps 81 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 82 | else: 83 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 84 | 85 | # Prepare optimizer and schedule (linear warmup and decay) 86 | no_decay = ['bias', 'LayerNorm.weight'] 87 | optimizer_grouped_parameters = [ 88 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 89 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 90 | ] 91 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 92 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 93 | if args.fp16: 94 | try: 95 | from apex import amp 96 | except ImportError: 97 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 98 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 99 | 100 | # multi-gpu training (should be after apex fp16 initialization) 101 | if args.n_gpu > 1: 102 | model = torch.nn.DataParallel(model) 103 | 104 | # Distributed training (should be after apex fp16 initialization) 105 | if args.local_rank != -1: 106 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 107 | output_device=args.local_rank, 108 | find_unused_parameters=True) 109 | 110 | # Train! 111 | logger.info("***** Running training *****") 112 | logger.info(" Num examples = %d", len(train_dataset)) 113 | logger.info(" Num Epochs = %d", args.num_train_epochs) 114 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 115 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 116 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 117 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 118 | logger.info(" Total optimization steps = %d", t_total) 119 | 120 | global_step = 0 121 | tr_loss, logging_loss = 0.0, 0.0 122 | model.zero_grad() 123 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 124 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 125 | for _ in train_iterator: 126 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 127 | for step, batch in enumerate(epoch_iterator): 128 | model.train() 129 | batch = tuple(t.to(args.device) for t in batch) 130 | inputs = {'input_ids': batch[0], 131 | 'attention_mask': batch[1], 132 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2], 133 | 'start_positions': batch[3], 134 | 'end_positions': batch[4]} 135 | if args.model_type in ['xlnet', 'xlm']: 136 | inputs.update({'cls_index': batch[5], 137 | 'p_mask': batch[6]}) 138 | outputs = model(**inputs) 139 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 140 | 141 | if args.n_gpu > 1: 142 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 143 | if args.gradient_accumulation_steps > 1: 144 | loss = loss / args.gradient_accumulation_steps 145 | 146 | if args.fp16: 147 | with amp.scale_loss(loss, optimizer) as scaled_loss: 148 | scaled_loss.backward() 149 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 150 | else: 151 | loss.backward() 152 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 153 | 154 | tr_loss += loss.item() 155 | if (step + 1) % args.gradient_accumulation_steps == 0: 156 | optimizer.step() 157 | scheduler.step() # Update learning rate schedule 158 | model.zero_grad() 159 | global_step += 1 160 | 161 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 162 | # Log metrics 163 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 164 | results = evaluate(args, model, tokenizer) 165 | for key, value in results.items(): 166 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 167 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 168 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 169 | logging_loss = tr_loss 170 | 171 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 172 | # Save model checkpoint 173 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 174 | if not os.path.exists(output_dir): 175 | os.makedirs(output_dir) 176 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 177 | model_to_save.save_pretrained(output_dir) 178 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 179 | logger.info("Saving model checkpoint to %s", output_dir) 180 | 181 | if args.max_steps > 0 and global_step > args.max_steps: 182 | epoch_iterator.close() 183 | break 184 | if args.max_steps > 0 and global_step > args.max_steps: 185 | train_iterator.close() 186 | break 187 | 188 | if args.local_rank in [-1, 0]: 189 | tb_writer.close() 190 | 191 | return global_step, tr_loss / global_step 192 | 193 | 194 | def evaluate(args, model, tokenizer, prefix=""): 195 | dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) 196 | 197 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 198 | os.makedirs(args.output_dir) 199 | 200 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 201 | # Note that DistributedSampler samples randomly 202 | eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) 203 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 204 | 205 | # Eval! 206 | logger.info("***** Running evaluation {} *****".format(prefix)) 207 | logger.info(" Num examples = %d", len(dataset)) 208 | logger.info(" Batch size = %d", args.eval_batch_size) 209 | all_results = [] 210 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 211 | model.eval() 212 | batch = tuple(t.to(args.device) for t in batch) 213 | with torch.no_grad(): 214 | inputs = {'input_ids': batch[0], 215 | 'attention_mask': batch[1], 216 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids 217 | } 218 | example_indices = batch[3] 219 | if args.model_type in ['xlnet', 'xlm']: 220 | inputs.update({'cls_index': batch[4], 221 | 'p_mask': batch[5]}) 222 | outputs = model(**inputs) 223 | 224 | for i, example_index in enumerate(example_indices): 225 | eval_feature = features[example_index.item()] 226 | unique_id = int(eval_feature.unique_id) 227 | if args.model_type in ['xlnet', 'xlm']: 228 | # XLNet uses a more complex post-processing procedure 229 | result = RawResultExtended(unique_id = unique_id, 230 | start_top_log_probs = to_list(outputs[0][i]), 231 | start_top_index = to_list(outputs[1][i]), 232 | end_top_log_probs = to_list(outputs[2][i]), 233 | end_top_index = to_list(outputs[3][i]), 234 | cls_logits = to_list(outputs[4][i])) 235 | else: 236 | result = RawResult(unique_id = unique_id, 237 | start_logits = to_list(outputs[0][i]), 238 | end_logits = to_list(outputs[1][i])) 239 | all_results.append(result) 240 | 241 | # Compute predictions 242 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) 243 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) 244 | if args.version_2_with_negative: 245 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) 246 | else: 247 | output_null_log_odds_file = None 248 | 249 | if args.model_type in ['xlnet', 'xlm']: 250 | # XLNet uses a more complex post-processing procedure 251 | write_predictions_extended(examples, features, all_results, args.n_best_size, 252 | args.max_answer_length, output_prediction_file, 253 | output_nbest_file, output_null_log_odds_file, args.predict_file, 254 | model.config.start_n_top, model.config.end_n_top, 255 | args.version_2_with_negative, tokenizer, args.verbose_logging) 256 | else: 257 | write_predictions(examples, features, all_results, args.n_best_size, 258 | args.max_answer_length, args.do_lower_case, output_prediction_file, 259 | output_nbest_file, output_null_log_odds_file, args.verbose_logging, 260 | args.version_2_with_negative, args.null_score_diff_threshold) 261 | 262 | # Evaluate with the official SQuAD script 263 | evaluate_options = EVAL_OPTS(data_file=args.predict_file, 264 | pred_file=output_prediction_file, 265 | na_prob_file=output_null_log_odds_file) 266 | results = evaluate_on_squad(evaluate_options) 267 | return results 268 | 269 | 270 | def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): 271 | if args.local_rank not in [-1, 0] and not evaluate: 272 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 273 | 274 | # Load data features from cache or dataset file 275 | input_file = args.predict_file if evaluate else args.train_file 276 | cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( 277 | 'dev' if evaluate else 'train', 278 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 279 | str(args.max_seq_length))) 280 | if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: 281 | logger.info("Loading features from cached file %s", cached_features_file) 282 | features = torch.load(cached_features_file) 283 | else: 284 | logger.info("Creating features from dataset file at %s", input_file) 285 | examples = read_squad_examples(input_file=input_file, 286 | is_training=not evaluate, 287 | version_2_with_negative=args.version_2_with_negative) 288 | features = convert_examples_to_features(examples=examples, 289 | tokenizer=tokenizer, 290 | max_seq_length=args.max_seq_length, 291 | doc_stride=args.doc_stride, 292 | max_query_length=args.max_query_length, 293 | is_training=not evaluate) 294 | if args.local_rank in [-1, 0]: 295 | logger.info("Saving features into cached file %s", cached_features_file) 296 | torch.save(features, cached_features_file) 297 | 298 | if args.local_rank == 0 and not evaluate: 299 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 300 | 301 | # Convert to Tensors and build dataset 302 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 303 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 304 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 305 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 306 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 307 | if evaluate: 308 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 309 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 310 | all_example_index, all_cls_index, all_p_mask) 311 | else: 312 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 313 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 314 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 315 | all_start_positions, all_end_positions, 316 | all_cls_index, all_p_mask) 317 | 318 | if output_examples: 319 | return dataset, examples, features 320 | return dataset 321 | 322 | 323 | def main(): 324 | parser = argparse.ArgumentParser() 325 | 326 | ## Required parameters 327 | parser.add_argument("--train_file", default=None, type=str, required=True, 328 | help="SQuAD json for training. E.g., train-v1.1.json") 329 | parser.add_argument("--predict_file", default=None, type=str, required=True, 330 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 331 | parser.add_argument("--model_type", default=None, type=str, required=True, 332 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 333 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 334 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 335 | parser.add_argument("--output_dir", default=None, type=str, required=True, 336 | help="The output directory where the model checkpoints and predictions will be written.") 337 | 338 | ## Other parameters 339 | parser.add_argument("--config_name", default="", type=str, 340 | help="Pretrained config name or path if not the same as model_name") 341 | parser.add_argument("--tokenizer_name", default="", type=str, 342 | help="Pretrained tokenizer name or path if not the same as model_name") 343 | parser.add_argument("--cache_dir", default="", type=str, 344 | help="Where do you want to store the pre-trained models downloaded from s3") 345 | 346 | parser.add_argument('--version_2_with_negative', action='store_true', 347 | help='If true, the SQuAD examples contain some that do not have an answer.') 348 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, 349 | help="If null_score - best_non_null is greater than the threshold predict null.") 350 | 351 | parser.add_argument("--max_seq_length", default=384, type=int, 352 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 353 | "longer than this will be truncated, and sequences shorter than this will be padded.") 354 | parser.add_argument("--doc_stride", default=128, type=int, 355 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 356 | parser.add_argument("--max_query_length", default=64, type=int, 357 | help="The maximum number of tokens for the question. Questions longer than this will " 358 | "be truncated to this length.") 359 | parser.add_argument("--do_train", action='store_true', 360 | help="Whether to run training.") 361 | parser.add_argument("--do_eval", action='store_true', 362 | help="Whether to run eval on the dev set.") 363 | parser.add_argument("--evaluate_during_training", action='store_true', 364 | help="Rul evaluation during training at each logging step.") 365 | parser.add_argument("--do_lower_case", action='store_true', 366 | help="Set this flag if you are using an uncased model.") 367 | 368 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 369 | help="Batch size per GPU/CPU for training.") 370 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 371 | help="Batch size per GPU/CPU for evaluation.") 372 | parser.add_argument("--learning_rate", default=5e-5, type=float, 373 | help="The initial learning rate for Adam.") 374 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 375 | help="Number of updates steps to accumulate before performing a backward/update pass.") 376 | parser.add_argument("--weight_decay", default=0.0, type=float, 377 | help="Weight deay if we apply some.") 378 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 379 | help="Epsilon for Adam optimizer.") 380 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 381 | help="Max gradient norm.") 382 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 383 | help="Total number of training epochs to perform.") 384 | parser.add_argument("--max_steps", default=-1, type=int, 385 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 386 | parser.add_argument("--warmup_steps", default=0, type=int, 387 | help="Linear warmup over warmup_steps.") 388 | parser.add_argument("--n_best_size", default=20, type=int, 389 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") 390 | parser.add_argument("--max_answer_length", default=30, type=int, 391 | help="The maximum length of an answer that can be generated. This is needed because the start " 392 | "and end predictions are not conditioned on one another.") 393 | parser.add_argument("--verbose_logging", action='store_true', 394 | help="If true, all of the warnings related to data processing will be printed. " 395 | "A number of warnings are expected for a normal SQuAD evaluation.") 396 | 397 | parser.add_argument('--logging_steps', type=int, default=50, 398 | help="Log every X updates steps.") 399 | parser.add_argument('--save_steps', type=int, default=50, 400 | help="Save checkpoint every X updates steps.") 401 | parser.add_argument("--eval_all_checkpoints", action='store_true', 402 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 403 | parser.add_argument("--no_cuda", action='store_true', 404 | help="Whether not to use CUDA when available") 405 | parser.add_argument('--overwrite_output_dir', action='store_true', 406 | help="Overwrite the content of the output directory") 407 | parser.add_argument('--overwrite_cache', action='store_true', 408 | help="Overwrite the cached training and evaluation sets") 409 | parser.add_argument('--seed', type=int, default=42, 410 | help="random seed for initialization") 411 | 412 | parser.add_argument("--local_rank", type=int, default=-1, 413 | help="local_rank for distributed training on gpus") 414 | parser.add_argument('--fp16', action='store_true', 415 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 416 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 417 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 418 | "See details at https://nvidia.github.io/apex/amp.html") 419 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 420 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 421 | args = parser.parse_args() 422 | 423 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 424 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 425 | 426 | # Setup distant debugging if needed 427 | if args.server_ip and args.server_port: 428 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 429 | import ptvsd 430 | print("Waiting for debugger attach") 431 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 432 | ptvsd.wait_for_attach() 433 | 434 | # Setup CUDA, GPU & distributed training 435 | if args.local_rank == -1 or args.no_cuda: 436 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 437 | args.n_gpu = torch.cuda.device_count() 438 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 439 | torch.cuda.set_device(args.local_rank) 440 | device = torch.device("cuda", args.local_rank) 441 | torch.distributed.init_process_group(backend='nccl') 442 | args.n_gpu = 1 443 | args.device = device 444 | 445 | # Setup logging 446 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 447 | datefmt = '%m/%d/%Y %H:%M:%S', 448 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 449 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 450 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 451 | 452 | # Set seed 453 | set_seed(args) 454 | 455 | # Load pretrained model and tokenizer 456 | if args.local_rank not in [-1, 0]: 457 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 458 | 459 | args.model_type = args.model_type.lower() 460 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 461 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 462 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) 463 | 464 | model = model_class.from_pretrained("bert-base-chinese", config=config) 465 | if args.local_rank == 0: 466 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 467 | 468 | model.to(args.device) 469 | 470 | logger.info("Training/evaluation parameters %s", args) 471 | 472 | # Training 473 | if args.do_train: 474 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) 475 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 476 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 477 | 478 | 479 | # Save the trained model and the tokenizer 480 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 481 | # Create output directory if needed 482 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 483 | os.makedirs(args.output_dir) 484 | 485 | logger.info("Saving model checkpoint to %s", args.output_dir) 486 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 487 | model_to_save.save_pretrained(args.output_dir) 488 | tokenizer.save_pretrained(args.output_dir) 489 | 490 | # Good practice: save your training arguments together with the trained model 491 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 492 | 493 | # Load a trained model and vocabulary that you have fine-tuned 494 | model = model_class.from_pretrained(args.output_dir) 495 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 496 | model.to(args.device) 497 | 498 | 499 | # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory 500 | results = {} 501 | if args.do_eval and args.local_rank in [-1, 0]: 502 | checkpoints = [args.output_dir] 503 | if args.eval_all_checkpoints: 504 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 505 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs 506 | 507 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 508 | 509 | for checkpoint in checkpoints: 510 | # Reload the model 511 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 512 | model = model_class.from_pretrained(checkpoint) 513 | model.to(args.device) 514 | 515 | # Evaluate 516 | result = evaluate(args, model, tokenizer, prefix=global_step) 517 | 518 | result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) 519 | results.update(result) 520 | 521 | logger.info("Results: {}".format(results)) 522 | 523 | return results 524 | 525 | 526 | if __name__ == "__main__": 527 | main() -------------------------------------------------------------------------------- /utils_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Load SQuAD dataset. """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import json 21 | import logging 22 | import math 23 | import collections 24 | from io import open 25 | 26 | from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize 27 | 28 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 29 | from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class SquadExample(object): 35 | """ 36 | A single training/test example for the Squad dataset. 37 | For examples without an answer, the start and end position are -1. 38 | """ 39 | 40 | def __init__(self, 41 | qas_id, 42 | question_text, 43 | doc_tokens, 44 | orig_answer_text=None, 45 | start_position=None, 46 | end_position=None, 47 | is_impossible=None): 48 | self.qas_id = qas_id 49 | self.question_text = question_text 50 | self.doc_tokens = doc_tokens 51 | self.orig_answer_text = orig_answer_text 52 | self.start_position = start_position 53 | self.end_position = end_position 54 | self.is_impossible = is_impossible 55 | 56 | def __str__(self): 57 | return self.__repr__() 58 | 59 | def __repr__(self): 60 | s = "" 61 | s += "qas_id: %s" % (self.qas_id) 62 | s += ", question_text: %s" % (self.question_text) 63 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 64 | if self.start_position: 65 | s += ", start_position: %d" % (self.start_position) 66 | if self.end_position: 67 | s += ", end_position: %d" % (self.end_position) 68 | if self.is_impossible: 69 | s += ", is_impossible: %r" % (self.is_impossible) 70 | return s 71 | 72 | 73 | class InputFeatures(object): 74 | """A single set of features of data.""" 75 | 76 | def __init__(self, 77 | unique_id, 78 | example_index, 79 | doc_span_index, 80 | tokens, 81 | token_to_orig_map, 82 | token_is_max_context, 83 | input_ids, 84 | input_mask, 85 | segment_ids, 86 | cls_index, 87 | p_mask, 88 | paragraph_len, 89 | start_position=None, 90 | end_position=None, 91 | is_impossible=None): 92 | self.unique_id = unique_id 93 | self.example_index = example_index 94 | self.doc_span_index = doc_span_index 95 | self.tokens = tokens 96 | self.token_to_orig_map = token_to_orig_map 97 | self.token_is_max_context = token_is_max_context 98 | self.input_ids = input_ids 99 | self.input_mask = input_mask 100 | self.segment_ids = segment_ids 101 | self.cls_index = cls_index 102 | self.p_mask = p_mask 103 | self.paragraph_len = paragraph_len 104 | self.start_position = start_position 105 | self.end_position = end_position 106 | self.is_impossible = is_impossible 107 | 108 | 109 | def read_squad_examples(input_file, is_training, version_2_with_negative): 110 | """Read a SQuAD json file into a list of SquadExample.""" 111 | with open(input_file, "r", encoding='utf-8') as reader: 112 | input_data = json.load(reader)["data"] 113 | 114 | def is_whitespace(c): 115 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 116 | return True 117 | return False 118 | 119 | def is_english_or_number(c): 120 | return (ord(c) > 64 and ord(c) < 91) or (ord(c) < 123 and ord(c) > 96) 121 | 122 | examples = [] 123 | for entry in input_data: 124 | for paragraph in entry["paragraphs"]: 125 | paragraph_text = paragraph["context"] 126 | doc_tokens = [] 127 | char_to_word_offset = [] 128 | prev_is_whitespace = True 129 | for c in paragraph_text: 130 | if is_whitespace(c): 131 | continue 132 | doc_tokens.append(c) 133 | char_to_word_offset.append(len(doc_tokens) - 1) 134 | 135 | for qa in paragraph["qas"]: 136 | qas_id = qa["id"] 137 | question_text = qa["question"] 138 | start_position = None 139 | end_position = None 140 | orig_answer_text = None 141 | is_impossible = False 142 | if is_training: 143 | if (len(qa["answers"]) != 1) and (not is_impossible): 144 | raise ValueError( 145 | "For training, each question should have exactly 1 answer." 146 | ) 147 | answer = qa["answers"][0] 148 | orig_answer_text = answer["text"] 149 | answer_offset = answer["answer_start"] 150 | answer_length = len(orig_answer_text) 151 | if answer_offset > len(char_to_word_offset) - 1: 152 | logger.warning("样本错误: '%s' offfset vs. length'%s'", 153 | answer_offset, len(char_to_word_offset)) 154 | continue 155 | start_position = char_to_word_offset[answer_offset] 156 | end_position = answer_offset + answer_length - 1 157 | if end_position > len(char_to_word_offset) - 1: 158 | logger.warning("样本错误: '%s' vs. '%s'", end_position, len(char_to_word_offset)) 159 | continue 160 | end_position = char_to_word_offset[answer_offset + 161 | answer_length - 1] 162 | # Only add answers where the text can be exactly recovered from the 163 | # document. If this CAN'T happen it's likely due to weird Unicode 164 | # stuff so we will just skip the example. 165 | # 166 | # Note that this means for training mode, every example is NOT 167 | # guaranteed to be preserved. 168 | actual_text = "".join( 169 | doc_tokens[start_position:(end_position + 1)]) 170 | cleaned_answer_text = "".join( 171 | whitespace_tokenize(orig_answer_text)) 172 | if actual_text.find(cleaned_answer_text) == -1: 173 | logger.warning("样本错误: '%s' vs. '%s'", actual_text, 174 | cleaned_answer_text) 175 | continue 176 | 177 | example = SquadExample( 178 | qas_id=qas_id, 179 | question_text=question_text, 180 | doc_tokens=doc_tokens, 181 | orig_answer_text=orig_answer_text, 182 | start_position=start_position, 183 | end_position=end_position, 184 | is_impossible=is_impossible) 185 | examples.append(example) 186 | return examples 187 | 188 | 189 | def convert_examples_to_features(examples, 190 | tokenizer, 191 | max_seq_length, 192 | doc_stride, 193 | max_query_length, 194 | is_training, 195 | cls_token_at_end=False, 196 | cls_token='[CLS]', 197 | sep_token='[SEP]', 198 | pad_token=0, 199 | sequence_a_segment_id=0, 200 | sequence_b_segment_id=1, 201 | cls_token_segment_id=0, 202 | pad_token_segment_id=0, 203 | mask_padding_with_zero=True): 204 | """Loads a data file into a list of `InputBatch`s.""" 205 | 206 | unique_id = 1000000000 207 | # cnt_pos, cnt_neg = 0, 0 208 | # max_N, max_M = 1024, 1024 209 | # f = np.zeros((max_N, max_M), dtype=np.float32) 210 | 211 | features = [] 212 | for (example_index, example) in enumerate(examples): 213 | 214 | # if example_index % 100 == 0: 215 | # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg) 216 | 217 | query_tokens = tokenizer.tokenize(example.question_text) 218 | 219 | if len(query_tokens) > max_query_length: 220 | query_tokens = query_tokens[0:max_query_length] 221 | 222 | tok_to_orig_index = [] 223 | orig_to_tok_index = [] 224 | all_doc_tokens = [] 225 | for (i, token) in enumerate(example.doc_tokens): 226 | orig_to_tok_index.append(len(all_doc_tokens)) 227 | sub_tokens = tokenizer.tokenize(token) 228 | for sub_token in sub_tokens: 229 | tok_to_orig_index.append(i) 230 | all_doc_tokens.append(sub_token) 231 | 232 | tok_start_position = None 233 | tok_end_position = None 234 | if is_training and example.is_impossible: 235 | tok_start_position = -1 236 | tok_end_position = -1 237 | if is_training and not example.is_impossible: 238 | tok_start_position = orig_to_tok_index[example.start_position] 239 | if example.end_position < len(example.doc_tokens) - 1: 240 | tok_end_position = orig_to_tok_index[example.end_position + 241 | 1] - 1 242 | else: 243 | tok_end_position = len(all_doc_tokens) - 1 244 | (tok_start_position, tok_end_position) = _improve_answer_span( 245 | all_doc_tokens, tok_start_position, tok_end_position, 246 | tokenizer, example.orig_answer_text) 247 | 248 | # The -3 accounts for [CLS], [SEP] and [SEP] 249 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 250 | 251 | # We can have documents that are longer than the maximum sequence length. 252 | # To deal with this we do a sliding window approach, where we take chunks 253 | # of the up to our max length with a stride of `doc_stride`. 254 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 255 | "DocSpan", ["start", "length"]) 256 | doc_spans = [] 257 | start_offset = 0 258 | while start_offset < len(all_doc_tokens): 259 | length = len(all_doc_tokens) - start_offset 260 | if length > max_tokens_for_doc: 261 | length = max_tokens_for_doc 262 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 263 | if start_offset + length == len(all_doc_tokens): 264 | break 265 | start_offset += min(length, doc_stride) 266 | 267 | for (doc_span_index, doc_span) in enumerate(doc_spans): 268 | tokens = [] 269 | token_to_orig_map = {} 270 | token_is_max_context = {} 271 | segment_ids = [] 272 | 273 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 274 | # Original TF implem also keep the classification token (set to 0) (not sure why...) 275 | p_mask = [] 276 | 277 | # CLS token at the beginning 278 | if not cls_token_at_end: 279 | tokens.append(cls_token) 280 | segment_ids.append(cls_token_segment_id) 281 | p_mask.append(0) 282 | cls_index = 0 283 | 284 | # Query 285 | for token in query_tokens: 286 | tokens.append(token) 287 | segment_ids.append(sequence_a_segment_id) 288 | p_mask.append(1) 289 | 290 | # SEP token 291 | tokens.append(sep_token) 292 | segment_ids.append(sequence_a_segment_id) 293 | p_mask.append(1) 294 | 295 | # Paragraph 296 | for i in range(doc_span.length): 297 | split_token_index = doc_span.start + i 298 | token_to_orig_map[len( 299 | tokens)] = tok_to_orig_index[split_token_index] 300 | 301 | is_max_context = _check_is_max_context( 302 | doc_spans, doc_span_index, split_token_index) 303 | token_is_max_context[len(tokens)] = is_max_context 304 | tokens.append(all_doc_tokens[split_token_index]) 305 | segment_ids.append(sequence_b_segment_id) 306 | p_mask.append(0) 307 | paragraph_len = doc_span.length 308 | 309 | # SEP token 310 | tokens.append(sep_token) 311 | segment_ids.append(sequence_b_segment_id) 312 | p_mask.append(1) 313 | 314 | # CLS token at the end 315 | if cls_token_at_end: 316 | tokens.append(cls_token) 317 | segment_ids.append(cls_token_segment_id) 318 | p_mask.append(0) 319 | cls_index = len(tokens) - 1 # Index of classification token 320 | 321 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 322 | 323 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 324 | # tokens are attended to. 325 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 326 | 327 | # Zero-pad up to the sequence length. 328 | while len(input_ids) < max_seq_length: 329 | input_ids.append(pad_token) 330 | input_mask.append(0 if mask_padding_with_zero else 1) 331 | segment_ids.append(pad_token_segment_id) 332 | p_mask.append(1) 333 | 334 | assert len(input_ids) == max_seq_length 335 | assert len(input_mask) == max_seq_length 336 | assert len(segment_ids) == max_seq_length 337 | 338 | span_is_impossible = example.is_impossible 339 | start_position = None 340 | end_position = None 341 | if is_training and not span_is_impossible: 342 | # For training, if our document chunk does not contain an annotation 343 | # we throw it out, since there is nothing to predict. 344 | doc_start = doc_span.start 345 | doc_end = doc_span.start + doc_span.length - 1 346 | out_of_span = False 347 | if not (tok_start_position >= doc_start 348 | and tok_end_position <= doc_end): 349 | out_of_span = True 350 | if out_of_span: 351 | start_position = 0 352 | end_position = 0 353 | span_is_impossible = True 354 | else: 355 | doc_offset = len(query_tokens) + 2 356 | start_position = tok_start_position - doc_start + doc_offset 357 | end_position = tok_end_position - doc_start + doc_offset 358 | 359 | if is_training and span_is_impossible: 360 | start_position = cls_index 361 | end_position = cls_index 362 | 363 | if example_index < 20: 364 | logger.info("*** Example ***") 365 | logger.info("unique_id: %s" % (unique_id)) 366 | logger.info("example_index: %s" % (example_index)) 367 | logger.info("doc_span_index: %s" % (doc_span_index)) 368 | logger.info("tokens: %s" % " ".join(tokens)) 369 | logger.info("token_to_orig_map: %s" % " ".join( 370 | ["%d:%d" % (x, y) 371 | for (x, y) in token_to_orig_map.items()])) 372 | logger.info("token_is_max_context: %s" % " ".join([ 373 | "%d:%s" % (x, y) 374 | for (x, y) in token_is_max_context.items() 375 | ])) 376 | logger.info( 377 | "input_ids: %s" % " ".join([str(x) for x in input_ids])) 378 | logger.info( 379 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 380 | logger.info( 381 | "segment_ids: %s" % " ".join([str(x) 382 | for x in segment_ids])) 383 | if is_training and span_is_impossible: 384 | logger.info("impossible example") 385 | if is_training and not span_is_impossible: 386 | answer_text = " ".join( 387 | tokens[start_position:(end_position + 1)]) 388 | logger.info("start_position: %d" % (start_position)) 389 | logger.info("end_position: %d" % (end_position)) 390 | logger.info("answer: %s" % (answer_text)) 391 | 392 | features.append( 393 | InputFeatures( 394 | unique_id=unique_id, 395 | example_index=example_index, 396 | doc_span_index=doc_span_index, 397 | tokens=tokens, 398 | token_to_orig_map=token_to_orig_map, 399 | token_is_max_context=token_is_max_context, 400 | input_ids=input_ids, 401 | input_mask=input_mask, 402 | segment_ids=segment_ids, 403 | cls_index=cls_index, 404 | p_mask=p_mask, 405 | paragraph_len=paragraph_len, 406 | start_position=start_position, 407 | end_position=end_position, 408 | is_impossible=span_is_impossible)) 409 | unique_id += 1 410 | 411 | return features 412 | 413 | 414 | # convert_examples_to_features(examp) 415 | 416 | 417 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 418 | orig_answer_text): 419 | """Returns tokenized answer spans that better match the annotated answer.""" 420 | 421 | # The SQuAD annotations are character based. We first project them to 422 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 423 | # often find a "better match". For example: 424 | # 425 | # Question: What year was John Smith born? 426 | # Context: The leader was John Smith (1895-1943). 427 | # Answer: 1895 428 | # 429 | # The original whitespace-tokenized answer will be "(1895-1943).". However 430 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 431 | # the exact answer, 1895. 432 | # 433 | # However, this is not always possible. Consider the following: 434 | # 435 | # Question: What country is the top exporter of electornics? 436 | # Context: The Japanese electronics industry is the lagest in the world. 437 | # Answer: Japan 438 | # 439 | # In this case, the annotator chose "Japan" as a character sub-span of 440 | # the word "Japanese". Since our WordPiece tokenizer does not split 441 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 442 | # in SQuAD, but does happen. 443 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 444 | 445 | for new_start in range(input_start, input_end + 1): 446 | for new_end in range(input_end, new_start - 1, -1): 447 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 448 | if text_span == tok_answer_text: 449 | return (new_start, new_end) 450 | 451 | return (input_start, input_end) 452 | 453 | 454 | def _check_is_max_context(doc_spans, cur_span_index, position): 455 | """Check if this is the 'max context' doc span for the token.""" 456 | 457 | # Because of the sliding window approach taken to scoring documents, a single 458 | # token can appear in multiple documents. E.g. 459 | # Doc: the man went to the store and bought a gallon of milk 460 | # Span A: the man went to the 461 | # Span B: to the store and bought 462 | # Span C: and bought a gallon of 463 | # ... 464 | # 465 | # Now the word 'bought' will have two scores from spans B and C. We only 466 | # want to consider the score with "maximum context", which we define as 467 | # the *minimum* of its left and right context (the *sum* of left and 468 | # right context will always be the same, of course). 469 | # 470 | # In the example the maximum context for 'bought' would be span C since 471 | # it has 1 left context and 3 right context, while span B has 4 left context 472 | # and 0 right context. 473 | best_score = None 474 | best_span_index = None 475 | for (span_index, doc_span) in enumerate(doc_spans): 476 | end = doc_span.start + doc_span.length - 1 477 | if position < doc_span.start: 478 | continue 479 | if position > end: 480 | continue 481 | num_left_context = position - doc_span.start 482 | num_right_context = end - position 483 | score = min(num_left_context, 484 | num_right_context) + 0.01 * doc_span.length 485 | if best_score is None or score > best_score: 486 | best_score = score 487 | best_span_index = span_index 488 | 489 | return cur_span_index == best_span_index 490 | 491 | 492 | RawResult = collections.namedtuple("RawResult", 493 | ["unique_id", "start_logits", "end_logits"]) 494 | 495 | 496 | def write_predictions(all_examples, all_features, all_results, n_best_size, 497 | max_answer_length, do_lower_case, output_prediction_file, 498 | output_nbest_file, output_null_log_odds_file, 499 | verbose_logging, version_2_with_negative, 500 | null_score_diff_threshold): 501 | """Write final predictions to the json file and log-odds of null if needed.""" 502 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 503 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 504 | 505 | example_index_to_features = collections.defaultdict(list) 506 | for feature in all_features: 507 | example_index_to_features[feature.example_index].append(feature) 508 | 509 | unique_id_to_result = {} 510 | for result in all_results: 511 | unique_id_to_result[result.unique_id] = result 512 | 513 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 514 | "PrelimPrediction", [ 515 | "feature_index", "start_index", "end_index", "start_logit", 516 | "end_logit" 517 | ]) 518 | 519 | all_predictions = collections.OrderedDict() 520 | all_nbest_json = collections.OrderedDict() 521 | scores_diff_json = collections.OrderedDict() 522 | 523 | for (example_index, example) in enumerate(all_examples): 524 | features = example_index_to_features[example_index] 525 | 526 | prelim_predictions = [] 527 | # keep track of the minimum score of null start+end of position 0 528 | score_null = 1000000 # large and positive 529 | min_null_feature_index = 0 # the paragraph slice with min null score 530 | null_start_logit = 0 # the start logit at the slice with min null score 531 | null_end_logit = 0 # the end logit at the slice with min null score 532 | for (feature_index, feature) in enumerate(features): 533 | result = unique_id_to_result[feature.unique_id] 534 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 535 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 536 | # if we could have irrelevant answers, get the min score of irrelevant 537 | if version_2_with_negative: 538 | feature_null_score = result.start_logits[ 539 | 0] + result.end_logits[0] 540 | if feature_null_score < score_null: 541 | score_null = feature_null_score 542 | min_null_feature_index = feature_index 543 | null_start_logit = result.start_logits[0] 544 | null_end_logit = result.end_logits[0] 545 | for start_index in start_indexes: 546 | for end_index in end_indexes: 547 | # We could hypothetically create invalid predictions, e.g., predict 548 | # that the start of the span is in the question. We throw out all 549 | # invalid predictions. 550 | if start_index >= len(feature.tokens): 551 | continue 552 | if end_index >= len(feature.tokens): 553 | continue 554 | if start_index not in feature.token_to_orig_map: 555 | continue 556 | if end_index not in feature.token_to_orig_map: 557 | continue 558 | if not feature.token_is_max_context.get( 559 | start_index, False): 560 | continue 561 | if end_index < start_index: 562 | continue 563 | length = end_index - start_index + 1 564 | if length > max_answer_length: 565 | continue 566 | prelim_predictions.append( 567 | _PrelimPrediction( 568 | feature_index=feature_index, 569 | start_index=start_index, 570 | end_index=end_index, 571 | start_logit=result.start_logits[start_index], 572 | end_logit=result.end_logits[end_index])) 573 | if version_2_with_negative: 574 | prelim_predictions.append( 575 | _PrelimPrediction( 576 | feature_index=min_null_feature_index, 577 | start_index=0, 578 | end_index=0, 579 | start_logit=null_start_logit, 580 | end_logit=null_end_logit)) 581 | prelim_predictions = sorted( 582 | prelim_predictions, 583 | key=lambda x: (x.start_logit + x.end_logit), 584 | reverse=True) 585 | 586 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 587 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 588 | 589 | seen_predictions = {} 590 | nbest = [] 591 | for pred in prelim_predictions: 592 | if len(nbest) >= n_best_size: 593 | break 594 | feature = features[pred.feature_index] 595 | if pred.start_index > 0: # this is a non-null prediction 596 | tok_tokens = feature.tokens[pred.start_index:( 597 | pred.end_index + 1)] 598 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 599 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 600 | orig_tokens = example.doc_tokens[orig_doc_start:( 601 | orig_doc_end + 1)] 602 | tok_text = " ".join(tok_tokens) 603 | 604 | # De-tokenize WordPieces that have been split off. 605 | tok_text = tok_text.replace(" ##", "") 606 | tok_text = tok_text.replace("##", "") 607 | 608 | # Clean whitespace 609 | tok_text = tok_text.strip() 610 | tok_text = " ".join(tok_text.split()) 611 | orig_text = " ".join(orig_tokens) 612 | 613 | final_text = get_final_text(tok_text, orig_text, do_lower_case, 614 | verbose_logging) 615 | if final_text in seen_predictions: 616 | continue 617 | 618 | seen_predictions[final_text] = True 619 | else: 620 | final_text = "" 621 | seen_predictions[final_text] = True 622 | 623 | nbest.append( 624 | _NbestPrediction( 625 | text=final_text, 626 | start_logit=pred.start_logit, 627 | end_logit=pred.end_logit)) 628 | # if we didn't include the empty option in the n-best, include it 629 | if version_2_with_negative: 630 | if "" not in seen_predictions: 631 | nbest.append( 632 | _NbestPrediction( 633 | text="", 634 | start_logit=null_start_logit, 635 | end_logit=null_end_logit)) 636 | 637 | # In very rare edge cases we could only have single null prediction. 638 | # So we just create a nonce prediction in this case to avoid failure. 639 | if len(nbest) == 1: 640 | nbest.insert( 641 | 0, 642 | _NbestPrediction( 643 | text="empty", start_logit=0.0, end_logit=0.0)) 644 | 645 | # In very rare edge cases we could have no valid predictions. So we 646 | # just create a nonce prediction in this case to avoid failure. 647 | if not nbest: 648 | nbest.append( 649 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 650 | 651 | assert len(nbest) >= 1 652 | 653 | total_scores = [] 654 | best_non_null_entry = None 655 | for entry in nbest: 656 | total_scores.append(entry.start_logit + entry.end_logit) 657 | if not best_non_null_entry: 658 | if entry.text: 659 | best_non_null_entry = entry 660 | 661 | probs = _compute_softmax(total_scores) 662 | 663 | nbest_json = [] 664 | for (i, entry) in enumerate(nbest): 665 | output = collections.OrderedDict() 666 | output["text"] = entry.text 667 | output["probability"] = probs[i] 668 | output["start_logit"] = entry.start_logit 669 | output["end_logit"] = entry.end_logit 670 | nbest_json.append(output) 671 | 672 | assert len(nbest_json) >= 1 673 | 674 | if not version_2_with_negative: 675 | all_predictions[example.qas_id] = nbest_json[0]["text"] 676 | else: 677 | # predict "" iff the null score - the score of best non-null > threshold 678 | score_diff = score_null - best_non_null_entry.start_logit - ( 679 | best_non_null_entry.end_logit) 680 | scores_diff_json[example.qas_id] = score_diff 681 | if score_diff > null_score_diff_threshold: 682 | all_predictions[example.qas_id] = "" 683 | else: 684 | all_predictions[example.qas_id] = best_non_null_entry.text 685 | all_nbest_json[example.qas_id] = nbest_json 686 | 687 | with open(output_prediction_file, "w") as writer: 688 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 689 | 690 | with open(output_nbest_file, "w") as writer: 691 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 692 | 693 | if version_2_with_negative: 694 | with open(output_null_log_odds_file, "w") as writer: 695 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 696 | 697 | return all_predictions 698 | 699 | 700 | def write_predictions_extended( 701 | all_examples, all_features, all_results, n_best_size, 702 | max_answer_length, output_prediction_file, output_nbest_file, 703 | output_null_log_odds_file, orig_data_file, start_n_top, end_n_top, 704 | version_2_with_negative, tokenizer, verbose_logging): 705 | """ XLNet write prediction logic (more complex than Bert's). 706 | Write final predictions to the json file and log-odds of null if needed. 707 | Requires utils_squad_evaluate.py 708 | """ 709 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 710 | "PrelimPrediction", [ 711 | "feature_index", "start_index", "end_index", "start_log_prob", 712 | "end_log_prob" 713 | ]) 714 | 715 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 716 | "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) 717 | 718 | logger.info("Writing predictions to: %s", output_prediction_file) 719 | # logger.info("Writing nbest to: %s" % (output_nbest_file)) 720 | 721 | example_index_to_features = collections.defaultdict(list) 722 | for feature in all_features: 723 | example_index_to_features[feature.example_index].append(feature) 724 | 725 | unique_id_to_result = {} 726 | for result in all_results: 727 | unique_id_to_result[result.unique_id] = result 728 | 729 | all_predictions = collections.OrderedDict() 730 | all_nbest_json = collections.OrderedDict() 731 | scores_diff_json = collections.OrderedDict() 732 | 733 | for (example_index, example) in enumerate(all_examples): 734 | features = example_index_to_features[example_index] 735 | 736 | prelim_predictions = [] 737 | # keep track of the minimum score of null start+end of position 0 738 | score_null = 1000000 # large and positive 739 | 740 | for (feature_index, feature) in enumerate(features): 741 | result = unique_id_to_result[feature.unique_id] 742 | 743 | cur_null_score = result.cls_logits 744 | 745 | # if we could have irrelevant answers, get the min score of irrelevant 746 | score_null = min(score_null, cur_null_score) 747 | 748 | for i in range(start_n_top): 749 | for j in range(end_n_top): 750 | start_log_prob = result.start_top_log_probs[i] 751 | start_index = result.start_top_index[i] 752 | 753 | j_index = i * end_n_top + j 754 | 755 | end_log_prob = result.end_top_log_probs[j_index] 756 | end_index = result.end_top_index[j_index] 757 | 758 | # We could hypothetically create invalid predictions, e.g., predict 759 | # that the start of the span is in the question. We throw out all 760 | # invalid predictions. 761 | if start_index >= feature.paragraph_len - 1: 762 | continue 763 | if end_index >= feature.paragraph_len - 1: 764 | continue 765 | 766 | if not feature.token_is_max_context.get( 767 | start_index, False): 768 | continue 769 | if end_index < start_index: 770 | continue 771 | length = end_index - start_index + 1 772 | if length > max_answer_length: 773 | continue 774 | 775 | prelim_predictions.append( 776 | _PrelimPrediction( 777 | feature_index=feature_index, 778 | start_index=start_index, 779 | end_index=end_index, 780 | start_log_prob=start_log_prob, 781 | end_log_prob=end_log_prob)) 782 | 783 | prelim_predictions = sorted( 784 | prelim_predictions, 785 | key=lambda x: (x.start_log_prob + x.end_log_prob), 786 | reverse=True) 787 | 788 | seen_predictions = {} 789 | nbest = [] 790 | for pred in prelim_predictions: 791 | if len(nbest) >= n_best_size: 792 | break 793 | feature = features[pred.feature_index] 794 | 795 | # XLNet un-tokenizer 796 | # Let's keep it simple for now and see if we need all this later. 797 | # 798 | # tok_start_to_orig_index = feature.tok_start_to_orig_index 799 | # tok_end_to_orig_index = feature.tok_end_to_orig_index 800 | # start_orig_pos = tok_start_to_orig_index[pred.start_index] 801 | # end_orig_pos = tok_end_to_orig_index[pred.end_index] 802 | # paragraph_text = example.paragraph_text 803 | # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() 804 | 805 | # Previously used Bert untokenizer 806 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 807 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 808 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 809 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 810 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens) 811 | 812 | # Clean whitespace 813 | tok_text = tok_text.strip() 814 | tok_text = " ".join(tok_text.split()) 815 | orig_text = " ".join(orig_tokens) 816 | 817 | final_text = get_final_text( 818 | tok_text, orig_text, tokenizer.do_lower_case, verbose_logging) 819 | 820 | if final_text in seen_predictions: 821 | continue 822 | 823 | seen_predictions[final_text] = True 824 | 825 | nbest.append( 826 | _NbestPrediction( 827 | text=final_text, 828 | start_log_prob=pred.start_log_prob, 829 | end_log_prob=pred.end_log_prob)) 830 | 831 | # In very rare edge cases we could have no valid predictions. So we 832 | # just create a nonce prediction in this case to avoid failure. 833 | if not nbest: 834 | nbest.append( 835 | _NbestPrediction( 836 | text="", start_log_prob=-1e6, end_log_prob=-1e6)) 837 | 838 | total_scores = [] 839 | best_non_null_entry = None 840 | for entry in nbest: 841 | total_scores.append(entry.start_log_prob + entry.end_log_prob) 842 | if not best_non_null_entry: 843 | best_non_null_entry = entry 844 | 845 | probs = _compute_softmax(total_scores) 846 | 847 | nbest_json = [] 848 | for (i, entry) in enumerate(nbest): 849 | output = collections.OrderedDict() 850 | output["text"] = entry.text 851 | output["probability"] = probs[i] 852 | output["start_log_prob"] = entry.start_log_prob 853 | output["end_log_prob"] = entry.end_log_prob 854 | nbest_json.append(output) 855 | 856 | assert len(nbest_json) >= 1 857 | assert best_non_null_entry is not None 858 | 859 | score_diff = score_null 860 | scores_diff_json[example.qas_id] = score_diff 861 | # note(zhiliny): always predict best_non_null_entry 862 | # and the evaluation script will search for the best threshold 863 | all_predictions[example.qas_id] = best_non_null_entry.text 864 | 865 | all_nbest_json[example.qas_id] = nbest_json 866 | 867 | with open(output_prediction_file, "w") as writer: 868 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 869 | 870 | with open(output_nbest_file, "w") as writer: 871 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 872 | 873 | if version_2_with_negative: 874 | with open(output_null_log_odds_file, "w") as writer: 875 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 876 | 877 | with open(orig_data_file, "r", encoding='utf-8') as reader: 878 | orig_data = json.load(reader)["data"] 879 | 880 | qid_to_has_ans = make_qid_to_has_ans(orig_data) 881 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 882 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 883 | exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions) 884 | out_eval = {} 885 | 886 | find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, 887 | scores_diff_json, qid_to_has_ans) 888 | 889 | return out_eval 890 | 891 | 892 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 893 | """Project the tokenized prediction back to the original text.""" 894 | 895 | # When we created the data, we kept track of the alignment between original 896 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 897 | # now `orig_text` contains the span of our original text corresponding to the 898 | # span that we predicted. 899 | # 900 | # However, `orig_text` may contain extra characters that we don't want in 901 | # our prediction. 902 | # 903 | # For example, let's say: 904 | # pred_text = steve smith 905 | # orig_text = Steve Smith's 906 | # 907 | # We don't want to return `orig_text` because it contains the extra "'s". 908 | # 909 | # We don't want to return `pred_text` because it's already been normalized 910 | # (the SQuAD eval script also does punctuation stripping/lower casing but 911 | # our tokenizer does additional normalization like stripping accent 912 | # characters). 913 | # 914 | # What we really want to return is "Steve Smith". 915 | # 916 | # Therefore, we have to apply a semi-complicated alignment heuristic between 917 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 918 | # can fail in certain cases in which case we just return `orig_text`. 919 | 920 | def _strip_spaces(text): 921 | ns_chars = [] 922 | ns_to_s_map = collections.OrderedDict() 923 | for (i, c) in enumerate(text): 924 | if c == " ": 925 | continue 926 | ns_to_s_map[len(ns_chars)] = i 927 | ns_chars.append(c) 928 | ns_text = "".join(ns_chars) 929 | return (ns_text, ns_to_s_map) 930 | 931 | # We first tokenize `orig_text`, strip whitespace from the result 932 | # and `pred_text`, and check if they are the same length. If they are 933 | # NOT the same length, the heuristic has failed. If they are the same 934 | # length, we assume the characters are one-to-one aligned. 935 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 936 | 937 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 938 | 939 | start_position = tok_text.find(pred_text) 940 | if start_position == -1: 941 | if verbose_logging: 942 | logger.info( 943 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 944 | return orig_text 945 | end_position = start_position + len(pred_text) - 1 946 | 947 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 948 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 949 | 950 | if len(orig_ns_text) != len(tok_ns_text): 951 | if verbose_logging: 952 | logger.info( 953 | "Length not equal after stripping spaces: '%s' vs '%s'", 954 | orig_ns_text, tok_ns_text) 955 | return orig_text 956 | 957 | # We then project the characters in `pred_text` back to `orig_text` using 958 | # the character-to-character alignment. 959 | tok_s_to_ns_map = {} 960 | for (i, tok_index) in tok_ns_to_s_map.items(): 961 | tok_s_to_ns_map[tok_index] = i 962 | 963 | orig_start_position = None 964 | if start_position in tok_s_to_ns_map: 965 | ns_start_position = tok_s_to_ns_map[start_position] 966 | if ns_start_position in orig_ns_to_s_map: 967 | orig_start_position = orig_ns_to_s_map[ns_start_position] 968 | 969 | if orig_start_position is None: 970 | if verbose_logging: 971 | logger.info("Couldn't map start position") 972 | return orig_text 973 | 974 | orig_end_position = None 975 | if end_position in tok_s_to_ns_map: 976 | ns_end_position = tok_s_to_ns_map[end_position] 977 | if ns_end_position in orig_ns_to_s_map: 978 | orig_end_position = orig_ns_to_s_map[ns_end_position] 979 | 980 | if orig_end_position is None: 981 | if verbose_logging: 982 | logger.info("Couldn't map end position") 983 | return orig_text 984 | 985 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 986 | return output_text 987 | 988 | 989 | def _get_best_indexes(logits, n_best_size): 990 | """Get the n-best logits from a list.""" 991 | index_and_score = sorted( 992 | enumerate(logits), key=lambda x: x[1], reverse=True) 993 | 994 | best_indexes = [] 995 | for i in range(len(index_and_score)): 996 | if i >= n_best_size: 997 | break 998 | best_indexes.append(index_and_score[i][0]) 999 | return best_indexes 1000 | 1001 | 1002 | def _compute_softmax(scores): 1003 | """Compute softmax probability over raw logits.""" 1004 | if not scores: 1005 | return [] 1006 | 1007 | max_score = None 1008 | for score in scores: 1009 | if max_score is None or score > max_score: 1010 | max_score = score 1011 | 1012 | exp_scores = [] 1013 | total_sum = 0.0 1014 | for score in scores: 1015 | x = math.exp(score - max_score) 1016 | exp_scores.append(x) 1017 | total_sum += x 1018 | 1019 | probs = [] 1020 | for score in exp_scores: 1021 | probs.append(score / total_sum) 1022 | return probs --------------------------------------------------------------------------------