├── docs ├── xlnet.coqa.png ├── xlnet.quac.png ├── squad.example.png ├── xlnet.squad.v1.png ├── xlnet.squad.v2.png ├── _config.yml └── index.md ├── .gitmodules ├── tool ├── convert_squad.py ├── convert_quac.py ├── convert_coqa.py ├── eval_quac.py ├── eval_coqa.py └── eval_squad.py ├── .gitignore ├── codalab ├── run_coqa.codalab.sh └── run_quac.codalab.sh ├── run_squad.sh ├── run_coqa.sh ├── run_quac.sh ├── README.md ├── LICENSE └── run_squad.py /docs/xlnet.coqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.coqa.png -------------------------------------------------------------------------------- /docs/xlnet.quac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.quac.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "xlnet"] 2 | path = xlnet 3 | url = https://github.com/zihangdai/xlnet.git 4 | -------------------------------------------------------------------------------- /docs/squad.example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/squad.example.png -------------------------------------------------------------------------------- /docs/xlnet.squad.v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.squad.v1.png -------------------------------------------------------------------------------- /docs/xlnet.squad.v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.squad.v2.png -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | title: Machine Reading Comprehension (MRC) 3 | description: This is a Machine Reading Comprehension (MRC) experiment project, which includes implementations and experiments for various MRC tasks (e.g. SQuAD, CoQA, QuAC, etc.) -------------------------------------------------------------------------------- /tool/convert_squad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | def add_arguments(parser): 5 | parser.add_argument("--input_file", help="path to input file", required=True) 6 | parser.add_argument("--span_file", help="path to answer span file", required=True) 7 | parser.add_argument("--prob_file", help="path to no-answer probability file", required=True) 8 | parser.add_argument("--prob_thres", help="threshold of no-answer probability", required=False, default=1.0, type=float) 9 | 10 | def convert_squad(input_file, 11 | span_file, 12 | prob_file, 13 | prob_thres): 14 | with open(input_file, "r") as file: 15 | input_data = json.load(file) 16 | 17 | span_dict = {} 18 | prob_dict = {} 19 | for data in input_data: 20 | qas_id = data["qas_id"] 21 | predict_text = data["predict_text"] 22 | answer_prob = data["answer_prob"] if "answer_prob" in data else 0.0 23 | 24 | span_dict[qas_id] = predict_text if answer_prob < prob_thres else "" 25 | prob_dict[qas_id] = answer_prob 26 | 27 | with open(span_file, "w") as file: 28 | json.dump(span_dict, file, indent=4) 29 | 30 | with open(prob_file, "w") as file: 31 | json.dump(prob_dict, file, indent=4) 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | add_arguments(parser) 36 | args = parser.parse_args() 37 | convert_squad(args.input_file, args.span_file, args.prob_file, args.prob_thres) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /tool/convert_quac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | 6 | def add_arguments(parser): 7 | parser.add_argument("--input_file", help="path to input file", required=True) 8 | parser.add_argument("--output_file", help="path to output file", required=True) 9 | parser.add_argument("--answer_threshold", help="threshold of answer", required=False, default=0.1, type=float) 10 | 11 | def convert_quac(input_file, 12 | output_file, 13 | answer_threshold): 14 | with open(input_file, "r") as file: 15 | input_data = json.load(file) 16 | 17 | data_lookup = {} 18 | for data in input_data: 19 | qas_id = data["qas_id"] 20 | id_items = qas_id.split('#') 21 | id = id_items[0] 22 | turn_id = int(id_items[1]) 23 | 24 | no_answer = data["no_answer_score"] 25 | 26 | yes_no_list = ["y", "x", "n"] 27 | yes_no = yes_no_list[data["yes_no_id"]] 28 | 29 | follow_up_list = ["y", "m", "n"] 30 | follow_up = follow_up_list[data["follow_up_id"]] 31 | 32 | if no_answer >= answer_threshold: 33 | answer_text = "CANNOTANSWER" 34 | else: 35 | answer_text = data["predict_text"] 36 | 37 | if id not in data_lookup: 38 | data_lookup[id] = [] 39 | 40 | data_lookup[id].append({ 41 | "qas_id": qas_id, 42 | "turn_id": turn_id, 43 | "answer_text": answer_text, 44 | "yes_no": yes_no, 45 | "follow_up": follow_up 46 | }) 47 | 48 | with open(output_file, "w") as file: 49 | for id in data_lookup.keys(): 50 | data_list = sorted(data_lookup[id], key=lambda x: x["turn_id"]) 51 | 52 | output_data = json.dumps({ 53 | "best_span_str": [data["answer_text"] for data in data_list], 54 | "qid": [data["qas_id"] for data in data_list], 55 | "yesno": [data["yes_no"] for data in data_list], 56 | "followup": [data["follow_up"] for data in data_list] 57 | }) 58 | 59 | file.write("{0}\n".format(output_data)) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | add_arguments(parser) 64 | args = parser.parse_args() 65 | convert_quac(args.input_file, args.output_file, args.answer_threshold) 66 | -------------------------------------------------------------------------------- /tool/convert_coqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | 6 | from eval_coqa import CoQAEvaluator 7 | 8 | def add_arguments(parser): 9 | parser.add_argument("--input_file", help="path to input file", required=True) 10 | parser.add_argument("--output_file", help="path to output file", required=True) 11 | parser.add_argument("--answer_threshold", help="threshold of answer", required=False, default=0.1, type=float) 12 | 13 | def convert_coqa(input_file, 14 | output_file, 15 | answer_threshold): 16 | with open(input_file, "r") as file: 17 | input_data = json.load(file) 18 | 19 | output_data = [] 20 | for data in input_data: 21 | id_items = data["qas_id"].split('_') 22 | id = id_items[0] 23 | turn_id = int(id_items[1]) 24 | 25 | score_list = [data["unk_score"], data["yes_score"], data["no_score"], data["num_score"], data["opt_score"]] 26 | answer_list = ["unknown", "yes", "no", "number", "option"] 27 | 28 | score_idx = np.argmax(score_list) 29 | if score_list[score_idx] >= answer_threshold: 30 | answer = answer_list[score_idx] 31 | if answer == "number": 32 | answer_list = ["none", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"] 33 | answer = answer_list[data["num_id"]-1] 34 | elif answer == "option": 35 | answer = data["predict_text"] 36 | norm_question_tokens = CoQAEvaluator.normalize_answer(data["question_text"]).split(" ") 37 | if "or" in norm_question_tokens: 38 | index = norm_question_tokens.index("or") 39 | if index-1 >= 0 and index+1 < len(norm_question_tokens): 40 | answer_list = [norm_question_tokens[index-1], norm_question_tokens[index+1]] 41 | answer = answer_list[data["opt_id"]-1] 42 | else: 43 | answer = data["predict_text"] 44 | 45 | output_data.append({ 46 | "id": id, 47 | "turn_id": turn_id, 48 | "answer": answer 49 | }) 50 | 51 | with open(output_file, "w") as file: 52 | json.dump(output_data, file, indent=4) 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | add_arguments(parser) 57 | args = parser.parse_args() 58 | convert_coqa(args.input_file, args.output_file, args.answer_threshold) 59 | -------------------------------------------------------------------------------- /codalab/run_coqa.codalab.sh: -------------------------------------------------------------------------------- 1 | start_time=`date +%s` 2 | 3 | alias python=python3 4 | 5 | git clone --recurse-submodules https://github.com/stevezheng23/mrc_tf.git 6 | 7 | cd mrc_tf 8 | 9 | mkdir model 10 | mkdir model/xlnet 11 | wget -P model/xlnet https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip 12 | unzip model/xlnet/cased_L-24_H-1024_A-16.zip -d model/xlnet/ 13 | mv model/xlnet/xlnet_cased_L-24_H-1024_A-16 model/xlnet/cased_L-24_H-1024_A-16 14 | rm model/xlnet/cased_L-24_H-1024_A-16.zip 15 | 16 | mkdir data 17 | mkdir data/coqa 18 | cp ../coqa-dev-v1.0.json data/coqa/dev-v1.0.json 19 | 20 | mkdir output 21 | mkdir output/coqa 22 | mkdir output/coqa/data 23 | wget -P output/coqa https://storage.googleapis.com/mrc_data/coqa/coqa_cased_L-24_H-1024_A-16.zip 24 | unzip output/coqa/coqa_cased_L-24_H-1024_A-16.zip -d output/coqa/ 25 | mv output/coqa/coqa_cased_L-24_H-1024_A-16 output/coqa/checkpoint 26 | rm output/coqa/coqa_cased_L-24_H-1024_A-16.zip 27 | 28 | CUDA_VISIBLE_DEVICES=0 python run_coqa.py \ 29 | --spiece_model_file=model/xlnet/cased_L-24_H-1024_A-16/spiece.model \ 30 | --model_config_path=model/xlnet/cased_L-24_H-1024_A-16/xlnet_config.json \ 31 | --init_checkpoint=model/xlnet/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 32 | --task_name='v1.0' \ 33 | --random_seed=1000 \ 34 | --predict_tag='v1.0' \ 35 | --lower_case=false \ 36 | --data_dir=data/coqa/ \ 37 | --output_dir=output/coqa/data \ 38 | --model_dir=output/coqa/checkpoint \ 39 | --export_dir=output/coqa/export \ 40 | --num_turn=-1 \ 41 | --max_seq_length=512 \ 42 | --max_query_length=128 \ 43 | --max_answer_length=16 \ 44 | --train_batch_size=48 \ 45 | --predict_batch_size=16 \ 46 | --num_hosts=1 \ 47 | --num_core_per_host=1 \ 48 | --learning_rate=2e-5 \ 49 | --train_steps=10000 \ 50 | --warmup_steps=1000 \ 51 | --save_steps=1000 \ 52 | --do_train=false \ 53 | --do_predict=true \ 54 | --do_export=false \ 55 | --overwrite_data=false 56 | 57 | python tool/convert_coqa.py \ 58 | --input_file=output/coqa/data/predict.v1.0.summary.json \ 59 | --output_file=output/coqa/data/predict.v1.0.span.json \ 60 | --answer_threshold=0.1 61 | 62 | python tool/eval_coqa.py \ 63 | --data-file=data/coqa/dev-v1.0.json \ 64 | --pred-file=output/coqa/data/predict.v1.0.span.json \ 65 | >> output/coqa/data/predict.v1.0.eval.json 66 | 67 | cp output/coqa/data/predict.v1.0.span.json ../coqa-dev-v1.0.span.json 68 | cp output/coqa/data/predict.v1.0.eval.json ../coqa-dev-v1.0.eval.json 69 | 70 | cd .. 71 | rm -r mrc_tf 72 | 73 | end_time=`date +%s` 74 | echo execution time was `expr $end_time - $start_time` s. 75 | -------------------------------------------------------------------------------- /codalab/run_quac.codalab.sh: -------------------------------------------------------------------------------- 1 | for arg in "$@" 2 | do 3 | case $arg in 4 | -i|--inputfile) 5 | INPUTFILE="$2" 6 | shift 7 | shift 8 | ;; 9 | -o|--outputfile) 10 | OUTPUTFILE="$2" 11 | shift 12 | shift 13 | ;; 14 | esac 15 | done 16 | 17 | echo "input file = ${INPUTFILE}" 18 | echo "output file = ${OUTPUTFILE}" 19 | 20 | start_time=`date +%s` 21 | 22 | alias python=python3 23 | 24 | git clone --recurse-submodules https://github.com/stevezheng23/mrc_tf.git 25 | 26 | cd mrc_tf 27 | 28 | mkdir model 29 | mkdir model/xlnet 30 | wget -P model/xlnet https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip 31 | unzip model/xlnet/cased_L-24_H-1024_A-16.zip -d model/xlnet/ 32 | mv model/xlnet/xlnet_cased_L-24_H-1024_A-16 model/xlnet/cased_L-24_H-1024_A-16 33 | rm model/xlnet/cased_L-24_H-1024_A-16.zip 34 | 35 | mkdir data 36 | mkdir data/quac 37 | cp ../${INPUTFILE} data/quac/dev-v0.2.json 38 | 39 | mkdir output 40 | mkdir output/quac 41 | mkdir output/quac/data 42 | wget -P output/quac https://storage.googleapis.com/mrc_data/quac/quac_cased_L-24_H-1024_A-16.zip 43 | unzip output/quac/quac_cased_L-24_H-1024_A-16.zip -d output/quac/ 44 | mv output/quac/quac_cased_L-24_H-1024_A-16 output/quac/checkpoint 45 | rm output/quac/quac_cased_L-24_H-1024_A-16.zip 46 | 47 | CUDA_VISIBLE_DEVICES=0 python run_quac.py \ 48 | --spiece_model_file=model/xlnet/cased_L-24_H-1024_A-16/spiece.model \ 49 | --model_config_path=model/xlnet/cased_L-24_H-1024_A-16/xlnet_config.json \ 50 | --init_checkpoint=model/xlnet/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 51 | --task_name='v0.2' \ 52 | --random_seed=1000 \ 53 | --predict_tag='v0.2' \ 54 | --lower_case=false \ 55 | --data_dir=data/quac/ \ 56 | --output_dir=output/quac/data \ 57 | --model_dir=output/quac/checkpoint \ 58 | --export_dir=output/quac/export \ 59 | --num_turn=-1 \ 60 | --max_seq_length=512 \ 61 | --max_query_length=128 \ 62 | --max_answer_length=32 \ 63 | --train_batch_size=48 \ 64 | --predict_batch_size=12 \ 65 | --num_hosts=1 \ 66 | --num_core_per_host=1 \ 67 | --learning_rate=2e-5 \ 68 | --train_steps=8000 \ 69 | --warmup_steps=1000 \ 70 | --save_steps=1000 \ 71 | --do_train=false \ 72 | --do_predict=true \ 73 | --do_export=false \ 74 | --overwrite_data=false 75 | 76 | python tool/convert_quac.py \ 77 | --input_file=output/quac/data/predict.v0.2.summary.json \ 78 | --output_file=output/quac/data/predict.v0.2.span.json \ 79 | --answer_threshold=0.1 80 | 81 | python tool/eval_quac.py \ 82 | --val_file=data/quac/dev-v0.2.json \ 83 | --model_output=output/quac/data/predict.v0.2.span.json \ 84 | --o=output/quac/data/predict.v0.2.eval.json 85 | 86 | cp output/quac/data/predict.v0.2.span.json ../${OUTPUTFILE} 87 | 88 | cd .. 89 | rm -r mrc_tf 90 | 91 | end_time=`date +%s` 92 | echo execution time was `expr $end_time - $start_time` s. 93 | -------------------------------------------------------------------------------- /run_squad.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --seqlen=*) 37 | SEQLEN="${i#*=}" 38 | shift 39 | ;; 40 | --querylen=*) 41 | QUERYLEN="${i#*=}" 42 | shift 43 | ;; 44 | --answerlen=*) 45 | ANSWERLEN="${i#*=}" 46 | shift 47 | ;; 48 | --batchsize=*) 49 | BATCHSIZE="${i#*=}" 50 | shift 51 | ;; 52 | --learningrate=*) 53 | LEARNINGRATE="${i#*=}" 54 | shift 55 | ;; 56 | --trainsteps=*) 57 | TRAINSTEPS="${i#*=}" 58 | shift 59 | ;; 60 | --warmupsteps=*) 61 | WARMUPSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --savesteps=*) 65 | SAVESTEPS="${i#*=}" 66 | shift 67 | ;; 68 | esac 69 | done 70 | 71 | echo "gpu device = ${GPUDEVICE}" 72 | echo "num gpus = ${NUMGPUS}" 73 | echo "task name = ${TASKNAME}" 74 | echo "random seed = ${RANDOMSEED}" 75 | echo "predict tag = ${PREDICTTAG}" 76 | echo "model dir = ${MODELDIR}" 77 | echo "data dir = ${DATADIR}" 78 | echo "output dir = ${OUTPUTDIR}" 79 | echo "seq len = ${SEQLEN}" 80 | echo "query len = ${QUERYLEN}" 81 | echo "answer len = ${ANSWERLEN}" 82 | echo "batch size = ${BATCHSIZE}" 83 | echo "learning rate = ${LEARNINGRATE}" 84 | echo "train steps = ${TRAINSTEPS}" 85 | echo "warmup steps = ${WARMUPSTEPS}" 86 | echo "save steps = ${SAVESTEPS}" 87 | 88 | alias python=python3 89 | mkdir ${OUTPUTDIR} 90 | 91 | start_time=`date +%s` 92 | 93 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_squad.py \ 94 | --spiece_model_file=${MODELDIR}/spiece.model \ 95 | --model_config_path=${MODELDIR}/xlnet_config.json \ 96 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 97 | --task_name=${TASKNAME} \ 98 | --random_seed=${RANDOMSEED} \ 99 | --predict_tag=${PREDICTTAG} \ 100 | --lower_case=false \ 101 | --data_dir=${DATADIR}/ \ 102 | --output_dir=${OUTPUTDIR}/data \ 103 | --model_dir=${OUTPUTDIR}/checkpoint \ 104 | --export_dir=${OUTPUTDIR}/export \ 105 | --max_seq_length=${SEQLEN} \ 106 | --max_query_length=${QUERYLEN} \ 107 | --max_answer_length=${ANSWERLEN} \ 108 | --train_batch_size=${BATCHSIZE} \ 109 | --predict_batch_size=${BATCHSIZE} \ 110 | --num_hosts=1 \ 111 | --num_core_per_host=${NUMGPUS} \ 112 | --learning_rate=${LEARNINGRATE} \ 113 | --train_steps=${TRAINSTEPS} \ 114 | --warmup_steps=${WARMUPSTEPS} \ 115 | --save_steps=${SAVESTEPS} \ 116 | --do_train=true \ 117 | --do_predict=false \ 118 | --do_export=false \ 119 | --overwrite_data=false 120 | 121 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_squad.py \ 122 | --spiece_model_file=${MODELDIR}/spiece.model \ 123 | --model_config_path=${MODELDIR}/xlnet_config.json \ 124 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 125 | --task_name=${TASKNAME} \ 126 | --random_seed=${RANDOMSEED} \ 127 | --predict_tag=${PREDICTTAG} \ 128 | --lower_case=false \ 129 | --data_dir=${DATADIR}/ \ 130 | --output_dir=${OUTPUTDIR}/data \ 131 | --model_dir=${OUTPUTDIR}/checkpoint \ 132 | --export_dir=${OUTPUTDIR}/export \ 133 | --max_seq_length=${SEQLEN} \ 134 | --max_query_length=${QUERYLEN} \ 135 | --max_answer_length=${ANSWERLEN} \ 136 | --train_batch_size=${BATCHSIZE} \ 137 | --predict_batch_size=${BATCHSIZE} \ 138 | --num_hosts=1 \ 139 | --num_core_per_host=1 \ 140 | --learning_rate=${LEARNINGRATE} \ 141 | --train_steps=${TRAINSTEPS} \ 142 | --warmup_steps=${WARMUPSTEPS} \ 143 | --save_steps=${SAVESTEPS} \ 144 | --do_train=false \ 145 | --do_predict=true \ 146 | --do_export=false \ 147 | --overwrite_data=false 148 | 149 | python tool/convert_squad.py \ 150 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 151 | --span_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 152 | --prob_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.prob.json 153 | 154 | python tool/eval_squad.py \ 155 | ${DATADIR}/dev-${TASKNAME}.json \ 156 | ${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 157 | --out-file ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json \ 158 | --na-prob-file ${OUTPUTDIR}/data/predict.${PREDICTTAG}.prob.json \ 159 | --na-prob-thresh 1.0 \ 160 | --out-image-dir ${OUTPUTDIR}/data/ 161 | 162 | end_time=`date +%s` 163 | echo execution time was `expr $end_time - $start_time` s. 164 | 165 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /run_coqa.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --numturn=*) 37 | NUMTURN="${i#*=}" 38 | shift 39 | ;; 40 | --seqlen=*) 41 | SEQLEN="${i#*=}" 42 | shift 43 | ;; 44 | --querylen=*) 45 | QUERYLEN="${i#*=}" 46 | shift 47 | ;; 48 | --answerlen=*) 49 | ANSWERLEN="${i#*=}" 50 | shift 51 | ;; 52 | --batchsize=*) 53 | BATCHSIZE="${i#*=}" 54 | shift 55 | ;; 56 | --learningrate=*) 57 | LEARNINGRATE="${i#*=}" 58 | shift 59 | ;; 60 | --trainsteps=*) 61 | TRAINSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --warmupsteps=*) 65 | WARMUPSTEPS="${i#*=}" 66 | shift 67 | ;; 68 | --savesteps=*) 69 | SAVESTEPS="${i#*=}" 70 | shift 71 | ;; 72 | --answerthreshold=*) 73 | ANSWERTHRESHOLD="${i#*=}" 74 | shift 75 | ;; 76 | esac 77 | done 78 | 79 | echo "gpu device = ${GPUDEVICE}" 80 | echo "num gpus = ${NUMGPUS}" 81 | echo "task name = ${TASKNAME}" 82 | echo "random seed = ${RANDOMSEED}" 83 | echo "predict tag = ${PREDICTTAG}" 84 | echo "model dir = ${MODELDIR}" 85 | echo "data dir = ${DATADIR}" 86 | echo "output dir = ${OUTPUTDIR}" 87 | echo "num turn = ${NUMTURN}" 88 | echo "seq len = ${SEQLEN}" 89 | echo "query len = ${QUERYLEN}" 90 | echo "answer len = ${ANSWERLEN}" 91 | echo "batch size = ${BATCHSIZE}" 92 | echo "learning rate = ${LEARNINGRATE}" 93 | echo "train steps = ${TRAINSTEPS}" 94 | echo "warmup steps = ${WARMUPSTEPS}" 95 | echo "save steps = ${SAVESTEPS}" 96 | echo "answer threshold = ${ANSWERTHRESHOLD}" 97 | 98 | alias python=python3 99 | mkdir ${OUTPUTDIR} 100 | 101 | start_time=`date +%s` 102 | 103 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_coqa.py \ 104 | --spiece_model_file=${MODELDIR}/spiece.model \ 105 | --model_config_path=${MODELDIR}/xlnet_config.json \ 106 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 107 | --task_name=${TASKNAME} \ 108 | --random_seed=${RANDOMSEED} \ 109 | --predict_tag=${PREDICTTAG} \ 110 | --lower_case=false \ 111 | --data_dir=${DATADIR}/ \ 112 | --output_dir=${OUTPUTDIR}/data \ 113 | --model_dir=${OUTPUTDIR}/checkpoint \ 114 | --export_dir=${OUTPUTDIR}/export \ 115 | --num_turn=${NUMTURN} \ 116 | --max_seq_length=${SEQLEN} \ 117 | --max_query_length=${QUERYLEN} \ 118 | --max_answer_length=${ANSWERLEN} \ 119 | --train_batch_size=${BATCHSIZE} \ 120 | --predict_batch_size=${BATCHSIZE} \ 121 | --num_hosts=1 \ 122 | --num_core_per_host=${NUMGPUS} \ 123 | --learning_rate=${LEARNINGRATE} \ 124 | --train_steps=${TRAINSTEPS} \ 125 | --warmup_steps=${WARMUPSTEPS} \ 126 | --save_steps=${SAVESTEPS} \ 127 | --do_train=true \ 128 | --do_predict=false \ 129 | --do_export=false \ 130 | --overwrite_data=false 131 | 132 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_coqa.py \ 133 | --spiece_model_file=${MODELDIR}/spiece.model \ 134 | --model_config_path=${MODELDIR}/xlnet_config.json \ 135 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 136 | --task_name=${TASKNAME} \ 137 | --random_seed=${RANDOMSEED} \ 138 | --predict_tag=${PREDICTTAG} \ 139 | --lower_case=false \ 140 | --data_dir=${DATADIR}/ \ 141 | --output_dir=${OUTPUTDIR}/data \ 142 | --model_dir=${OUTPUTDIR}/checkpoint \ 143 | --export_dir=${OUTPUTDIR}/export \ 144 | --num_turn=${NUMTURN} \ 145 | --max_seq_length=${SEQLEN} \ 146 | --max_query_length=${QUERYLEN} \ 147 | --max_answer_length=${ANSWERLEN} \ 148 | --train_batch_size=${BATCHSIZE} \ 149 | --predict_batch_size=${BATCHSIZE} \ 150 | --num_hosts=1 \ 151 | --num_core_per_host=1 \ 152 | --learning_rate=${LEARNINGRATE} \ 153 | --train_steps=${TRAINSTEPS} \ 154 | --warmup_steps=${WARMUPSTEPS} \ 155 | --save_steps=${SAVESTEPS} \ 156 | --do_train=false \ 157 | --do_predict=true \ 158 | --do_export=false \ 159 | --overwrite_data=false 160 | 161 | python tool/convert_coqa.py \ 162 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 163 | --output_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 164 | --answer_threshold=${ANSWERTHRESHOLD} 165 | 166 | rm ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 167 | 168 | python tool/eval_coqa.py \ 169 | --data-file=${DATADIR}/dev-${TASKNAME}.json \ 170 | --pred-file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 171 | >> ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 172 | 173 | end_time=`date +%s` 174 | echo execution time was `expr $end_time - $start_time` s. 175 | 176 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /run_quac.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --numturn=*) 37 | NUMTURN="${i#*=}" 38 | shift 39 | ;; 40 | --seqlen=*) 41 | SEQLEN="${i#*=}" 42 | shift 43 | ;; 44 | --querylen=*) 45 | QUERYLEN="${i#*=}" 46 | shift 47 | ;; 48 | --answerlen=*) 49 | ANSWERLEN="${i#*=}" 50 | shift 51 | ;; 52 | --batchsize=*) 53 | BATCHSIZE="${i#*=}" 54 | shift 55 | ;; 56 | --learningrate=*) 57 | LEARNINGRATE="${i#*=}" 58 | shift 59 | ;; 60 | --trainsteps=*) 61 | TRAINSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --warmupsteps=*) 65 | WARMUPSTEPS="${i#*=}" 66 | shift 67 | ;; 68 | --savesteps=*) 69 | SAVESTEPS="${i#*=}" 70 | shift 71 | ;; 72 | --answerthreshold=*) 73 | ANSWERTHRESHOLD="${i#*=}" 74 | shift 75 | ;; 76 | esac 77 | done 78 | 79 | echo "gpu device = ${GPUDEVICE}" 80 | echo "num gpus = ${NUMGPUS}" 81 | echo "task name = ${TASKNAME}" 82 | echo "random seed = ${RANDOMSEED}" 83 | echo "predict tag = ${PREDICTTAG}" 84 | echo "model dir = ${MODELDIR}" 85 | echo "data dir = ${DATADIR}" 86 | echo "output dir = ${OUTPUTDIR}" 87 | echo "num turn = ${NUMTURN}" 88 | echo "seq len = ${SEQLEN}" 89 | echo "query len = ${QUERYLEN}" 90 | echo "answer len = ${ANSWERLEN}" 91 | echo "batch size = ${BATCHSIZE}" 92 | echo "learning rate = ${LEARNINGRATE}" 93 | echo "train steps = ${TRAINSTEPS}" 94 | echo "warmup steps = ${WARMUPSTEPS}" 95 | echo "save steps = ${SAVESTEPS}" 96 | echo "answer threshold = ${ANSWERTHRESHOLD}" 97 | 98 | alias python=python3 99 | mkdir ${OUTPUTDIR} 100 | 101 | start_time=`date +%s` 102 | 103 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_quac.py \ 104 | --spiece_model_file=${MODELDIR}/spiece.model \ 105 | --model_config_path=${MODELDIR}/xlnet_config.json \ 106 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 107 | --task_name=${TASKNAME} \ 108 | --random_seed=${RANDOMSEED} \ 109 | --predict_tag=${PREDICTTAG} \ 110 | --lower_case=false \ 111 | --data_dir=${DATADIR}/ \ 112 | --output_dir=${OUTPUTDIR}/data \ 113 | --model_dir=${OUTPUTDIR}/checkpoint \ 114 | --export_dir=${OUTPUTDIR}/export \ 115 | --num_turn=${NUMTURN} \ 116 | --max_seq_length=${SEQLEN} \ 117 | --max_query_length=${QUERYLEN} \ 118 | --max_answer_length=${ANSWERLEN} \ 119 | --train_batch_size=${BATCHSIZE} \ 120 | --predict_batch_size=${BATCHSIZE} \ 121 | --num_hosts=1 \ 122 | --num_core_per_host=${NUMGPUS} \ 123 | --learning_rate=${LEARNINGRATE} \ 124 | --train_steps=${TRAINSTEPS} \ 125 | --warmup_steps=${WARMUPSTEPS} \ 126 | --save_steps=${SAVESTEPS} \ 127 | --do_train=true \ 128 | --do_predict=false \ 129 | --do_export=false \ 130 | --overwrite_data=false 131 | 132 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_quac.py \ 133 | --spiece_model_file=${MODELDIR}/spiece.model \ 134 | --model_config_path=${MODELDIR}/xlnet_config.json \ 135 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 136 | --task_name=${TASKNAME} \ 137 | --random_seed=${RANDOMSEED} \ 138 | --predict_tag=${PREDICTTAG} \ 139 | --lower_case=false \ 140 | --data_dir=${DATADIR}/ \ 141 | --output_dir=${OUTPUTDIR}/data \ 142 | --model_dir=${OUTPUTDIR}/checkpoint \ 143 | --export_dir=${OUTPUTDIR}/export \ 144 | --num_turn=${NUMTURN} \ 145 | --max_seq_length=${SEQLEN} \ 146 | --max_query_length=${QUERYLEN} \ 147 | --max_answer_length=${ANSWERLEN} \ 148 | --train_batch_size=${BATCHSIZE} \ 149 | --predict_batch_size=${BATCHSIZE} \ 150 | --num_hosts=1 \ 151 | --num_core_per_host=1 \ 152 | --learning_rate=${LEARNINGRATE} \ 153 | --train_steps=${TRAINSTEPS} \ 154 | --warmup_steps=${WARMUPSTEPS} \ 155 | --save_steps=${SAVESTEPS} \ 156 | --do_train=false \ 157 | --do_predict=true \ 158 | --do_export=false \ 159 | --overwrite_data=false 160 | 161 | python tool/convert_quac.py \ 162 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 163 | --output_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 164 | --answer_threshold=${ANSWERTHRESHOLD} 165 | 166 | rm ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 167 | 168 | python tool/eval_quac.py \ 169 | --val_file=${DATADIR}/dev-${TASKNAME}.json \ 170 | --model_output=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 171 | --o ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 172 | 173 | end_time=`date +%s` 174 | echo execution time was `expr $end_time - $start_time` s. 175 | 176 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Machine reading comprehension (MRC), a task which asks machine to read a given context then answer questions based on its understanding, is considered one of the key problems in artificial intelligence and has significant interest from both academic and industry. Over the past few years, great progress has been made in this field, thanks to various end-to-end trained neural models and high quality datasets with large amount of examples proposed. 3 | 4 | ![squad_example]({{ site.url }}/mrc_tf/squad.example.png){:width="800px"} 5 | 6 | *Figure 1: MRC example from SQuAD 2.0 dev set* 7 | 8 | ## DataSet 9 | * [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension dataset, consisting of questions posed by crowd-workers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. 10 | * [CoQA](https://stanfordnlp.github.io/coqa/) a large-scale dataset for building Conversational Question Answering systems. The goal of the CoQA challenge is to measure the ability of machines to understand a text passage and answer a series of interconnected questions that appear in a conversation. CoQA is pronounced as coca 11 | * [QuAC](https://quac.ai/) is a dataset for modeling, understanding, and participating in information seeking dialog. QuAC introduces challenges not found in existing machine comprehension datasets: its questions are often more open-ended, unanswerable, or only meaningful within the dialog context. 12 | 13 | ## Experiment 14 | ### SQuAD v1.1 15 | 16 | ![xlnet_squad_v1]({{ site.url }}/mrc_tf/xlnet.squad.v1.png){:width="500px"} 17 | 18 | *Figure 2: Illustrations of fine-tuning XLNet on SQuAD v1.1 task* 19 | 20 | | Model | Train Data | # Epoch | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 21 | |:-----------------:|:----------:|:-------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 22 | | XLNet-base | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 85.90 | 92.17 | 23 | | XLNet-large | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 88.61 | 94.28 | 24 | 25 | *Table 1: The dev set performance of XLNet model finetuned on SQuAD v1.1 task* 26 | 27 | ### SQuAD v2.0 28 | 29 | ![xlnet_squad_v2]({{ site.url }}/mrc_tf/xlnet.squad.v2.png){:width="500px"} 30 | 31 | *Figure 3: Illustrations of fine-tuning XLNet on SQuAD v2.0 task* 32 | 33 | | Model | Train Data | # Epoch | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 34 | |:-----------------:|:----------:|:-------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 35 | | XLNet-base | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 80.23 | 82.90 | 36 | | XLNet-large | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 85.72 | 88.36 | 37 | 38 | *Table 2: The dev set performance of XLNet model finetuned on SQuAD v2.0 task* 39 | 40 | ### CoQA v1.0 41 | 42 | ![xlnet_coqa]({{ site.url }}/mrc_tf/xlnet.coqa.png){:width="500px"} 43 | 44 | *Figure 4: Illustrations of fine-tuning XLNet on CoQA v1.0 task* 45 | 46 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | EM | F1 | 47 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:--------:|:--------:| 48 | | XLNet-base | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 76.4 | 84.4 | 49 | | XLNet-large | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 81.8 | 89.4 | 50 | 51 | *Table 3: The dev set performance of XLNet model finetuned on CoQA v1.0 task* 52 | 53 | ### QuAC v0.2 54 | 55 | ![xlnet_quac]({{ site.url }}/mrc_tf/xlnet.quac.png){:width="500px"} 56 | 57 | *>Figure 5: Illustrations of fine-tuning XLNet on QuAC v0.2 task* 58 | 59 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | Overall F1 | HEQQ | HEQD | 60 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:----------:|:------:|:------:| 61 | | XLNet-base | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 66.4 | 62.6 | 6.8 | 62 | | XLNet-large | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 71.5 | 68.0 | 11.1 | 63 | 64 | *Table 3: The dev set performance of XLNet model finetuned on QuAC v0.2 task* 65 | 66 | ## Reference 67 | * Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. [SQuAD: 100,000+ questions for machine comprehension of text](https://arxiv.org/abs/1606.05250) [2016] 68 | * Pranav Rajpurkar, Robin Jia, and Percy Liang. [Know what you don’t know: unanswerable questions for SQuAD](https://arxiv.org/abs/1806.03822) [2018] 69 | * Siva Reddy, Danqi Chen, Christopher D. Manning. [CoQA: A Conversational Question Answering Challenge](https://arxiv.org/abs/1808.07042) [2018] 70 | * Eunsol Choi, He He, Mohit Iyyer, Mark Yatskar, Wen-tau Yih, Yejin Choi, Percy Liang, Luke Zettlemoyer. [QuAC : Question Answering in Context](https://arxiv.org/abs/1808.07036) [2018] 71 | * Danqi Chen. [Neural reading comprehension and beyond](https://cs.stanford.edu/~danqi/papers/thesis.pdf) [2018] 72 | * Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matthew Gardner, Christopher T Clark, Kenton Lee, and Luke S. Zettlemoyer. [Deep contextualized word representations](https://arxiv.org/abs/1802.05365) [2018] 73 | * Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. [Improving language understanding by generative pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) [2018] 74 | * Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever. [Language models are unsupervised multitask learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) [2019] 75 | * Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. [BERT: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805) [2018] 76 | * Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) [2019] 77 | * Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) [2019] 78 | * Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) [2019] 79 | * Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le. [XLNet: Generalized autoregressive pretraining for language understanding](https://arxiv.org/abs/1906.08237) [2019] 80 | * Zihang Dai, Zhilin Yang, Yiming Yang, William W Cohen, Jaime Carbonell, Quoc V Le and Ruslan Salakhutdinov. [Transformer-XL: Attentive language models beyond a fixed-length context](https://arxiv.org/abs/1901.02860) [2019] 81 | -------------------------------------------------------------------------------- /tool/eval_quac.py: -------------------------------------------------------------------------------- 1 | import json, string, re 2 | from collections import Counter, defaultdict 3 | from argparse import ArgumentParser 4 | 5 | 6 | def is_overlapping(x1, x2, y1, y2): 7 | return max(x1, y1) <= min(x2, y2) 8 | 9 | def normalize_answer(s): 10 | """Lower text and remove punctuation, articles and extra whitespace.""" 11 | def remove_articles(text): 12 | return re.sub(r'\b(a|an|the)\b', ' ', text) 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | def remove_punc(text): 16 | exclude = set(string.punctuation) 17 | return ''.join(ch for ch in text if ch not in exclude) 18 | def lower(text): 19 | return text.lower() 20 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 21 | 22 | def f1_score(prediction, ground_truth): 23 | prediction_tokens = normalize_answer(prediction).split() 24 | ground_truth_tokens = normalize_answer(ground_truth).split() 25 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 26 | num_same = sum(common.values()) 27 | if num_same == 0: 28 | return 0 29 | precision = 1.0 * num_same / len(prediction_tokens) 30 | recall = 1.0 * num_same / len(ground_truth_tokens) 31 | f1 = (2 * precision * recall) / (precision + recall) 32 | return f1 33 | 34 | def exact_match_score(prediction, ground_truth): 35 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 36 | 37 | def display_counter(title, c, c2=None): 38 | print(title) 39 | for key, _ in c.most_common(): 40 | if c2: 41 | print('%s: %d / %d, %.1f%%, F1: %.1f' % ( 42 | key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()), sum(c2[key]) * 100. / len(c2[key]))) 43 | else: 44 | print('%s: %d / %d, %.1f%%' % (key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()))) 45 | 46 | def leave_one_out_max(prediction, ground_truths, article): 47 | if len(ground_truths) == 1: 48 | return metric_max_over_ground_truths(prediction, ground_truths, article)[1] 49 | else: 50 | t_f1 = [] 51 | # leave out one ref every time 52 | for i in range(len(ground_truths)): 53 | idxes = list(range(len(ground_truths))) 54 | idxes.pop(i) 55 | refs = [ground_truths[z] for z in idxes] 56 | t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1]) 57 | return 1.0 * sum(t_f1) / len(t_f1) 58 | 59 | 60 | def metric_max_over_ground_truths(prediction, ground_truths, article): 61 | scores_for_ground_truths = [] 62 | for ground_truth in ground_truths: 63 | score = compute_span_overlap(prediction, ground_truth, article) 64 | scores_for_ground_truths.append(score) 65 | return max(scores_for_ground_truths, key=lambda x: x[1]) 66 | 67 | 68 | def handle_cannot(refs): 69 | num_cannot = 0 70 | num_spans = 0 71 | for ref in refs: 72 | if ref == 'CANNOTANSWER': 73 | num_cannot += 1 74 | else: 75 | num_spans += 1 76 | if num_cannot >= num_spans: 77 | refs = ['CANNOTANSWER'] 78 | else: 79 | refs = [x for x in refs if x != 'CANNOTANSWER'] 80 | return refs 81 | 82 | 83 | def leave_one_out(refs): 84 | if len(refs) == 1: 85 | return 1. 86 | splits = [] 87 | for r in refs: 88 | splits.append(r.split()) 89 | t_f1 = 0.0 90 | for i in range(len(refs)): 91 | m_f1 = 0 92 | for j in range(len(refs)): 93 | if i == j: 94 | continue 95 | f1_ij = f1_score(refs[i], refs[j]) 96 | if f1_ij > m_f1: 97 | m_f1 = f1_ij 98 | t_f1 += m_f1 99 | return t_f1 / len(refs) 100 | 101 | 102 | def compute_span_overlap(pred_span, gt_span, text): 103 | if gt_span == 'CANNOTANSWER': 104 | if pred_span == 'CANNOTANSWER': 105 | return 'Exact match', 1.0 106 | return 'No overlap', 0. 107 | fscore = f1_score(pred_span, gt_span) 108 | pred_start = text.find(pred_span) 109 | gt_start = text.find(gt_span) 110 | 111 | if pred_start == -1 or gt_start == -1: 112 | return 'Span indexing error', fscore 113 | 114 | pred_end = pred_start + len(pred_span) 115 | gt_end = gt_start + len(gt_span) 116 | 117 | fscore = f1_score(pred_span, gt_span) 118 | overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end) 119 | 120 | if exact_match_score(pred_span, gt_span): 121 | return 'Exact match', fscore 122 | if overlap: 123 | return 'Partial overlap', fscore 124 | else: 125 | return 'No overlap', fscore 126 | 127 | 128 | def eval_fn(val_results, model_results, verbose): 129 | span_overlap_stats = Counter() 130 | sentence_overlap = 0. 131 | para_overlap = 0. 132 | total_qs = 0. 133 | f1_stats = defaultdict(list) 134 | unfiltered_f1s = [] 135 | human_f1 = [] 136 | HEQ = 0. 137 | DHEQ = 0. 138 | total_dials = 0. 139 | yes_nos = [] 140 | followups = [] 141 | unanswerables = [] 142 | for p in val_results: 143 | for par in p['paragraphs']: 144 | did = par['id'] 145 | qa_list = par['qas'] 146 | good_dial = 1. 147 | for qa in qa_list: 148 | q_idx = qa['id'] 149 | val_spans = [anss['text'] for anss in qa['answers']] 150 | val_spans = handle_cannot(val_spans) 151 | hf1 = leave_one_out(val_spans) 152 | 153 | if did not in model_results or q_idx not in model_results[did]: 154 | print(did, q_idx, 'no prediction for this dialogue id') 155 | good_dial = 0 156 | f1_stats['NO ANSWER'].append(0.0) 157 | yes_nos.append(False) 158 | followups.append(False) 159 | if val_spans == ['CANNOTANSWER']: 160 | unanswerables.append(0.0) 161 | total_qs += 1 162 | unfiltered_f1s.append(0.0) 163 | if hf1 >= args.min_f1: 164 | human_f1.append(hf1) 165 | continue 166 | 167 | pred_span, pred_yesno, pred_followup = model_results[did][q_idx] 168 | 169 | max_overlap, _ = metric_max_over_ground_truths( \ 170 | pred_span, val_spans, par['context']) 171 | max_f1 = leave_one_out_max( \ 172 | pred_span, val_spans, par['context']) 173 | unfiltered_f1s.append(max_f1) 174 | 175 | # dont eval on low agreement instances 176 | if hf1 < args.min_f1: 177 | continue 178 | 179 | human_f1.append(hf1) 180 | yes_nos.append(pred_yesno == qa['yesno']) 181 | followups.append(pred_followup == qa['followup']) 182 | if val_spans == ['CANNOTANSWER']: 183 | unanswerables.append(max_f1) 184 | if verbose: 185 | print("-" * 20) 186 | print(pred_span) 187 | print(val_spans) 188 | print(max_f1) 189 | print("-" * 20) 190 | if max_f1 >= hf1: 191 | HEQ += 1. 192 | else: 193 | good_dial = 0. 194 | span_overlap_stats[max_overlap] += 1 195 | f1_stats[max_overlap].append(max_f1) 196 | total_qs += 1. 197 | DHEQ += good_dial 198 | total_dials += 1 199 | DHEQ_score = 100.0 * DHEQ / total_dials 200 | HEQ_score = 100.0 * HEQ / total_qs 201 | all_f1s = sum(f1_stats.values(), []) 202 | overall_f1 = 100.0 * sum(all_f1s) / len(all_f1s) 203 | unfiltered_f1 = 100.0 * sum(unfiltered_f1s) / len(unfiltered_f1s) 204 | yesno_score = (100.0 * sum(yes_nos) / len(yes_nos)) 205 | followup_score = (100.0 * sum(followups) / len(followups)) 206 | unanswerable_score = (100.0 * sum(unanswerables) / len(unanswerables)) 207 | metric_json = {"unfiltered_f1": unfiltered_f1, "f1": overall_f1, "HEQ": HEQ_score, "DHEQ": DHEQ_score, "yes/no": yesno_score, "followup": followup_score, "unanswerable_acc": unanswerable_score} 208 | if verbose: 209 | print("=======================") 210 | display_counter('Overlap Stats', span_overlap_stats, f1_stats) 211 | print("=======================") 212 | print('Overall F1: %.1f' % overall_f1) 213 | print('Yes/No Accuracy : %.1f' % yesno_score) 214 | print('Followup Accuracy : %.1f' % followup_score) 215 | print('Unfiltered F1 ({0:d} questions): {1:.1f}'.format(len(unfiltered_f1s), unfiltered_f1)) 216 | print('Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)'.format(unanswerable_score, len(unanswerables))) 217 | print('Human F1: %.1f' % (100.0 * sum(human_f1) / len(human_f1))) 218 | print('Model F1 >= Human F1 (Questions): %d / %d, %.1f%%' % (HEQ, total_qs, 100.0 * HEQ / total_qs)) 219 | print('Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials)) 220 | print("=======================") 221 | return metric_json 222 | 223 | if __name__ == "__main__": 224 | parser = ArgumentParser() 225 | parser.add_argument('--val_file', type=str, required=True, help='file containing validation results') 226 | parser.add_argument('--model_output', type=str, required=True, help='Path to model output.') 227 | parser.add_argument('--o', type=str, required=False, help='Path to save score json') 228 | parser.add_argument('--min_f1', type=float, default=0.4, help='file containing validation results') 229 | parser.add_argument('--verbose', action='store_true', help='print individual scores') 230 | args = parser.parse_args() 231 | val = json.load(open(args.val_file, 'r'))['data'] 232 | preds = defaultdict(dict) 233 | total = 0 234 | val_total = 0 235 | for line in open(args.model_output, 'r'): 236 | if line.strip(): 237 | pred_idx = json.loads(line.strip()) 238 | dia_id = pred_idx['qid'][0].split("_q#")[0] 239 | for qid, qspan, qyesno, qfollowup in zip(pred_idx['qid'], pred_idx['best_span_str'], pred_idx['yesno'], pred_idx['followup']): 240 | preds[dia_id][qid] = qspan, qyesno, qfollowup 241 | total += 1 242 | for p in val: 243 | for par in p['paragraphs']: 244 | did = par['id'] 245 | qa_list = par['qas'] 246 | val_total += len(qa_list) 247 | metric_json = eval_fn(val, preds, args.verbose) 248 | if args.o: 249 | with open(args.o, 'w') as fout: 250 | json.dump(metric_json, fout) 251 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Reading Comprehension 2 | Machine reading comprehension (MRC), a task which asks machine to read a given context then answer questions based on its understanding, is considered one of the key problems in artificial intelligence and has significant interest from both academic and industry. Over the past few years, great progress has been made in this field, thanks to various end-to-end trained neural models and high quality datasets with large amount of examples proposed. 3 |

4 |

Figure 1: MRC example from SQuAD 2.0 dev set

5 | 6 | ## Setting 7 | * Python 3.6.7 8 | * Tensorflow 1.13.1 9 | * NumPy 1.13.3 10 | * SentencePiece 0.1.82 11 | 12 | ## DataSet 13 | * [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension dataset, consisting of questions posed by crowd-workers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. 14 | * [CoQA](https://stanfordnlp.github.io/coqa/) a large-scale dataset for building Conversational Question Answering systems. The goal of the CoQA challenge is to measure the ability of machines to understand a text passage and answer a series of interconnected questions that appear in a conversation. CoQA is pronounced as coca 15 | * [QuAC](https://quac.ai/) is a dataset for modeling, understanding, and participating in information seeking dialog. QuAC introduces challenges not found in existing machine comprehension datasets: its questions are often more open-ended, unanswerable, or only meaningful within the dialog context. 16 | 17 | ## Usage 18 | * Run SQuAD experiment 19 | ```bash 20 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_squad.py \ 21 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 22 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 23 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 24 | --task_name=v2.0 \ 25 | --random_seed=100 \ 26 | --predict_tag=xxxxx \ 27 | --data_dir=data/squad/v2.0 \ 28 | --output_dir=output/squad/v2.0/data \ 29 | --model_dir=output/squad/v2.0/checkpoint \ 30 | --export_dir=output/squad/v2.0/export \ 31 | --max_seq_length=512 \ 32 | --train_batch_size=12 \ 33 | --predict_batch_size=12 \ 34 | --num_hosts=1 \ 35 | --num_core_per_host=4 \ 36 | --learning_rate=3e-5 \ 37 | --train_steps=8000 \ 38 | --warmup_steps=1000 \ 39 | --save_steps=1000 \ 40 | --do_train=true \ 41 | --do_predict=true \ 42 | --do_export=true \ 43 | --overwrite_data=false 44 | ``` 45 | * Run CoQA experiment 46 | ```bash 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_coqa.py \ 48 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 49 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 50 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 51 | --task_name=v1.0 \ 52 | --random_seed=100 \ 53 | --predict_tag=xxxxx \ 54 | --data_dir=data/coqa/v1.0 \ 55 | --output_dir=output/coqa/v1.0/data \ 56 | --model_dir=output/coqa/v1.0/checkpoint \ 57 | --export_dir=output/coqa/v1.0/export \ 58 | --max_seq_length=512 \ 59 | --train_batch_size=12 \ 60 | --predict_batch_size=12 \ 61 | --num_hosts=1 \ 62 | --num_core_per_host=4 \ 63 | --learning_rate=3e-5 \ 64 | --train_steps=8000 \ 65 | --warmup_steps=1000 \ 66 | --save_steps=1000 \ 67 | --do_train=true \ 68 | --do_predict=true \ 69 | --do_export=true \ 70 | --overwrite_data=false 71 | ``` 72 | * Run QuAC experiment 73 | ```bash 74 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_quac.py \ 75 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 76 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 77 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 78 | --task_name=v1.0 \ 79 | --random_seed=100 \ 80 | --predict_tag=xxxxx \ 81 | --data_dir=data/quac/v0.2 \ 82 | --output_dir=output/quac/v0.2/data \ 83 | --model_dir=output/quac/v0.2/checkpoint \ 84 | --export_dir=output/quac/v0.2/export \ 85 | --max_seq_length=512 \ 86 | --train_batch_size=12 \ 87 | --predict_batch_size=12 \ 88 | --num_hosts=1 \ 89 | --num_core_per_host=4 \ 90 | --learning_rate=3e-5 \ 91 | --train_steps=8000 \ 92 | --warmup_steps=1000 \ 93 | --save_steps=1000 \ 94 | --do_train=true \ 95 | --do_predict=true \ 96 | --do_export=true \ 97 | --overwrite_data=false 98 | ``` 99 | 100 | ## Experiment 101 | ### SQuAD v1.1 102 |

103 |

Figure 2: Illustrations of fine-tuning XLNet on SQuAD v1.1 task

104 | 105 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 106 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 107 | | XLNet-base | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 85.90 | 92.17 | 108 | | XLNet-large | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 88.61 | 94.28 | 109 | 110 |

Table 1: The dev set performance of XLNet model finetuned on SQuAD v1.1 task

111 | 112 | ### SQuAD v2.0 113 |

114 |

Figure 3: Illustrations of fine-tuning XLNet on SQuAD v2.0 task

115 | 116 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 117 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 118 | | XLNet-base | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 80.23 | 82.90 | 119 | | XLNet-large | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 85.72 | 88.36 | 120 | 121 |

Table 2: The dev set performance of XLNet model finetuned on SQuAD v2.0 task

122 | 123 | ### CoQA v1.0 124 |

125 |

Figure 4: Illustrations of fine-tuning XLNet on CoQA v1.0 task

126 | 127 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | EM | F1 | 128 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:--------:|:--------:| 129 | | XLNet-base | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 76.4 | 84.4 | 130 | | XLNet-large | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 81.8 | 89.4 | 131 | 132 |

Table 3: The dev set performance of XLNet model finetuned on CoQA v1.0 task

133 | 134 | ### QuAC v0.2 135 |

136 |

Figure 5: Illustrations of fine-tuning XLNet on QuAC v0.2 task

137 | 138 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | Overall F1 | HEQQ | HEQD | 139 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:----------:|:------:|:------:| 140 | | XLNet-base | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 66.4 | 62.6 | 6.8 | 141 | | XLNet-large | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 71.5 | 68.0 | 11.1 | 142 | 143 |

Table 3: The dev set performance of XLNet model finetuned on QuAC v0.2 task

144 | 145 | ## Reference 146 | * Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. [SQuAD: 100,000+ questions for machine comprehension of text](https://arxiv.org/abs/1606.05250) [2016] 147 | * Pranav Rajpurkar, Robin Jia, and Percy Liang. [Know what you don’t know: unanswerable questions for SQuAD](https://arxiv.org/abs/1806.03822) [2018] 148 | * Siva Reddy, Danqi Chen, Christopher D. Manning. [CoQA: A Conversational Question Answering Challenge](https://arxiv.org/abs/1808.07042) [2018] 149 | * Eunsol Choi, He He, Mohit Iyyer, Mark Yatskar, Wen-tau Yih, Yejin Choi, Percy Liang, Luke Zettlemoyer. [QuAC : Question Answering in Context](https://arxiv.org/abs/1808.07036) [2018] 150 | * Danqi Chen. [Neural reading comprehension and beyond](https://cs.stanford.edu/~danqi/papers/thesis.pdf) [2018] 151 | * Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matthew Gardner, Christopher T Clark, Kenton Lee, and Luke S. Zettlemoyer. [Deep contextualized word representations](https://arxiv.org/abs/1802.05365) [2018] 152 | * Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. [Improving language understanding by generative pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) [2018] 153 | * Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever. [Language models are unsupervised multitask learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) [2019] 154 | * Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. [BERT: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805) [2018] 155 | * Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) [2019] 156 | * Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) [2019] 157 | * Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) [2019] 158 | * Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le. [XLNet: Generalized autoregressive pretraining for language understanding](https://arxiv.org/abs/1906.08237) [2019] 159 | * Zihang Dai, Zhilin Yang, Yiming Yang, William W Cohen, Jaime Carbonell, Quoc V Le and Ruslan Salakhutdinov. [Transformer-XL: Attentive language models beyond a fixed-length context](https://arxiv.org/abs/1901.02860) [2019] 160 | -------------------------------------------------------------------------------- /tool/eval_coqa.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for CoQA. 2 | 3 | The code is based partially on SQuAD 2.0 evaluation script. 4 | """ 5 | import argparse 6 | import json 7 | import re 8 | import string 9 | import sys 10 | 11 | from collections import Counter, OrderedDict 12 | 13 | OPTS = None 14 | 15 | out_domain = ["reddit", "science"] 16 | in_domain = ["mctest", "gutenberg", "race", "cnn", "wikipedia"] 17 | domain_mappings = {"mctest":"children_stories", "gutenberg":"literature", "race":"mid-high_school", "cnn":"news", "wikipedia":"wikipedia", "science":"science", "reddit":"reddit"} 18 | 19 | 20 | class CoQAEvaluator(): 21 | 22 | def __init__(self, gold_file): 23 | self.gold_data, self.id_to_source = CoQAEvaluator.gold_answers_to_dict(gold_file) 24 | 25 | @staticmethod 26 | def gold_answers_to_dict(gold_file): 27 | dataset = json.load(open(gold_file)) 28 | gold_dict = {} 29 | id_to_source = {} 30 | for story in dataset['data']: 31 | source = story['source'] 32 | story_id = story['id'] 33 | id_to_source[story_id] = source 34 | questions = story['questions'] 35 | multiple_answers = [story['answers']] 36 | multiple_answers += story['additional_answers'].values() 37 | for i, qa in enumerate(questions): 38 | qid = qa['turn_id'] 39 | if i + 1 != qid: 40 | sys.stderr.write("Turn id should match index {}: {}\n".format(i + 1, qa)) 41 | gold_answers = [] 42 | for answers in multiple_answers: 43 | answer = answers[i] 44 | if qid != answer['turn_id']: 45 | sys.stderr.write("Question turn id does match answer: {} {}\n".format(qa, answer)) 46 | gold_answers.append(answer['input_text']) 47 | key = (story_id, qid) 48 | if key in gold_dict: 49 | sys.stderr.write("Gold file has duplicate stories: {}".format(source)) 50 | gold_dict[key] = gold_answers 51 | return gold_dict, id_to_source 52 | 53 | @staticmethod 54 | def preds_to_dict(pred_file): 55 | preds = json.load(open(pred_file)) 56 | pred_dict = {} 57 | for pred in preds: 58 | pred_dict[(pred['id'], pred['turn_id'])] = pred['answer'] 59 | return pred_dict 60 | 61 | @staticmethod 62 | def normalize_answer(s): 63 | """Lower text and remove punctuation, storys and extra whitespace.""" 64 | 65 | def remove_articles(text): 66 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 67 | return re.sub(regex, ' ', text) 68 | 69 | def white_space_fix(text): 70 | return ' '.join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return ''.join(ch for ch in text if ch not in exclude) 75 | 76 | def lower(text): 77 | return text.lower() 78 | 79 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 80 | 81 | @staticmethod 82 | def get_tokens(s): 83 | if not s: return [] 84 | return CoQAEvaluator.normalize_answer(s).split() 85 | 86 | @staticmethod 87 | def compute_exact(a_gold, a_pred): 88 | return int(CoQAEvaluator.normalize_answer(a_gold) == CoQAEvaluator.normalize_answer(a_pred)) 89 | 90 | @staticmethod 91 | def compute_f1(a_gold, a_pred): 92 | gold_toks = CoQAEvaluator.get_tokens(a_gold) 93 | pred_toks = CoQAEvaluator.get_tokens(a_pred) 94 | common = Counter(gold_toks) & Counter(pred_toks) 95 | num_same = sum(common.values()) 96 | if len(gold_toks) == 0 or len(pred_toks) == 0: 97 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 98 | return int(gold_toks == pred_toks) 99 | if num_same == 0: 100 | return 0 101 | precision = 1.0 * num_same / len(pred_toks) 102 | recall = 1.0 * num_same / len(gold_toks) 103 | f1 = (2 * precision * recall) / (precision + recall) 104 | return f1 105 | 106 | @staticmethod 107 | def _compute_turn_score(a_gold_list, a_pred): 108 | f1_sum = 0.0 109 | em_sum = 0.0 110 | if len(a_gold_list) > 1: 111 | for i in range(len(a_gold_list)): 112 | # exclude the current answer 113 | gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:] 114 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in gold_answers) 115 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in gold_answers) 116 | else: 117 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in a_gold_list) 118 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in a_gold_list) 119 | 120 | return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))} 121 | 122 | def compute_turn_score(self, story_id, turn_id, a_pred): 123 | ''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. ''' 124 | key = (story_id, turn_id) 125 | a_gold_list = self.gold_data[key] 126 | return CoQAEvaluator._compute_turn_score(a_gold_list, a_pred) 127 | 128 | def get_raw_scores(self, pred_data): 129 | ''''Returns a dict with score with each turn prediction''' 130 | exact_scores = {} 131 | f1_scores = {} 132 | for story_id, turn_id in self.gold_data: 133 | key = (story_id, turn_id) 134 | if key not in pred_data: 135 | sys.stderr.write('Missing prediction for {} and turn_id: {}\n'.format(story_id, turn_id)) 136 | continue 137 | a_pred = pred_data[key] 138 | scores = self.compute_turn_score(story_id, turn_id, a_pred) 139 | # Take max over all gold answers 140 | exact_scores[key] = scores['em'] 141 | f1_scores[key] = scores['f1'] 142 | return exact_scores, f1_scores 143 | 144 | def get_raw_scores_human(self): 145 | ''''Returns a dict with score for each turn''' 146 | exact_scores = {} 147 | f1_scores = {} 148 | for story_id, turn_id in self.gold_data: 149 | key = (story_id, turn_id) 150 | f1_sum = 0.0 151 | em_sum = 0.0 152 | if len(self.gold_data[key]) > 1: 153 | for i in range(len(self.gold_data[key])): 154 | # exclude the current answer 155 | gold_answers = self.gold_data[key][0:i] + self.gold_data[key][i + 1:] 156 | em_sum += max(CoQAEvaluator.compute_exact(a, self.gold_data[key][i]) for a in gold_answers) 157 | f1_sum += max(CoQAEvaluator.compute_f1(a, self.gold_data[key][i]) for a in gold_answers) 158 | else: 159 | exit("Gold answers should be multiple: {}={}".format(key, self.gold_data[key])) 160 | exact_scores[key] = em_sum / len(self.gold_data[key]) 161 | f1_scores[key] = f1_sum / len(self.gold_data[key]) 162 | return exact_scores, f1_scores 163 | 164 | def human_performance(self): 165 | exact_scores, f1_scores = self.get_raw_scores_human() 166 | return self.get_domain_scores(exact_scores, f1_scores) 167 | 168 | def model_performance(self, pred_data): 169 | exact_scores, f1_scores = self.get_raw_scores(pred_data) 170 | return self.get_domain_scores(exact_scores, f1_scores) 171 | 172 | def get_domain_scores(self, exact_scores, f1_scores): 173 | sources = {} 174 | for source in in_domain + out_domain: 175 | sources[source] = Counter() 176 | 177 | for story_id, turn_id in self.gold_data: 178 | key = (story_id, turn_id) 179 | source = self.id_to_source[story_id] 180 | sources[source]['em_total'] += exact_scores.get(key, 0) 181 | sources[source]['f1_total'] += f1_scores.get(key, 0) 182 | sources[source]['turn_count'] += 1 183 | 184 | scores = OrderedDict() 185 | in_domain_em_total = 0.0 186 | in_domain_f1_total = 0.0 187 | in_domain_turn_count = 0 188 | 189 | out_domain_em_total = 0.0 190 | out_domain_f1_total = 0.0 191 | out_domain_turn_count = 0 192 | 193 | for source in in_domain + out_domain: 194 | domain = domain_mappings[source] 195 | scores[domain] = {} 196 | scores[domain]['em'] = round(sources[source]['em_total'] / max(1, sources[source]['turn_count']) * 100, 1) 197 | scores[domain]['f1'] = round(sources[source]['f1_total'] / max(1, sources[source]['turn_count']) * 100, 1) 198 | scores[domain]['turns'] = sources[source]['turn_count'] 199 | if source in in_domain: 200 | in_domain_em_total += sources[source]['em_total'] 201 | in_domain_f1_total += sources[source]['f1_total'] 202 | in_domain_turn_count += sources[source]['turn_count'] 203 | elif source in out_domain: 204 | out_domain_em_total += sources[source]['em_total'] 205 | out_domain_f1_total += sources[source]['f1_total'] 206 | out_domain_turn_count += sources[source]['turn_count'] 207 | 208 | scores["in_domain"] = {'em': round(in_domain_em_total / max(1, in_domain_turn_count) * 100, 1), 209 | 'f1': round(in_domain_f1_total / max(1, in_domain_turn_count) * 100, 1), 210 | 'turns': in_domain_turn_count} 211 | scores["out_domain"] = {'em': round(out_domain_em_total / max(1, out_domain_turn_count) * 100, 1), 212 | 'f1': round(out_domain_f1_total / max(1, out_domain_turn_count) * 100, 1), 213 | 'turns': out_domain_turn_count} 214 | 215 | em_total = in_domain_em_total + out_domain_em_total 216 | f1_total = in_domain_f1_total + out_domain_f1_total 217 | turn_count = in_domain_turn_count + out_domain_turn_count 218 | scores["overall"] = {'em': round(em_total / max(1, turn_count) * 100, 1), 219 | 'f1': round(f1_total / max(1, turn_count) * 100, 1), 220 | 'turns': turn_count} 221 | 222 | return scores 223 | 224 | def parse_args(): 225 | parser = argparse.ArgumentParser('Official evaluation script for CoQA.') 226 | parser.add_argument('--data-file', dest="data_file", help='Input data JSON file.') 227 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.') 228 | parser.add_argument('--out-file', '-o', metavar='eval.json', 229 | help='Write accuracy metrics to file (default is stdout).') 230 | parser.add_argument('--verbose', '-v', action='store_true') 231 | parser.add_argument('--human', dest="human", action='store_true') 232 | if len(sys.argv) == 1: 233 | parser.print_help() 234 | sys.exit(1) 235 | return parser.parse_args() 236 | 237 | def main(): 238 | evaluator = CoQAEvaluator(OPTS.data_file) 239 | 240 | if OPTS.human: 241 | print(json.dumps(evaluator.human_performance(), indent=2)) 242 | 243 | if OPTS.pred_file: 244 | with open(OPTS.pred_file) as f: 245 | pred_data = CoQAEvaluator.preds_to_dict(OPTS.pred_file) 246 | print(json.dumps(evaluator.model_performance(pred_data), indent=2)) 247 | 248 | if __name__ == '__main__': 249 | OPTS = parse_args() 250 | main() 251 | -------------------------------------------------------------------------------- /tool/eval_squad.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 224 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 225 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 226 | main_eval['best_exact'] = best_exact 227 | main_eval['best_exact_thresh'] = exact_thresh 228 | main_eval['best_f1'] = best_f1 229 | main_eval['best_f1_thresh'] = f1_thresh 230 | 231 | def main(): 232 | with open(OPTS.data_file) as f: 233 | dataset_json = json.load(f) 234 | dataset = dataset_json['data'] 235 | with open(OPTS.pred_file) as f: 236 | preds = json.load(f) 237 | if OPTS.na_prob_file: 238 | with open(OPTS.na_prob_file) as f: 239 | na_probs = json.load(f) 240 | else: 241 | na_probs = {k: 0.0 for k in preds} 242 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 243 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 244 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 245 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 246 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 247 | OPTS.na_prob_thresh) 248 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 249 | OPTS.na_prob_thresh) 250 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 251 | if has_ans_qids: 252 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 253 | merge_eval(out_eval, has_ans_eval, 'HasAns') 254 | if no_ans_qids: 255 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 256 | merge_eval(out_eval, no_ans_eval, 'NoAns') 257 | if OPTS.na_prob_file: 258 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 259 | if OPTS.na_prob_file and OPTS.out_image_dir: 260 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 261 | qid_to_has_ans, OPTS.out_image_dir) 262 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 263 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 264 | if OPTS.out_file: 265 | with open(OPTS.out_file, 'w') as f: 266 | json.dump(out_eval, f) 267 | else: 268 | print(json.dumps(out_eval, indent=2)) 269 | 270 | if __name__ == '__main__': 271 | OPTS = parse_args() 272 | if OPTS.out_image_dir: 273 | import matplotlib 274 | matplotlib.use('Agg') 275 | import matplotlib.pyplot as plt 276 | main() 277 | 278 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /run_squad.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | sys.path.append('xlnet') # walkaround due to submodule absolute import... 7 | 8 | import collections 9 | import os 10 | import os.path 11 | import json 12 | import pickle 13 | import time 14 | 15 | import tensorflow as tf 16 | import numpy as np 17 | import sentencepiece as sp 18 | 19 | from xlnet import xlnet 20 | import function_builder 21 | import prepro_utils 22 | import model_utils 23 | 24 | MAX_FLOAT = 1e30 25 | MIN_FLOAT = -1e30 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("data_dir", None, "Data directory where raw data located.") 31 | flags.DEFINE_string("output_dir", None, "Output directory where processed data located.") 32 | flags.DEFINE_string("model_dir", None, "Model directory where checkpoints located.") 33 | flags.DEFINE_string("export_dir", None, "Export directory where saved model located.") 34 | 35 | flags.DEFINE_string("task_name", default=None, help="The name of the task to train.") 36 | flags.DEFINE_string("model_config_path", default=None, help="Config file of the pre-trained model.") 37 | flags.DEFINE_string("init_checkpoint", default=None, help="Initial checkpoint of the pre-trained model.") 38 | flags.DEFINE_string("spiece_model_file", default=None, help="Sentence Piece model path.") 39 | flags.DEFINE_bool("overwrite_data", default=False, help="If False, will use cached data if available.") 40 | flags.DEFINE_integer("random_seed", default=100, help="Random seed for weight initialzation.") 41 | flags.DEFINE_string("predict_tag", None, "Predict tag for predict result tracking.") 42 | 43 | flags.DEFINE_bool("do_train", default=False, help="Whether to run training.") 44 | flags.DEFINE_bool("do_predict", default=False, help="Whether to run prediction.") 45 | flags.DEFINE_bool("do_export", default=False, help="Whether to run exporting.") 46 | 47 | flags.DEFINE_enum("init", default="normal", enum_values=["normal", "uniform"], help="Initialization method.") 48 | flags.DEFINE_float("init_std", default=0.02, help="Initialization std when init is normal.") 49 | flags.DEFINE_float("init_range", default=0.1, help="Initialization std when init is uniform.") 50 | flags.DEFINE_bool("init_global_vars", default=False, help="If true, init all global vars. If false, init trainable vars only.") 51 | 52 | flags.DEFINE_bool("lower_case", default=False, help="Enable lower case nor not.") 53 | flags.DEFINE_integer("doc_stride", default=128, help="Doc stride") 54 | flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length") 55 | flags.DEFINE_integer("max_query_length", default=64, help="Max query length") 56 | flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length") 57 | flags.DEFINE_integer("train_batch_size", default=48, help="Total batch size for training.") 58 | flags.DEFINE_integer("predict_batch_size", default=32, help="Total batch size for predict.") 59 | 60 | flags.DEFINE_integer("train_steps", default=8000, help="Number of training steps") 61 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps") 62 | flags.DEFINE_integer("max_save", default=5, help="Max number of checkpoints to save. Use 0 to save all.") 63 | flags.DEFINE_integer("save_steps", default=1000, help="Save the model for every save_steps. If None, not to save any model.") 64 | flags.DEFINE_integer("shuffle_buffer", default=2048, help="Buffer size used for shuffle.") 65 | 66 | flags.DEFINE_integer("n_best_size", default=5, help="n best size for predictions") 67 | flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.") 68 | flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.") 69 | flags.DEFINE_string("target_eval_key", default="best_f1", help="Use has_ans_f1 for Model I.") 70 | 71 | flags.DEFINE_bool("use_bfloat16", default=False, help="Whether to use bfloat16.") 72 | flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.") 73 | flags.DEFINE_float("dropatt", default=0.1, help="Attention dropout rate.") 74 | flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length") 75 | flags.DEFINE_string("summary_type", default="last", help="Method used to summarize a sequence into a vector.") 76 | 77 | flags.DEFINE_float("learning_rate", default=3e-5, help="initial learning rate") 78 | flags.DEFINE_float("min_lr_ratio", default=0.0, help="min lr ratio for cos decay.") 79 | flags.DEFINE_float("lr_layer_decay_rate", default=0.75, help="lr[L] = learning_rate, lr[l-1] = lr[l] * lr_layer_decay_rate.") 80 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping") 81 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate") 82 | flags.DEFINE_float("adam_epsilon", default=1e-6, help="Adam epsilon") 83 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos") 84 | 85 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 86 | flags.DEFINE_integer("num_hosts", 1, "How many TPU hosts.") 87 | flags.DEFINE_integer("num_core_per_host", 1, "Total number of TPU cores to use.") 88 | flags.DEFINE_string("tpu_job_name", None, "TPU worker job name.") 89 | flags.DEFINE_string("tpu", None, "The Cloud TPU name to use for training.") 90 | flags.DEFINE_string("tpu_zone", None, "GCE zone where the Cloud TPU is located in.") 91 | flags.DEFINE_string("gcp_project", None, "Project name for the Cloud TPU-enabled project.") 92 | flags.DEFINE_string("master", None, "TensorFlow master URL") 93 | flags.DEFINE_integer("iterations", 1000, "number of iterations per TPU training loop.") 94 | 95 | class InputExample(object): 96 | """A single SQuAD example.""" 97 | def __init__(self, 98 | qas_id, 99 | question_text, 100 | paragraph_text, 101 | orig_answer_text=None, 102 | start_position=None, 103 | is_impossible=False): 104 | self.qas_id = qas_id 105 | self.question_text = question_text 106 | self.paragraph_text = paragraph_text 107 | self.orig_answer_text = orig_answer_text 108 | self.start_position = start_position 109 | self.is_impossible = is_impossible 110 | 111 | def __str__(self): 112 | return self.__repr__() 113 | 114 | def __repr__(self): 115 | s = "qas_id: %s" % (prepro_utils.printable_text(self.qas_id)) 116 | s += ", question_text: %s" % (prepro_utils.printable_text(self.question_text)) 117 | s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text)) 118 | if self.start_position: 119 | s += ", start_position: %d" % (self.start_position) 120 | s += ", is_impossible: %r" % (self.is_impossible) 121 | return s 122 | 123 | class InputFeatures(object): 124 | """A single SQuAD feature.""" 125 | def __init__(self, 126 | unique_id, 127 | qas_id, 128 | doc_idx, 129 | token2char_raw_start_index, 130 | token2char_raw_end_index, 131 | token2doc_index, 132 | input_ids, 133 | input_mask, 134 | p_mask, 135 | segment_ids, 136 | cls_index, 137 | para_length, 138 | start_position=None, 139 | end_position=None, 140 | is_impossible=None): 141 | self.unique_id = unique_id 142 | self.qas_id = qas_id 143 | self.doc_idx = doc_idx 144 | self.token2char_raw_start_index = token2char_raw_start_index 145 | self.token2char_raw_end_index = token2char_raw_end_index 146 | self.token2doc_index = token2doc_index 147 | self.input_ids = input_ids 148 | self.input_mask = input_mask 149 | self.p_mask = p_mask 150 | self.segment_ids = segment_ids 151 | self.cls_index = cls_index 152 | self.para_length = para_length 153 | self.start_position = start_position 154 | self.end_position = end_position 155 | self.is_impossible = is_impossible 156 | 157 | class OutputResult(object): 158 | """A single SQuAD result.""" 159 | def __init__(self, 160 | unique_id, 161 | answer_prob, 162 | start_prob, 163 | start_index, 164 | end_prob, 165 | end_index): 166 | self.unique_id = unique_id 167 | self.answer_prob = answer_prob 168 | self.start_prob = start_prob 169 | self.start_index = start_index 170 | self.end_prob = end_prob 171 | self.end_index = end_index 172 | 173 | class SquadPipeline(object): 174 | """Pipeline for SQuAD dataset.""" 175 | def __init__(self, 176 | data_dir, 177 | task_name): 178 | self.data_dir = data_dir 179 | self.task_name = task_name 180 | 181 | def get_train_examples(self): 182 | """Gets a collection of `InputExample`s for the train set.""" 183 | data_path = os.path.join(self.data_dir, "train-{0}.json".format(self.task_name)) 184 | data_list = self._read_json(data_path) 185 | example_list = self._get_example(data_list, True) 186 | return example_list 187 | 188 | def get_dev_examples(self): 189 | """Gets a collection of `InputExample`s for the dev set.""" 190 | data_path = os.path.join(self.data_dir, "dev-{0}.json".format(self.task_name)) 191 | data_list = self._read_json(data_path) 192 | example_list = self._get_example(data_list, False) 193 | return example_list 194 | 195 | def _read_json(self, 196 | data_path): 197 | if os.path.exists(data_path): 198 | with open(data_path, "r") as file: 199 | data_list = json.load(file)["data"] 200 | return data_list 201 | else: 202 | raise FileNotFoundError("data path not found: {0}".format(data_path)) 203 | 204 | def _get_example(self, 205 | data_list, 206 | is_training): 207 | examples = [] 208 | for entry in data_list: 209 | for paragraph in entry["paragraphs"]: 210 | paragraph_text = paragraph["context"] 211 | 212 | for qa in paragraph["qas"]: 213 | qas_id = qa["id"] 214 | question_text = qa["question"] 215 | start_position = None 216 | orig_answer_text = None 217 | is_impossible = False 218 | 219 | if is_training: 220 | is_impossible = qa["is_impossible"] 221 | if (len(qa["answers"]) != 1) and (not is_impossible): 222 | raise ValueError("For training, each question should have exactly 1 answer.") 223 | 224 | if not is_impossible: 225 | answer = qa["answers"][0] 226 | orig_answer_text = answer["text"] 227 | start_position = answer["answer_start"] 228 | else: 229 | start_position = -1 230 | orig_answer_text = "" 231 | 232 | example = InputExample( 233 | qas_id=qas_id, 234 | question_text=question_text, 235 | paragraph_text=paragraph_text, 236 | orig_answer_text=orig_answer_text, 237 | start_position=start_position, 238 | is_impossible=is_impossible) 239 | 240 | examples.append(example) 241 | 242 | return examples 243 | 244 | class XLNetTokenizer(object): 245 | """Default text tokenizer for XLNet""" 246 | def __init__(self, 247 | sp_model_file, 248 | lower_case=False): 249 | """Construct XLNet tokenizer""" 250 | self.sp_processor = sp.SentencePieceProcessor() 251 | self.sp_processor.Load(sp_model_file) 252 | self.lower_case = lower_case 253 | 254 | def tokenize(self, 255 | text): 256 | """Tokenize text for XLNet""" 257 | processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case) 258 | tokenized_pieces = prepro_utils.encode_pieces(self.sp_processor, processed_text, return_unicode=False) 259 | return tokenized_pieces 260 | 261 | def encode(self, 262 | text): 263 | """Encode text for XLNet""" 264 | processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case) 265 | encoded_ids = prepro_utils.encode_ids(self.sp_processor, processed_text) 266 | return encoded_ids 267 | 268 | def token_to_id(self, 269 | token): 270 | """Convert token to id for XLNet""" 271 | return self.sp_processor.PieceToId(token) 272 | 273 | def id_to_token(self, 274 | id): 275 | """Convert id to token for XLNet""" 276 | return self.sp_processor.IdToPiece(id) 277 | 278 | def tokens_to_ids(self, 279 | tokens): 280 | """Convert tokens to ids for XLNet""" 281 | return [self.sp_processor.PieceToId(token) for token in tokens] 282 | 283 | def ids_to_tokens(self, 284 | ids): 285 | """Convert ids to tokens for XLNet""" 286 | return [self.sp_processor.IdToPiece(id) for id in ids] 287 | 288 | class XLNetExampleProcessor(object): 289 | """Default example processor for XLNet""" 290 | def __init__(self, 291 | max_seq_length, 292 | max_query_length, 293 | doc_stride, 294 | tokenizer): 295 | """Construct XLNet example processor""" 296 | self.special_vocab_list = ["", "", "", "", "", "", "", "", ""] 297 | self.special_vocab_map = {} 298 | for (i, special_vocab) in enumerate(self.special_vocab_list): 299 | self.special_vocab_map[special_vocab] = i 300 | 301 | self.segment_vocab_list = ["

", "", "", "", ""] 302 | self.segment_vocab_map = {} 303 | for (i, segment_vocab) in enumerate(self.segment_vocab_list): 304 | self.segment_vocab_map[segment_vocab] = i 305 | 306 | self.max_seq_length = max_seq_length 307 | self.max_query_length = max_query_length 308 | self.doc_stride = doc_stride 309 | self.tokenizer = tokenizer 310 | self.unique_id = 1000000000 311 | 312 | def _generate_match_mapping(self, 313 | para_text, 314 | tokenized_para_text, 315 | N, 316 | M, 317 | max_N, 318 | max_M): 319 | """Generate match mapping for raw and tokenized paragraph""" 320 | def _lcs_match(para_text, 321 | tokenized_para_text, 322 | N, 323 | M, 324 | max_N, 325 | max_M, 326 | max_dist): 327 | """longest common sub-sequence 328 | 329 | f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j)) 330 | 331 | unlike standard LCS, this is specifically optimized for the setting 332 | because the mismatch between sentence pieces and original text will be small 333 | """ 334 | f = np.zeros((max_N, max_M), dtype=np.float32) 335 | g = {} 336 | 337 | for i in range(N): 338 | for j in range(i - max_dist, i + max_dist): 339 | if j >= M or j < 0: 340 | continue 341 | 342 | if i > 0: 343 | g[(i, j)] = 0 344 | f[i, j] = f[i - 1, j] 345 | 346 | if j > 0 and f[i, j - 1] > f[i, j]: 347 | g[(i, j)] = 1 348 | f[i, j] = f[i, j - 1] 349 | 350 | f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0 351 | 352 | raw_char = prepro_utils.preprocess_text(para_text[i], lower=self.tokenizer.lower_case, remove_space=False) 353 | tokenized_char = tokenized_para_text[j] 354 | if (raw_char == tokenized_char and f_prev + 1 > f[i, j]): 355 | g[(i, j)] = 2 356 | f[i, j] = f_prev + 1 357 | 358 | return f, g 359 | 360 | max_dist = abs(N - M) + 5 361 | for _ in range(2): 362 | lcs_matrix, match_mapping = _lcs_match(para_text, tokenized_para_text, N, M, max_N, max_M, max_dist) 363 | 364 | if lcs_matrix[N - 1, M - 1] > 0.8 * N: 365 | break 366 | 367 | max_dist *= 2 368 | 369 | mismatch = lcs_matrix[N - 1, M - 1] < 0.8 * N 370 | return match_mapping, mismatch 371 | 372 | def _convert_tokenized_index(self, 373 | index, 374 | pos, 375 | M=None, 376 | is_start=True): 377 | """Convert index for tokenized text""" 378 | if index[pos] is not None: 379 | return index[pos] 380 | 381 | N = len(index) 382 | rear = pos 383 | while rear < N - 1 and index[rear] is None: 384 | rear += 1 385 | 386 | front = pos 387 | while front > 0 and index[front] is None: 388 | front -= 1 389 | 390 | assert index[front] is not None or index[rear] is not None 391 | 392 | if index[front] is None: 393 | if index[rear] >= 1: 394 | if is_start: 395 | return 0 396 | else: 397 | return index[rear] - 1 398 | 399 | return index[rear] 400 | 401 | if index[rear] is None: 402 | if M is not None and index[front] < M - 1: 403 | if is_start: 404 | return index[front] + 1 405 | else: 406 | return M - 1 407 | 408 | return index[front] 409 | 410 | if is_start: 411 | if index[rear] > index[front] + 1: 412 | return index[front] + 1 413 | else: 414 | return index[rear] 415 | else: 416 | if index[rear] > index[front] + 1: 417 | return index[rear] - 1 418 | else: 419 | return index[front] 420 | 421 | def _find_max_context(self, 422 | doc_spans, 423 | token_idx): 424 | """Check if this is the 'max context' doc span for the token. 425 | 426 | Because of the sliding window approach taken to scoring documents, a single 427 | token can appear in multiple documents. E.g. 428 | Doc: the man went to the store and bought a gallon of milk 429 | Span A: the man went to the 430 | Span B: to the store and bought 431 | Span C: and bought a gallon of 432 | ... 433 | 434 | Now the word 'bought' will have two scores from spans B and C. We only 435 | want to consider the score with "maximum context", which we define as 436 | the *minimum* of its left and right context (the *sum* of left and 437 | right context will always be the same, of course). 438 | 439 | In the example the maximum context for 'bought' would be span C since 440 | it has 1 left context and 3 right context, while span B has 4 left context 441 | and 0 right context. 442 | """ 443 | best_doc_score = None 444 | best_doc_idx = None 445 | for (doc_idx, doc_span) in enumerate(doc_spans): 446 | doc_start = doc_span["start"] 447 | doc_length = doc_span["length"] 448 | doc_end = doc_start + doc_length - 1 449 | if token_idx < doc_start or token_idx > doc_end: 450 | continue 451 | 452 | left_context_length = token_idx - doc_start 453 | right_context_length = doc_end - token_idx 454 | doc_score = min(left_context_length, right_context_length) + 0.01 * doc_length 455 | if best_doc_score is None or doc_score > best_doc_score: 456 | best_doc_score = doc_score 457 | best_doc_idx = doc_idx 458 | 459 | return best_doc_idx 460 | 461 | def convert_squad_example(self, 462 | example, 463 | is_training=True, 464 | logging=False): 465 | """Converts a single `InputExample` into a single `InputFeatures`.""" 466 | query_tokens = self.tokenizer.tokenize(example.question_text) 467 | if len(query_tokens) > self.max_query_length: 468 | query_tokens = query_tokens[:self.max_query_length] 469 | 470 | para_text = example.paragraph_text 471 | para_tokens = self.tokenizer.tokenize(example.paragraph_text) 472 | 473 | char2token_index = [] 474 | token2char_start_index = [] 475 | token2char_end_index = [] 476 | char_idx = 0 477 | for i, token in enumerate(para_tokens): 478 | char_len = len(token) 479 | char2token_index.extend([i] * char_len) 480 | token2char_start_index.append(char_idx) 481 | char_idx += char_len 482 | token2char_end_index.append(char_idx - 1) 483 | 484 | tokenized_para_text = ''.join(para_tokens).replace(prepro_utils.SPIECE_UNDERLINE, ' ') 485 | 486 | N, M = len(para_text), len(tokenized_para_text) 487 | max_N, max_M = 1024, 1024 488 | if N > max_N or M > max_M: 489 | max_N = max(N, max_N) 490 | max_M = max(M, max_M) 491 | 492 | match_mapping, mismatch = self._generate_match_mapping(para_text, tokenized_para_text, N, M, max_N, max_M) 493 | 494 | raw2tokenized_char_index = [None] * N 495 | tokenized2raw_char_index = [None] * M 496 | i, j = N-1, M-1 497 | while i >= 0 and j >= 0: 498 | if (i, j) not in match_mapping: 499 | break 500 | 501 | if match_mapping[(i, j)] == 2: 502 | raw2tokenized_char_index[i] = j 503 | tokenized2raw_char_index[j] = i 504 | i, j = i - 1, j - 1 505 | elif match_mapping[(i, j)] == 1: 506 | j = j - 1 507 | else: 508 | i = i - 1 509 | 510 | if all(v is None for v in raw2tokenized_char_index) or mismatch: 511 | tf.logging.warning("raw and tokenized paragraph mismatch detected for example: %s" % example.qas_id) 512 | 513 | token2char_raw_start_index = [] 514 | token2char_raw_end_index = [] 515 | for idx in range(len(para_tokens)): 516 | start_pos = token2char_start_index[idx] 517 | end_pos = token2char_end_index[idx] 518 | raw_start_pos = self._convert_tokenized_index(tokenized2raw_char_index, start_pos, N, is_start=True) 519 | raw_end_pos = self._convert_tokenized_index(tokenized2raw_char_index, end_pos, N, is_start=False) 520 | token2char_raw_start_index.append(raw_start_pos) 521 | token2char_raw_end_index.append(raw_end_pos) 522 | 523 | if is_training: 524 | if not example.is_impossible: 525 | raw_start_char_pos = example.start_position 526 | raw_end_char_pos = raw_start_char_pos + len(example.orig_answer_text) - 1 527 | tokenized_start_char_pos = self._convert_tokenized_index(raw2tokenized_char_index, raw_start_char_pos, is_start=True) 528 | tokenized_end_char_pos = self._convert_tokenized_index(raw2tokenized_char_index, raw_end_char_pos, is_start=False) 529 | tokenized_start_token_pos = char2token_index[tokenized_start_char_pos] 530 | tokenized_end_token_pos = char2token_index[tokenized_end_char_pos] 531 | assert tokenized_start_token_pos <= tokenized_end_token_pos 532 | else: 533 | tokenized_start_token_pos = tokenized_end_token_pos = -1 534 | else: 535 | tokenized_start_token_pos = tokenized_end_token_pos = None 536 | 537 | # The -3 accounts for [CLS], [SEP] and [SEP] 538 | max_para_length = self.max_seq_length - len(query_tokens) - 3 539 | total_para_length = len(para_tokens) 540 | 541 | # We can have documents that are longer than the maximum sequence length. 542 | # To deal with this we do a sliding window approach, where we take chunks 543 | # of the up to our max length with a stride of `doc_stride`. 544 | doc_spans = [] 545 | para_start = 0 546 | while para_start < total_para_length: 547 | para_length = total_para_length - para_start 548 | if para_length > max_para_length: 549 | para_length = max_para_length 550 | 551 | doc_spans.append({ 552 | "start": para_start, 553 | "length": para_length 554 | }) 555 | 556 | if para_start + para_length == total_para_length: 557 | break 558 | 559 | para_start += min(para_length, self.doc_stride) 560 | 561 | feature_list = [] 562 | for (doc_idx, doc_span) in enumerate(doc_spans): 563 | input_tokens = [] 564 | segment_ids = [] 565 | p_mask = [] 566 | doc_token2char_raw_start_index = [] 567 | doc_token2char_raw_end_index = [] 568 | doc_token2doc_index = {} 569 | 570 | for i in range(doc_span["length"]): 571 | token_idx = doc_span["start"] + i 572 | 573 | doc_token2char_raw_start_index.append(token2char_raw_start_index[token_idx]) 574 | doc_token2char_raw_end_index.append(token2char_raw_end_index[token_idx]) 575 | 576 | best_doc_idx = self._find_max_context(doc_spans, token_idx) 577 | doc_token2doc_index[len(input_tokens)] = (best_doc_idx == doc_idx) 578 | 579 | input_tokens.append(para_tokens[token_idx]) 580 | segment_ids.append(self.segment_vocab_map["

"]) 581 | p_mask.append(0) 582 | 583 | doc_para_length = len(input_tokens) 584 | 585 | input_tokens.append("") 586 | segment_ids.append(self.segment_vocab_map["

"]) 587 | p_mask.append(1) 588 | 589 | # We put P before Q because during pretraining, B is always shorter than A 590 | for query_token in query_tokens: 591 | input_tokens.append(query_token) 592 | segment_ids.append(self.segment_vocab_map[""]) 593 | p_mask.append(1) 594 | 595 | input_tokens.append("") 596 | segment_ids.append(self.segment_vocab_map[""]) 597 | p_mask.append(1) 598 | 599 | cls_index = len(input_tokens) 600 | 601 | input_tokens.append("") 602 | segment_ids.append(self.segment_vocab_map[""]) 603 | p_mask.append(0) 604 | 605 | input_ids = self.tokenizer.tokens_to_ids(input_tokens) 606 | 607 | # The mask has 0 for real tokens and 1 for padding tokens. Only real tokens are attended to. 608 | input_mask = [0] * len(input_ids) 609 | 610 | # Zero-pad up to the sequence length. 611 | while len(input_ids) < self.max_seq_length: 612 | input_ids.append(self.special_vocab_map[""]) 613 | input_mask.append(1) 614 | segment_ids.append(self.segment_vocab_map[""]) 615 | p_mask.append(1) 616 | 617 | assert len(input_ids) == self.max_seq_length 618 | assert len(input_mask) == self.max_seq_length 619 | assert len(segment_ids) == self.max_seq_length 620 | assert len(p_mask) == self.max_seq_length 621 | 622 | start_position = None 623 | end_position = None 624 | is_impossible = example.is_impossible 625 | if is_training: 626 | if not is_impossible: 627 | # For training, if our document chunk does not contain an annotation, set default values. 628 | doc_start = doc_span["start"] 629 | doc_end = doc_start + doc_span["length"] - 1 630 | if tokenized_start_token_pos < doc_start or tokenized_end_token_pos > doc_end: 631 | start_position = 0 632 | end_position = 0 633 | is_impossible = True 634 | else: 635 | start_position = tokenized_start_token_pos - doc_start 636 | end_position = tokenized_end_token_pos - doc_start 637 | else: 638 | start_position = cls_index 639 | end_position = cls_index 640 | 641 | if logging: 642 | tf.logging.info("*** Example ***") 643 | tf.logging.info("unique_id: %s" % str(self.unique_id)) 644 | tf.logging.info("qas_id: %s" % example.qas_id) 645 | tf.logging.info("doc_idx: %s" % str(doc_idx)) 646 | tf.logging.info("doc_token2char_raw_start_index: %s" % " ".join([str(x) for x in doc_token2char_raw_start_index])) 647 | tf.logging.info("doc_token2char_raw_end_index: %s" % " ".join([str(x) for x in doc_token2char_raw_end_index])) 648 | tf.logging.info("doc_token2doc_index: %s" % " ".join(["%d:%s" % (x, y) for (x, y) in doc_token2doc_index.items()])) 649 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 650 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 651 | tf.logging.info("p_mask: %s" % " ".join([str(x) for x in p_mask])) 652 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 653 | if is_training: 654 | if not is_impossible: 655 | tf.logging.info("start_position: %s" % str(start_position)) 656 | tf.logging.info("end_position: %s" % str(end_position)) 657 | answer_tokens = input_tokens[start_position:end_position+1] 658 | answer_text = prepro_utils.printable_text("".join(answer_tokens).replace(prepro_utils.SPIECE_UNDERLINE, " ")) 659 | tf.logging.info("answer_text: %s" % answer_text) 660 | else: 661 | tf.logging.info("impossible example") 662 | 663 | feature = InputFeatures( 664 | unique_id=self.unique_id, 665 | qas_id=example.qas_id, 666 | doc_idx=doc_idx, 667 | token2char_raw_start_index=doc_token2char_raw_start_index, 668 | token2char_raw_end_index=doc_token2char_raw_end_index, 669 | token2doc_index=doc_token2doc_index, 670 | input_ids=input_ids, 671 | input_mask=input_mask, 672 | p_mask=p_mask, 673 | segment_ids=segment_ids, 674 | cls_index=cls_index, 675 | para_length=doc_para_length, 676 | start_position=start_position, 677 | end_position=end_position, 678 | is_impossible=is_impossible) 679 | 680 | feature_list.append(feature) 681 | self.unique_id += 1 682 | 683 | return feature_list 684 | 685 | def convert_examples_to_features(self, 686 | examples, 687 | is_training=True): 688 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 689 | features = [] 690 | for (idx, example) in enumerate(examples): 691 | if idx % 1000 == 0: 692 | tf.logging.info("Writing example %d of %d" % (idx, len(examples))) 693 | 694 | feature_list = self.convert_squad_example(example, is_training, logging=(idx < 20)) 695 | features.extend(feature_list) 696 | 697 | return features 698 | 699 | def save_features_as_tfrecord(self, 700 | features, 701 | output_file, 702 | is_training=True): 703 | """Save a set of `InputFeature`s to a TFRecord file.""" 704 | def create_int_feature(values): 705 | return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 706 | 707 | def create_float_feature(values): 708 | return tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 709 | 710 | with tf.python_io.TFRecordWriter(output_file) as writer: 711 | for feature in features: 712 | features = collections.OrderedDict() 713 | features["unique_id"] = create_int_feature([feature.unique_id]) 714 | features["input_ids"] = create_int_feature(feature.input_ids) 715 | features["input_mask"] = create_float_feature(feature.input_mask) 716 | features["p_mask"] = create_float_feature(feature.p_mask) 717 | features["segment_ids"] = create_int_feature(feature.segment_ids) 718 | features["cls_index"] = create_int_feature([feature.cls_index]) 719 | 720 | if is_training == True: 721 | features["start_position"] = create_int_feature([feature.start_position]) 722 | features["end_position"] = create_int_feature([feature.end_position]) 723 | features["is_impossible"] = create_float_feature([1 if feature.is_impossible else 0]) 724 | 725 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 726 | writer.write(tf_example.SerializeToString()) 727 | 728 | def save_features_as_pickle(self, 729 | features, 730 | output_file): 731 | """Save a set of `InputFeature`s to a Pickle file.""" 732 | with open(output_file, 'wb') as file: 733 | pickle.dump(features, file) 734 | 735 | def load_features_from_pickle(self, 736 | input_file): 737 | """Load a set of `InputFeature`s from a Pickle file.""" 738 | if not os.path.exists(input_file): 739 | raise FileNotFoundError("feature file not found: {0}".format(input_file)) 740 | 741 | with open(input_file, 'rb') as file: 742 | features = pickle.load(file) 743 | return features 744 | 745 | class XLNetInputBuilder(object): 746 | """Default input builder for XLNet""" 747 | @staticmethod 748 | def get_input_fn(input_file, 749 | seq_length, 750 | is_training, 751 | drop_remainder, 752 | shuffle_buffer=2048, 753 | num_threads=16): 754 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 755 | name_to_features = { 756 | "unique_id": tf.FixedLenFeature([], tf.int64), 757 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 758 | "input_mask": tf.FixedLenFeature([seq_length], tf.float32), 759 | "p_mask": tf.FixedLenFeature([seq_length], tf.float32), 760 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 761 | "cls_index": tf.FixedLenFeature([], tf.int64), 762 | } 763 | 764 | if is_training: 765 | name_to_features["start_position"] = tf.FixedLenFeature([], tf.int64) 766 | name_to_features["end_position"] = tf.FixedLenFeature([], tf.int64) 767 | name_to_features["is_impossible"] = tf.FixedLenFeature([], tf.float32) 768 | 769 | def _decode_record(record, 770 | name_to_features): 771 | """Decodes a record to a TensorFlow example.""" 772 | example = tf.parse_single_example(record, name_to_features) 773 | 774 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. So cast all int64 to int32. 775 | for name in list(example.keys()): 776 | t = example[name] 777 | if t.dtype == tf.int64: 778 | t = tf.to_int32(t) 779 | example[name] = t 780 | 781 | return example 782 | 783 | def input_fn(params): 784 | """The actual input function.""" 785 | batch_size = params["batch_size"] 786 | 787 | # For training, we want a lot of parallel reading and shuffling. 788 | # For eval, we want no shuffling and parallel reading doesn't matter. 789 | d = tf.data.TFRecordDataset(input_file) 790 | 791 | if is_training: 792 | d = d.repeat() 793 | d = d.shuffle(buffer_size=shuffle_buffer, seed=np.random.randint(10000)) 794 | 795 | d = d.apply(tf.contrib.data.map_and_batch( 796 | lambda record: _decode_record(record, name_to_features), 797 | batch_size=batch_size, 798 | num_parallel_batches=num_threads, 799 | drop_remainder=drop_remainder)) 800 | 801 | return d.prefetch(1024) 802 | 803 | return input_fn 804 | 805 | @staticmethod 806 | def get_serving_input_fn(seq_length): 807 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 808 | def serving_input_fn(): 809 | with tf.variable_scope("serving"): 810 | features = { 811 | 'unique_id': tf.placeholder(tf.int32, [None], name='unique_id'), 812 | 'input_ids': tf.placeholder(tf.int32, [None, seq_length], name='input_ids'), 813 | 'input_mask': tf.placeholder(tf.float32, [None, seq_length], name='input_mask'), 814 | 'p_mask': tf.placeholder(tf.float32, [None, seq_length], name='p_mask'), 815 | 'segment_ids': tf.placeholder(tf.int32, [None, seq_length], name='segment_ids'), 816 | 'cls_index': tf.placeholder(tf.int32, [None], name='cls_index'), 817 | } 818 | 819 | return tf.estimator.export.build_raw_serving_input_receiver_fn(features)() 820 | 821 | return serving_input_fn 822 | 823 | class XLNetModelBuilder(object): 824 | """Default model builder for XLNet""" 825 | def __init__(self, 826 | model_config, 827 | use_tpu=False): 828 | """Construct XLNet model builder""" 829 | self.model_config = model_config 830 | self.use_tpu = use_tpu 831 | 832 | def _generate_masked_data(self, 833 | input_data, 834 | input_mask): 835 | """Generate masked data""" 836 | return input_data * input_mask + MIN_FLOAT * (1 - input_mask) 837 | 838 | def _generate_onehot_label(self, 839 | input_data, 840 | input_depth): 841 | """Generate one-hot label""" 842 | return tf.one_hot(input_data, depth=input_depth, on_value=1.0, off_value=0.0, dtype=tf.float32) 843 | 844 | def _compute_loss(self, 845 | label, 846 | label_mask, 847 | predict, 848 | predict_mask, 849 | label_smoothing=0.0): 850 | """Compute optimization loss""" 851 | masked_predict = self._generate_masked_data(predict, predict_mask) 852 | masked_label = tf.cast(label, dtype=tf.int32) * tf.cast(label_mask, dtype=tf.int32) 853 | 854 | if label_smoothing > 1e-10: 855 | onehot_label = self._generate_onehot_label(masked_label, tf.shape(masked_predict)[-1]) 856 | onehot_label = (onehot_label * (1 - label_smoothing) + 857 | label_smoothing / tf.cast(tf.shape(masked_predict)[-1], dtype=tf.float32)) * predict_mask 858 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=onehot_label, logits=masked_predict) 859 | else: 860 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=masked_label, logits=masked_predict) 861 | 862 | return loss 863 | 864 | def _create_model(self, 865 | is_training, 866 | input_ids, 867 | input_mask, 868 | p_mask, 869 | segment_ids, 870 | cls_index, 871 | start_positions=None, 872 | end_positions=None, 873 | is_impossible=None): 874 | """Creates XLNet-SQuAD model""" 875 | model = xlnet.XLNetModel( 876 | xlnet_config=self.model_config, 877 | run_config=xlnet.create_run_config(is_training, True, FLAGS), 878 | input_ids=tf.transpose(input_ids, perm=[1,0]), # [b,l] --> [l,b] 879 | input_mask=tf.transpose(input_mask, perm=[1,0]), # [b,l] --> [l,b] 880 | seg_ids=tf.transpose(segment_ids, perm=[1,0])) # [b,l] --> [l,b] 881 | 882 | initializer = model.get_initializer() 883 | seq_len = tf.shape(input_ids)[-1] 884 | output_result = tf.transpose(model.get_sequence_output(), perm=[1,0,2]) # [l,b,h] --> [b,l,h] 885 | 886 | predicts = {} 887 | with tf.variable_scope("mrc", reuse=tf.AUTO_REUSE): 888 | with tf.variable_scope("start", reuse=tf.AUTO_REUSE): 889 | start_result = output_result # [b,l,h] 890 | start_result_mask = 1 - p_mask # [b,l] 891 | 892 | start_result = tf.layers.dense(start_result, units=1, activation=None, 893 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 894 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="start_project") # [b,l,h] --> [b,l,1] 895 | 896 | start_result = tf.squeeze(start_result, axis=-1) # [b,l,1] --> [b,l] 897 | start_result = self._generate_masked_data(start_result, start_result_mask) # [b,l], [b,l] --> [b,l] 898 | start_prob = tf.nn.softmax(start_result, axis=-1) # [b,l] 899 | 900 | if not is_training: 901 | start_top_prob, start_top_index = tf.nn.top_k(start_prob, k=FLAGS.start_n_top) # [b,l] --> [b,k], [b,k] 902 | predicts["start_prob"] = start_top_prob 903 | predicts["start_index"] = start_top_index 904 | 905 | with tf.variable_scope("end", reuse=tf.AUTO_REUSE): 906 | if is_training: 907 | # During training, compute the end logits based on the ground truth of the start position 908 | start_index = self._generate_onehot_label(tf.expand_dims(start_positions, axis=-1), seq_len) # [b] --> [b,1,l] 909 | feat_result = tf.matmul(start_index, output_result) # [b,1,l], [b,l,h] --> [b,1,h] 910 | feat_result = tf.tile(feat_result, multiples=[1,seq_len,1]) # [b,1,h] --> [b,l,h] 911 | 912 | end_result = tf.concat([output_result, feat_result], axis=-1) # [b,l,h], [b,l,h] --> [b,l,2h] 913 | end_result_mask = 1 - p_mask # [b,l] 914 | 915 | end_result = tf.layers.dense(end_result, units=self.model_config.d_model, activation=tf.tanh, 916 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 917 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="end_modeling") # [b,l,2h] --> [b,l,h] 918 | 919 | end_result = tf.contrib.layers.layer_norm(end_result, center=True, scale=True, activation_fn=None, 920 | reuse=None, begin_norm_axis=-1, begin_params_axis=-1, trainable=True, scope="end_norm") # [b,l,h] --> [b,l,h] 921 | 922 | end_result = tf.layers.dense(end_result, units=1, activation=None, 923 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 924 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="end_project") # [b,l,h] --> [b,l,1] 925 | 926 | end_result = tf.squeeze(end_result, axis=-1) # [b,l,1] --> [b,l] 927 | end_result = self._generate_masked_data(end_result, end_result_mask) # [b,l], [b,l] --> [b,l] 928 | end_prob = tf.nn.softmax(end_result, axis=-1) # [b,l] 929 | else: 930 | # During inference, compute the end logits based on beam search 931 | start_index = self._generate_onehot_label(start_top_index, seq_len) # [b,k] --> [b,k,l] 932 | feat_result = tf.matmul(start_index, output_result) # [b,k,l], [b,l,h] --> [b,k,h] 933 | feat_result = tf.expand_dims(feat_result, axis=1) # [b,k,h] --> [b,1,k,h] 934 | feat_result = tf.tile(feat_result, multiples=[1,seq_len,1,1]) # [b,1,k,h] --> [b,l,k,h] 935 | 936 | end_result = tf.expand_dims(output_result, axis=-2) # [b,l,h] --> [b,l,1,h] 937 | end_result = tf.tile(end_result, multiples=[1,1,FLAGS.start_n_top,1]) # [b,l,1,h] --> [b,l,k,h] 938 | end_result = tf.concat([end_result, feat_result], axis=-1) # [b,l,k,h], [b,l,k,h] --> [b,l,k,2h] 939 | end_result_mask = tf.expand_dims(1 - p_mask, axis=1) # [b,l] --> [b,1,l] 940 | end_result_mask = tf.tile(end_result_mask, multiples=[1,FLAGS.start_n_top,1]) # [b,1,l] --> [b,k,l] 941 | 942 | end_result = tf.layers.dense(end_result, units=self.model_config.d_model, activation=tf.tanh, 943 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 944 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="end_modeling") # [b,l,k,2h] --> [b,l,k,h] 945 | 946 | end_result = tf.contrib.layers.layer_norm(end_result, center=True, scale=True, activation_fn=None, 947 | reuse=None, begin_norm_axis=-1, begin_params_axis=-1, trainable=True, scope="end_norm") # [b,l,k,h] --> [b,l,k,h] 948 | 949 | end_result = tf.layers.dense(end_result, units=1, activation=None, 950 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 951 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="end_project") # [b,l,k,h] --> [b,l,k,1] 952 | 953 | end_result = tf.transpose(tf.squeeze(end_result, axis=-1), perm=[0,2,1]) # [b,l,k,1] --> [b,k,l] 954 | end_result = self._generate_masked_data(end_result, end_result_mask) # [b,k,l], [b,k,l] --> [b,k,l] 955 | end_prob = tf.nn.softmax(end_result, axis=-1) # [b,k,l] 956 | 957 | end_top_prob, end_top_index = tf.nn.top_k(end_prob, k=FLAGS.end_n_top) # [b,k,l] --> [b,k,k], [b,k,k] 958 | predicts["end_prob"] = end_top_prob 959 | predicts["end_index"] = end_top_index 960 | 961 | with tf.variable_scope("answer", reuse=tf.AUTO_REUSE): 962 | cls_index = self._generate_onehot_label(tf.expand_dims(cls_index, axis=-1), seq_len) # [b] --> [b,1,l] 963 | feat_result = tf.matmul(tf.expand_dims(start_prob, axis=1), output_result) # [b,l], [b,l,h] --> [b,1,h] 964 | 965 | answer_result = tf.matmul(cls_index, output_result) # [b,1,l], [b,l,h] --> [b,1,h] 966 | answer_result = tf.squeeze(tf.concat([feat_result, answer_result], axis=-1), axis=1) # [b,1,h], [b,1,h] --> [b,2h] 967 | answer_result_mask = tf.reduce_max(1 - p_mask, axis=-1) # [b,l] --> [b] 968 | 969 | answer_result = tf.layers.dense(answer_result, units=self.model_config.d_model, activation=tf.tanh, 970 | use_bias=True, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 971 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="answer_modeling") # [b,2h] --> [b,h] 972 | 973 | answer_result = tf.layers.dropout(answer_result, 974 | rate=FLAGS.dropout, seed=np.random.randint(10000), training=is_training) # [b,h] --> [b,h] 975 | 976 | answer_result = tf.layers.dense(answer_result, units=1, activation=None, 977 | use_bias=False, kernel_initializer=initializer, bias_initializer=tf.zeros_initializer, 978 | kernel_regularizer=None, bias_regularizer=None, trainable=True, name="answer_project") # [b,h] --> [b,1] 979 | 980 | answer_result = tf.squeeze(answer_result, axis=-1) # [b,1] --> [b] 981 | answer_result = self._generate_masked_data(answer_result, answer_result_mask) # [b], [b] --> [b] 982 | answer_prob = tf.sigmoid(answer_result) # [b] 983 | predicts["answer_prob"] = answer_prob 984 | 985 | with tf.variable_scope("loss", reuse=tf.AUTO_REUSE): 986 | loss = tf.constant(0.0, dtype=tf.float32) 987 | if is_training: 988 | start_label = start_positions # [b] 989 | start_label_mask = tf.reduce_max(1 - p_mask, axis=-1) # [b,l] --> [b] 990 | start_loss = self._compute_loss(start_label, start_label_mask, start_result, start_result_mask) # [b] 991 | end_label = end_positions # [b] 992 | end_label_mask = tf.reduce_max(1 - p_mask, axis=-1) # [b,l] --> [b] 993 | end_loss = self._compute_loss(end_label, end_label_mask, end_result, end_result_mask) # [b] 994 | loss += tf.reduce_mean(start_loss + end_loss) * 0.5 995 | 996 | if is_impossible is not None: 997 | answer_label = is_impossible # [b] 998 | answer_label_mask = tf.reduce_max(1 - p_mask, axis=-1) # [b,l] --> [b] 999 | answer_loss = tf.nn.sigmoid_cross_entropy_with_logits( 1000 | labels=answer_label * answer_label_mask, logits=answer_result) # [b] 1001 | loss += tf.reduce_mean(answer_loss) * 0.5 1002 | 1003 | return loss, predicts 1004 | 1005 | def get_model_fn(self): 1006 | """Returns `model_fn` closure for TPUEstimator.""" 1007 | def model_fn(features, 1008 | labels, 1009 | mode, 1010 | params): # pylint: disable=unused-argument 1011 | """The `model_fn` for TPUEstimator.""" 1012 | tf.logging.info("*** Features ***") 1013 | for name in sorted(features.keys()): 1014 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 1015 | 1016 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 1017 | 1018 | unique_id = features["unique_id"] 1019 | input_ids = features["input_ids"] 1020 | input_mask = features["input_mask"] 1021 | p_mask = features["p_mask"] 1022 | segment_ids = features["segment_ids"] 1023 | cls_index = features["cls_index"] 1024 | 1025 | if is_training: 1026 | start_position = features["start_position"] 1027 | end_position = features["end_position"] 1028 | is_impossible = features["is_impossible"] 1029 | else: 1030 | start_position = None 1031 | end_position = None 1032 | is_impossible = None 1033 | 1034 | loss, predicts = self._create_model(is_training, input_ids, input_mask, 1035 | p_mask, segment_ids, cls_index, start_position, end_position, is_impossible) 1036 | 1037 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS) 1038 | 1039 | output_spec = None 1040 | if is_training: 1041 | train_op, _, _ = model_utils.get_train_op(FLAGS, loss) 1042 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 1043 | mode=mode, 1044 | loss=loss, 1045 | train_op=train_op, 1046 | scaffold_fn=scaffold_fn) 1047 | else: 1048 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 1049 | mode=mode, 1050 | predictions={ 1051 | "unique_id": unique_id, 1052 | "answer_prob": predicts["answer_prob"], 1053 | "start_prob": predicts["start_prob"], 1054 | "start_index": predicts["start_index"], 1055 | "end_prob": predicts["end_prob"], 1056 | "end_index": predicts["end_index"] 1057 | }, 1058 | scaffold_fn=scaffold_fn) 1059 | 1060 | return output_spec 1061 | 1062 | return model_fn 1063 | 1064 | class XLNetPredictProcessor(object): 1065 | """Default predict processor for XLNet""" 1066 | def __init__(self, 1067 | output_dir, 1068 | n_best_size, 1069 | start_n_top, 1070 | end_n_top, 1071 | max_answer_length, 1072 | tokenizer, 1073 | predict_tag=None): 1074 | """Construct XLNet predict processor""" 1075 | self.n_best_size = n_best_size 1076 | self.start_n_top = start_n_top 1077 | self.end_n_top = end_n_top 1078 | self.max_answer_length = max_answer_length 1079 | self.tokenizer = tokenizer 1080 | 1081 | predict_tag = predict_tag if predict_tag else str(time.time()) 1082 | self.output_summary = os.path.join(output_dir, "predict.{0}.summary.json".format(predict_tag)) 1083 | self.output_detail = os.path.join(output_dir, "predict.{0}.detail.json".format(predict_tag)) 1084 | 1085 | def _write_to_json(self, 1086 | data_list, 1087 | data_path): 1088 | data_folder = os.path.dirname(data_path) 1089 | if not os.path.exists(data_folder): 1090 | os.mkdir(data_folder) 1091 | 1092 | with open(data_path, "w") as file: 1093 | json.dump(data_list, file, indent=4) 1094 | 1095 | def _write_to_text(self, 1096 | data_list, 1097 | data_path): 1098 | data_folder = os.path.dirname(data_path) 1099 | if not os.path.exists(data_folder): 1100 | os.mkdir(data_folder) 1101 | 1102 | with open(data_path, "w") as file: 1103 | for data in data_list: 1104 | file.write("{0}\n".format(data)) 1105 | 1106 | def process(self, 1107 | examples, 1108 | features, 1109 | results): 1110 | qas_id_to_features = {} 1111 | unique_id_to_feature = {} 1112 | for feature in features: 1113 | if feature.qas_id not in qas_id_to_features: 1114 | qas_id_to_features[feature.qas_id] = [] 1115 | 1116 | qas_id_to_features[feature.qas_id].append(feature) 1117 | unique_id_to_feature[feature.unique_id] = feature 1118 | 1119 | unique_id_to_result = {} 1120 | for result in results: 1121 | unique_id_to_result[result.unique_id] = result 1122 | 1123 | predict_summary_list = [] 1124 | predict_detail_list = [] 1125 | num_example = len(examples) 1126 | for (example_idx, example) in enumerate(examples): 1127 | if example_idx % 1000 == 0: 1128 | tf.logging.info('Updating {0}/{1} example with predict'.format(example_idx, num_example)) 1129 | 1130 | if example.qas_id not in qas_id_to_features: 1131 | tf.logging.warning('No feature found for example: {0}'.format(example.qas_id)) 1132 | continue 1133 | 1134 | example_answer_prob = MAX_FLOAT 1135 | example_all_predicts = [] 1136 | example_features = qas_id_to_features[example.qas_id] 1137 | for example_feature in example_features: 1138 | if example_feature.unique_id not in unique_id_to_result: 1139 | tf.logging.warning('No result found for feature: {0}'.format(example_feature.unique_id)) 1140 | continue 1141 | 1142 | example_result = unique_id_to_result[example_feature.unique_id] 1143 | example_answer_prob = min(example_answer_prob, float(example_result.answer_prob)) 1144 | for i in range(self.start_n_top): 1145 | start_prob = example_result.start_prob[i] 1146 | start_index = example_result.start_index[i] 1147 | 1148 | for j in range(self.end_n_top): 1149 | end_prob = example_result.end_prob[i][j] 1150 | end_index = example_result.end_index[i][j] 1151 | 1152 | answer_length = end_index - start_index + 1 1153 | if end_index < start_index or answer_length > self.max_answer_length: 1154 | continue 1155 | 1156 | if start_index > example_feature.para_length or end_index > example_feature.para_length: 1157 | continue 1158 | 1159 | if start_index not in example_feature.token2doc_index: 1160 | continue 1161 | 1162 | example_all_predicts.append({ 1163 | "unique_id": example_result.unique_id, 1164 | "start_prob": start_prob, 1165 | "start_index": start_index, 1166 | "end_prob": end_prob, 1167 | "end_index": end_index, 1168 | "predict_score": np.log(start_prob) + np.log(end_prob) 1169 | }) 1170 | 1171 | example_all_predicts = sorted(example_all_predicts, key=lambda x: x["predict_score"], reverse=True) 1172 | 1173 | is_visited = set() 1174 | example_top_predicts = [] 1175 | for example_predict in example_all_predicts: 1176 | if len(example_top_predicts) >= self.n_best_size: 1177 | break 1178 | 1179 | example_feature = unique_id_to_feature[example_predict["unique_id"]] 1180 | predict_start = example_feature.token2char_raw_start_index[example_predict["start_index"]] 1181 | predict_end = example_feature.token2char_raw_end_index[example_predict["end_index"]] 1182 | predict_text = example.paragraph_text[predict_start:predict_end + 1].strip() 1183 | 1184 | if predict_text in is_visited: 1185 | continue 1186 | 1187 | is_visited.add(predict_text) 1188 | 1189 | example_top_predicts.append({ 1190 | "predict_text": predict_text, 1191 | "start_prob": float(example_predict["start_prob"]), 1192 | "end_prob": float(example_predict["end_prob"]), 1193 | "predict_score": float(example_predict["predict_score"]) 1194 | }) 1195 | 1196 | if len(example_top_predicts) == 0: 1197 | example_top_predicts.append({ 1198 | "predict_text": "", 1199 | "start_prob": 0.0, 1200 | "end_prob": 0.0, 1201 | "predict_score": 0.0 1202 | }) 1203 | 1204 | example_best_predict = example_top_predicts[0] 1205 | 1206 | predict_summary_list.append({ 1207 | "qas_id": example.qas_id, 1208 | "answer_prob": example_answer_prob, 1209 | "start_prob": example_best_predict["start_prob"], 1210 | "end_prob": example_best_predict["end_prob"], 1211 | "predict_text": example_best_predict["predict_text"] 1212 | }) 1213 | 1214 | predict_detail_list.append({ 1215 | "qas_id": example.qas_id, 1216 | "answer_prob": example_answer_prob, 1217 | "best_predict": example_best_predict, 1218 | "top_predicts": example_top_predicts, 1219 | }) 1220 | 1221 | self._write_to_json(predict_summary_list, self.output_summary) 1222 | self._write_to_json(predict_detail_list, self.output_detail) 1223 | 1224 | def main(_): 1225 | tf.logging.set_verbosity(tf.logging.INFO) 1226 | 1227 | np.random.seed(FLAGS.random_seed) 1228 | 1229 | if not os.path.exists(FLAGS.output_dir): 1230 | os.mkdir(FLAGS.output_dir) 1231 | 1232 | task_name = FLAGS.task_name.lower() 1233 | data_pipeline = SquadPipeline( 1234 | data_dir=FLAGS.data_dir, 1235 | task_name=task_name) 1236 | 1237 | model_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 1238 | 1239 | model_builder = XLNetModelBuilder( 1240 | model_config=model_config, 1241 | use_tpu=FLAGS.use_tpu) 1242 | 1243 | model_fn = model_builder.get_model_fn() 1244 | 1245 | # If TPU is not available, this will fall back to normal Estimator on CPU or GPU. 1246 | tpu_config = model_utils.configure_tpu(FLAGS) 1247 | 1248 | estimator = tf.contrib.tpu.TPUEstimator( 1249 | use_tpu=FLAGS.use_tpu, 1250 | model_fn=model_fn, 1251 | config=tpu_config, 1252 | export_to_tpu=FLAGS.use_tpu, 1253 | train_batch_size=FLAGS.train_batch_size, 1254 | predict_batch_size=FLAGS.predict_batch_size) 1255 | 1256 | tokenizer = XLNetTokenizer( 1257 | sp_model_file=FLAGS.spiece_model_file, 1258 | lower_case=FLAGS.lower_case) 1259 | 1260 | example_processor = XLNetExampleProcessor( 1261 | max_seq_length=FLAGS.max_seq_length, 1262 | max_query_length=FLAGS.max_query_length, 1263 | doc_stride=FLAGS.doc_stride, 1264 | tokenizer=tokenizer) 1265 | 1266 | if FLAGS.do_train: 1267 | train_examples = data_pipeline.get_train_examples() 1268 | 1269 | tf.logging.info("***** Run training *****") 1270 | tf.logging.info(" Num examples = %d", len(train_examples)) 1271 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 1272 | tf.logging.info(" Num steps = %d", FLAGS.train_steps) 1273 | 1274 | train_record_file = os.path.join(FLAGS.output_dir, "train-{0}.tfrecord".format(task_name)) 1275 | if not os.path.exists(train_record_file) or FLAGS.overwrite_data: 1276 | train_features = example_processor.convert_examples_to_features(train_examples, True) 1277 | np.random.shuffle(train_features) 1278 | example_processor.save_features_as_tfrecord(train_features, train_record_file, True) 1279 | 1280 | train_input_fn = XLNetInputBuilder.get_input_fn(train_record_file, FLAGS.max_seq_length, True, True, FLAGS.shuffle_buffer) 1281 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 1282 | 1283 | if FLAGS.do_predict: 1284 | predict_examples = data_pipeline.get_dev_examples() 1285 | 1286 | tf.logging.info("***** Run prediction *****") 1287 | tf.logging.info(" Num examples = %d", len(predict_examples)) 1288 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1289 | 1290 | predict_record_file = os.path.join(FLAGS.output_dir, "dev-{0}.tfrecord".format(task_name)) 1291 | predict_pickle_file = os.path.join(FLAGS.output_dir, "dev-{0}.pkl".format(task_name)) 1292 | if not os.path.exists(predict_record_file) or not os.path.exists(predict_pickle_file) or FLAGS.overwrite_data: 1293 | predict_features = example_processor.convert_examples_to_features(predict_examples, False) 1294 | example_processor.save_features_as_tfrecord(predict_features, predict_record_file, False) 1295 | example_processor.save_features_as_pickle(predict_features, predict_pickle_file) 1296 | else: 1297 | predict_features = example_processor.load_features_from_pickle(predict_pickle_file) 1298 | 1299 | predict_input_fn = XLNetInputBuilder.get_input_fn(predict_record_file, FLAGS.max_seq_length, False, False) 1300 | results = estimator.predict(input_fn=predict_input_fn) 1301 | 1302 | predict_results = [OutputResult( 1303 | unique_id=result["unique_id"], 1304 | answer_prob=result["answer_prob"], 1305 | start_prob=result["start_prob"].tolist(), 1306 | start_index=result["start_index"].tolist(), 1307 | end_prob=result["end_prob"].tolist(), 1308 | end_index=result["end_index"].tolist() 1309 | ) for result in results] 1310 | 1311 | predict_processor = XLNetPredictProcessor( 1312 | output_dir=FLAGS.output_dir, 1313 | n_best_size=FLAGS.n_best_size, 1314 | start_n_top=FLAGS.start_n_top, 1315 | end_n_top=FLAGS.end_n_top, 1316 | max_answer_length=FLAGS.max_answer_length, 1317 | tokenizer=tokenizer, 1318 | predict_tag=FLAGS.predict_tag) 1319 | 1320 | predict_processor.process(predict_examples, predict_features, predict_results) 1321 | 1322 | if FLAGS.do_export: 1323 | tf.logging.info("***** Running exporting *****") 1324 | if not os.path.exists(FLAGS.export_dir): 1325 | os.mkdir(FLAGS.export_dir) 1326 | 1327 | serving_input_fn = XLNetInputBuilder.get_serving_input_fn(FLAGS.max_seq_length) 1328 | estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn, as_text=False) 1329 | 1330 | if __name__ == "__main__": 1331 | flags.mark_flag_as_required("spiece_model_file") 1332 | flags.mark_flag_as_required("model_config_path") 1333 | flags.mark_flag_as_required("init_checkpoint") 1334 | flags.mark_flag_as_required("data_dir") 1335 | flags.mark_flag_as_required("output_dir") 1336 | flags.mark_flag_as_required("model_dir") 1337 | flags.mark_flag_as_required("export_dir") 1338 | tf.app.run() 1339 | --------------------------------------------------------------------------------