├── docs ├── xlnet.coqa.png ├── xlnet.quac.png ├── squad.example.png ├── xlnet.squad.v1.png ├── xlnet.squad.v2.png ├── _config.yml └── index.md ├── .gitmodules ├── tool ├── convert_squad.py ├── convert_quac.py ├── convert_coqa.py ├── eval_quac.py ├── eval_coqa.py └── eval_squad.py ├── .gitignore ├── codalab ├── run_coqa.codalab.sh └── run_quac.codalab.sh ├── run_squad.sh ├── run_coqa.sh ├── run_quac.sh ├── README.md ├── LICENSE └── run_squad.py /docs/xlnet.coqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.coqa.png -------------------------------------------------------------------------------- /docs/xlnet.quac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.quac.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "xlnet"] 2 | path = xlnet 3 | url = https://github.com/zihangdai/xlnet.git 4 | -------------------------------------------------------------------------------- /docs/squad.example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/squad.example.png -------------------------------------------------------------------------------- /docs/xlnet.squad.v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.squad.v1.png -------------------------------------------------------------------------------- /docs/xlnet.squad.v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevezheng23/mrc_tf/HEAD/docs/xlnet.squad.v2.png -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | title: Machine Reading Comprehension (MRC) 3 | description: This is a Machine Reading Comprehension (MRC) experiment project, which includes implementations and experiments for various MRC tasks (e.g. SQuAD, CoQA, QuAC, etc.) -------------------------------------------------------------------------------- /tool/convert_squad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | def add_arguments(parser): 5 | parser.add_argument("--input_file", help="path to input file", required=True) 6 | parser.add_argument("--span_file", help="path to answer span file", required=True) 7 | parser.add_argument("--prob_file", help="path to no-answer probability file", required=True) 8 | parser.add_argument("--prob_thres", help="threshold of no-answer probability", required=False, default=1.0, type=float) 9 | 10 | def convert_squad(input_file, 11 | span_file, 12 | prob_file, 13 | prob_thres): 14 | with open(input_file, "r") as file: 15 | input_data = json.load(file) 16 | 17 | span_dict = {} 18 | prob_dict = {} 19 | for data in input_data: 20 | qas_id = data["qas_id"] 21 | predict_text = data["predict_text"] 22 | answer_prob = data["answer_prob"] if "answer_prob" in data else 0.0 23 | 24 | span_dict[qas_id] = predict_text if answer_prob < prob_thres else "" 25 | prob_dict[qas_id] = answer_prob 26 | 27 | with open(span_file, "w") as file: 28 | json.dump(span_dict, file, indent=4) 29 | 30 | with open(prob_file, "w") as file: 31 | json.dump(prob_dict, file, indent=4) 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | add_arguments(parser) 36 | args = parser.parse_args() 37 | convert_squad(args.input_file, args.span_file, args.prob_file, args.prob_thres) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /tool/convert_quac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | 6 | def add_arguments(parser): 7 | parser.add_argument("--input_file", help="path to input file", required=True) 8 | parser.add_argument("--output_file", help="path to output file", required=True) 9 | parser.add_argument("--answer_threshold", help="threshold of answer", required=False, default=0.1, type=float) 10 | 11 | def convert_quac(input_file, 12 | output_file, 13 | answer_threshold): 14 | with open(input_file, "r") as file: 15 | input_data = json.load(file) 16 | 17 | data_lookup = {} 18 | for data in input_data: 19 | qas_id = data["qas_id"] 20 | id_items = qas_id.split('#') 21 | id = id_items[0] 22 | turn_id = int(id_items[1]) 23 | 24 | no_answer = data["no_answer_score"] 25 | 26 | yes_no_list = ["y", "x", "n"] 27 | yes_no = yes_no_list[data["yes_no_id"]] 28 | 29 | follow_up_list = ["y", "m", "n"] 30 | follow_up = follow_up_list[data["follow_up_id"]] 31 | 32 | if no_answer >= answer_threshold: 33 | answer_text = "CANNOTANSWER" 34 | else: 35 | answer_text = data["predict_text"] 36 | 37 | if id not in data_lookup: 38 | data_lookup[id] = [] 39 | 40 | data_lookup[id].append({ 41 | "qas_id": qas_id, 42 | "turn_id": turn_id, 43 | "answer_text": answer_text, 44 | "yes_no": yes_no, 45 | "follow_up": follow_up 46 | }) 47 | 48 | with open(output_file, "w") as file: 49 | for id in data_lookup.keys(): 50 | data_list = sorted(data_lookup[id], key=lambda x: x["turn_id"]) 51 | 52 | output_data = json.dumps({ 53 | "best_span_str": [data["answer_text"] for data in data_list], 54 | "qid": [data["qas_id"] for data in data_list], 55 | "yesno": [data["yes_no"] for data in data_list], 56 | "followup": [data["follow_up"] for data in data_list] 57 | }) 58 | 59 | file.write("{0}\n".format(output_data)) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | add_arguments(parser) 64 | args = parser.parse_args() 65 | convert_quac(args.input_file, args.output_file, args.answer_threshold) 66 | -------------------------------------------------------------------------------- /tool/convert_coqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | 6 | from eval_coqa import CoQAEvaluator 7 | 8 | def add_arguments(parser): 9 | parser.add_argument("--input_file", help="path to input file", required=True) 10 | parser.add_argument("--output_file", help="path to output file", required=True) 11 | parser.add_argument("--answer_threshold", help="threshold of answer", required=False, default=0.1, type=float) 12 | 13 | def convert_coqa(input_file, 14 | output_file, 15 | answer_threshold): 16 | with open(input_file, "r") as file: 17 | input_data = json.load(file) 18 | 19 | output_data = [] 20 | for data in input_data: 21 | id_items = data["qas_id"].split('_') 22 | id = id_items[0] 23 | turn_id = int(id_items[1]) 24 | 25 | score_list = [data["unk_score"], data["yes_score"], data["no_score"], data["num_score"], data["opt_score"]] 26 | answer_list = ["unknown", "yes", "no", "number", "option"] 27 | 28 | score_idx = np.argmax(score_list) 29 | if score_list[score_idx] >= answer_threshold: 30 | answer = answer_list[score_idx] 31 | if answer == "number": 32 | answer_list = ["none", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"] 33 | answer = answer_list[data["num_id"]-1] 34 | elif answer == "option": 35 | answer = data["predict_text"] 36 | norm_question_tokens = CoQAEvaluator.normalize_answer(data["question_text"]).split(" ") 37 | if "or" in norm_question_tokens: 38 | index = norm_question_tokens.index("or") 39 | if index-1 >= 0 and index+1 < len(norm_question_tokens): 40 | answer_list = [norm_question_tokens[index-1], norm_question_tokens[index+1]] 41 | answer = answer_list[data["opt_id"]-1] 42 | else: 43 | answer = data["predict_text"] 44 | 45 | output_data.append({ 46 | "id": id, 47 | "turn_id": turn_id, 48 | "answer": answer 49 | }) 50 | 51 | with open(output_file, "w") as file: 52 | json.dump(output_data, file, indent=4) 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | add_arguments(parser) 57 | args = parser.parse_args() 58 | convert_coqa(args.input_file, args.output_file, args.answer_threshold) 59 | -------------------------------------------------------------------------------- /codalab/run_coqa.codalab.sh: -------------------------------------------------------------------------------- 1 | start_time=`date +%s` 2 | 3 | alias python=python3 4 | 5 | git clone --recurse-submodules https://github.com/stevezheng23/mrc_tf.git 6 | 7 | cd mrc_tf 8 | 9 | mkdir model 10 | mkdir model/xlnet 11 | wget -P model/xlnet https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip 12 | unzip model/xlnet/cased_L-24_H-1024_A-16.zip -d model/xlnet/ 13 | mv model/xlnet/xlnet_cased_L-24_H-1024_A-16 model/xlnet/cased_L-24_H-1024_A-16 14 | rm model/xlnet/cased_L-24_H-1024_A-16.zip 15 | 16 | mkdir data 17 | mkdir data/coqa 18 | cp ../coqa-dev-v1.0.json data/coqa/dev-v1.0.json 19 | 20 | mkdir output 21 | mkdir output/coqa 22 | mkdir output/coqa/data 23 | wget -P output/coqa https://storage.googleapis.com/mrc_data/coqa/coqa_cased_L-24_H-1024_A-16.zip 24 | unzip output/coqa/coqa_cased_L-24_H-1024_A-16.zip -d output/coqa/ 25 | mv output/coqa/coqa_cased_L-24_H-1024_A-16 output/coqa/checkpoint 26 | rm output/coqa/coqa_cased_L-24_H-1024_A-16.zip 27 | 28 | CUDA_VISIBLE_DEVICES=0 python run_coqa.py \ 29 | --spiece_model_file=model/xlnet/cased_L-24_H-1024_A-16/spiece.model \ 30 | --model_config_path=model/xlnet/cased_L-24_H-1024_A-16/xlnet_config.json \ 31 | --init_checkpoint=model/xlnet/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 32 | --task_name='v1.0' \ 33 | --random_seed=1000 \ 34 | --predict_tag='v1.0' \ 35 | --lower_case=false \ 36 | --data_dir=data/coqa/ \ 37 | --output_dir=output/coqa/data \ 38 | --model_dir=output/coqa/checkpoint \ 39 | --export_dir=output/coqa/export \ 40 | --num_turn=-1 \ 41 | --max_seq_length=512 \ 42 | --max_query_length=128 \ 43 | --max_answer_length=16 \ 44 | --train_batch_size=48 \ 45 | --predict_batch_size=16 \ 46 | --num_hosts=1 \ 47 | --num_core_per_host=1 \ 48 | --learning_rate=2e-5 \ 49 | --train_steps=10000 \ 50 | --warmup_steps=1000 \ 51 | --save_steps=1000 \ 52 | --do_train=false \ 53 | --do_predict=true \ 54 | --do_export=false \ 55 | --overwrite_data=false 56 | 57 | python tool/convert_coqa.py \ 58 | --input_file=output/coqa/data/predict.v1.0.summary.json \ 59 | --output_file=output/coqa/data/predict.v1.0.span.json \ 60 | --answer_threshold=0.1 61 | 62 | python tool/eval_coqa.py \ 63 | --data-file=data/coqa/dev-v1.0.json \ 64 | --pred-file=output/coqa/data/predict.v1.0.span.json \ 65 | >> output/coqa/data/predict.v1.0.eval.json 66 | 67 | cp output/coqa/data/predict.v1.0.span.json ../coqa-dev-v1.0.span.json 68 | cp output/coqa/data/predict.v1.0.eval.json ../coqa-dev-v1.0.eval.json 69 | 70 | cd .. 71 | rm -r mrc_tf 72 | 73 | end_time=`date +%s` 74 | echo execution time was `expr $end_time - $start_time` s. 75 | -------------------------------------------------------------------------------- /codalab/run_quac.codalab.sh: -------------------------------------------------------------------------------- 1 | for arg in "$@" 2 | do 3 | case $arg in 4 | -i|--inputfile) 5 | INPUTFILE="$2" 6 | shift 7 | shift 8 | ;; 9 | -o|--outputfile) 10 | OUTPUTFILE="$2" 11 | shift 12 | shift 13 | ;; 14 | esac 15 | done 16 | 17 | echo "input file = ${INPUTFILE}" 18 | echo "output file = ${OUTPUTFILE}" 19 | 20 | start_time=`date +%s` 21 | 22 | alias python=python3 23 | 24 | git clone --recurse-submodules https://github.com/stevezheng23/mrc_tf.git 25 | 26 | cd mrc_tf 27 | 28 | mkdir model 29 | mkdir model/xlnet 30 | wget -P model/xlnet https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip 31 | unzip model/xlnet/cased_L-24_H-1024_A-16.zip -d model/xlnet/ 32 | mv model/xlnet/xlnet_cased_L-24_H-1024_A-16 model/xlnet/cased_L-24_H-1024_A-16 33 | rm model/xlnet/cased_L-24_H-1024_A-16.zip 34 | 35 | mkdir data 36 | mkdir data/quac 37 | cp ../${INPUTFILE} data/quac/dev-v0.2.json 38 | 39 | mkdir output 40 | mkdir output/quac 41 | mkdir output/quac/data 42 | wget -P output/quac https://storage.googleapis.com/mrc_data/quac/quac_cased_L-24_H-1024_A-16.zip 43 | unzip output/quac/quac_cased_L-24_H-1024_A-16.zip -d output/quac/ 44 | mv output/quac/quac_cased_L-24_H-1024_A-16 output/quac/checkpoint 45 | rm output/quac/quac_cased_L-24_H-1024_A-16.zip 46 | 47 | CUDA_VISIBLE_DEVICES=0 python run_quac.py \ 48 | --spiece_model_file=model/xlnet/cased_L-24_H-1024_A-16/spiece.model \ 49 | --model_config_path=model/xlnet/cased_L-24_H-1024_A-16/xlnet_config.json \ 50 | --init_checkpoint=model/xlnet/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 51 | --task_name='v0.2' \ 52 | --random_seed=1000 \ 53 | --predict_tag='v0.2' \ 54 | --lower_case=false \ 55 | --data_dir=data/quac/ \ 56 | --output_dir=output/quac/data \ 57 | --model_dir=output/quac/checkpoint \ 58 | --export_dir=output/quac/export \ 59 | --num_turn=-1 \ 60 | --max_seq_length=512 \ 61 | --max_query_length=128 \ 62 | --max_answer_length=32 \ 63 | --train_batch_size=48 \ 64 | --predict_batch_size=12 \ 65 | --num_hosts=1 \ 66 | --num_core_per_host=1 \ 67 | --learning_rate=2e-5 \ 68 | --train_steps=8000 \ 69 | --warmup_steps=1000 \ 70 | --save_steps=1000 \ 71 | --do_train=false \ 72 | --do_predict=true \ 73 | --do_export=false \ 74 | --overwrite_data=false 75 | 76 | python tool/convert_quac.py \ 77 | --input_file=output/quac/data/predict.v0.2.summary.json \ 78 | --output_file=output/quac/data/predict.v0.2.span.json \ 79 | --answer_threshold=0.1 80 | 81 | python tool/eval_quac.py \ 82 | --val_file=data/quac/dev-v0.2.json \ 83 | --model_output=output/quac/data/predict.v0.2.span.json \ 84 | --o=output/quac/data/predict.v0.2.eval.json 85 | 86 | cp output/quac/data/predict.v0.2.span.json ../${OUTPUTFILE} 87 | 88 | cd .. 89 | rm -r mrc_tf 90 | 91 | end_time=`date +%s` 92 | echo execution time was `expr $end_time - $start_time` s. 93 | -------------------------------------------------------------------------------- /run_squad.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --seqlen=*) 37 | SEQLEN="${i#*=}" 38 | shift 39 | ;; 40 | --querylen=*) 41 | QUERYLEN="${i#*=}" 42 | shift 43 | ;; 44 | --answerlen=*) 45 | ANSWERLEN="${i#*=}" 46 | shift 47 | ;; 48 | --batchsize=*) 49 | BATCHSIZE="${i#*=}" 50 | shift 51 | ;; 52 | --learningrate=*) 53 | LEARNINGRATE="${i#*=}" 54 | shift 55 | ;; 56 | --trainsteps=*) 57 | TRAINSTEPS="${i#*=}" 58 | shift 59 | ;; 60 | --warmupsteps=*) 61 | WARMUPSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --savesteps=*) 65 | SAVESTEPS="${i#*=}" 66 | shift 67 | ;; 68 | esac 69 | done 70 | 71 | echo "gpu device = ${GPUDEVICE}" 72 | echo "num gpus = ${NUMGPUS}" 73 | echo "task name = ${TASKNAME}" 74 | echo "random seed = ${RANDOMSEED}" 75 | echo "predict tag = ${PREDICTTAG}" 76 | echo "model dir = ${MODELDIR}" 77 | echo "data dir = ${DATADIR}" 78 | echo "output dir = ${OUTPUTDIR}" 79 | echo "seq len = ${SEQLEN}" 80 | echo "query len = ${QUERYLEN}" 81 | echo "answer len = ${ANSWERLEN}" 82 | echo "batch size = ${BATCHSIZE}" 83 | echo "learning rate = ${LEARNINGRATE}" 84 | echo "train steps = ${TRAINSTEPS}" 85 | echo "warmup steps = ${WARMUPSTEPS}" 86 | echo "save steps = ${SAVESTEPS}" 87 | 88 | alias python=python3 89 | mkdir ${OUTPUTDIR} 90 | 91 | start_time=`date +%s` 92 | 93 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_squad.py \ 94 | --spiece_model_file=${MODELDIR}/spiece.model \ 95 | --model_config_path=${MODELDIR}/xlnet_config.json \ 96 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 97 | --task_name=${TASKNAME} \ 98 | --random_seed=${RANDOMSEED} \ 99 | --predict_tag=${PREDICTTAG} \ 100 | --lower_case=false \ 101 | --data_dir=${DATADIR}/ \ 102 | --output_dir=${OUTPUTDIR}/data \ 103 | --model_dir=${OUTPUTDIR}/checkpoint \ 104 | --export_dir=${OUTPUTDIR}/export \ 105 | --max_seq_length=${SEQLEN} \ 106 | --max_query_length=${QUERYLEN} \ 107 | --max_answer_length=${ANSWERLEN} \ 108 | --train_batch_size=${BATCHSIZE} \ 109 | --predict_batch_size=${BATCHSIZE} \ 110 | --num_hosts=1 \ 111 | --num_core_per_host=${NUMGPUS} \ 112 | --learning_rate=${LEARNINGRATE} \ 113 | --train_steps=${TRAINSTEPS} \ 114 | --warmup_steps=${WARMUPSTEPS} \ 115 | --save_steps=${SAVESTEPS} \ 116 | --do_train=true \ 117 | --do_predict=false \ 118 | --do_export=false \ 119 | --overwrite_data=false 120 | 121 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_squad.py \ 122 | --spiece_model_file=${MODELDIR}/spiece.model \ 123 | --model_config_path=${MODELDIR}/xlnet_config.json \ 124 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 125 | --task_name=${TASKNAME} \ 126 | --random_seed=${RANDOMSEED} \ 127 | --predict_tag=${PREDICTTAG} \ 128 | --lower_case=false \ 129 | --data_dir=${DATADIR}/ \ 130 | --output_dir=${OUTPUTDIR}/data \ 131 | --model_dir=${OUTPUTDIR}/checkpoint \ 132 | --export_dir=${OUTPUTDIR}/export \ 133 | --max_seq_length=${SEQLEN} \ 134 | --max_query_length=${QUERYLEN} \ 135 | --max_answer_length=${ANSWERLEN} \ 136 | --train_batch_size=${BATCHSIZE} \ 137 | --predict_batch_size=${BATCHSIZE} \ 138 | --num_hosts=1 \ 139 | --num_core_per_host=1 \ 140 | --learning_rate=${LEARNINGRATE} \ 141 | --train_steps=${TRAINSTEPS} \ 142 | --warmup_steps=${WARMUPSTEPS} \ 143 | --save_steps=${SAVESTEPS} \ 144 | --do_train=false \ 145 | --do_predict=true \ 146 | --do_export=false \ 147 | --overwrite_data=false 148 | 149 | python tool/convert_squad.py \ 150 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 151 | --span_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 152 | --prob_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.prob.json 153 | 154 | python tool/eval_squad.py \ 155 | ${DATADIR}/dev-${TASKNAME}.json \ 156 | ${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 157 | --out-file ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json \ 158 | --na-prob-file ${OUTPUTDIR}/data/predict.${PREDICTTAG}.prob.json \ 159 | --na-prob-thresh 1.0 \ 160 | --out-image-dir ${OUTPUTDIR}/data/ 161 | 162 | end_time=`date +%s` 163 | echo execution time was `expr $end_time - $start_time` s. 164 | 165 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /run_coqa.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --numturn=*) 37 | NUMTURN="${i#*=}" 38 | shift 39 | ;; 40 | --seqlen=*) 41 | SEQLEN="${i#*=}" 42 | shift 43 | ;; 44 | --querylen=*) 45 | QUERYLEN="${i#*=}" 46 | shift 47 | ;; 48 | --answerlen=*) 49 | ANSWERLEN="${i#*=}" 50 | shift 51 | ;; 52 | --batchsize=*) 53 | BATCHSIZE="${i#*=}" 54 | shift 55 | ;; 56 | --learningrate=*) 57 | LEARNINGRATE="${i#*=}" 58 | shift 59 | ;; 60 | --trainsteps=*) 61 | TRAINSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --warmupsteps=*) 65 | WARMUPSTEPS="${i#*=}" 66 | shift 67 | ;; 68 | --savesteps=*) 69 | SAVESTEPS="${i#*=}" 70 | shift 71 | ;; 72 | --answerthreshold=*) 73 | ANSWERTHRESHOLD="${i#*=}" 74 | shift 75 | ;; 76 | esac 77 | done 78 | 79 | echo "gpu device = ${GPUDEVICE}" 80 | echo "num gpus = ${NUMGPUS}" 81 | echo "task name = ${TASKNAME}" 82 | echo "random seed = ${RANDOMSEED}" 83 | echo "predict tag = ${PREDICTTAG}" 84 | echo "model dir = ${MODELDIR}" 85 | echo "data dir = ${DATADIR}" 86 | echo "output dir = ${OUTPUTDIR}" 87 | echo "num turn = ${NUMTURN}" 88 | echo "seq len = ${SEQLEN}" 89 | echo "query len = ${QUERYLEN}" 90 | echo "answer len = ${ANSWERLEN}" 91 | echo "batch size = ${BATCHSIZE}" 92 | echo "learning rate = ${LEARNINGRATE}" 93 | echo "train steps = ${TRAINSTEPS}" 94 | echo "warmup steps = ${WARMUPSTEPS}" 95 | echo "save steps = ${SAVESTEPS}" 96 | echo "answer threshold = ${ANSWERTHRESHOLD}" 97 | 98 | alias python=python3 99 | mkdir ${OUTPUTDIR} 100 | 101 | start_time=`date +%s` 102 | 103 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_coqa.py \ 104 | --spiece_model_file=${MODELDIR}/spiece.model \ 105 | --model_config_path=${MODELDIR}/xlnet_config.json \ 106 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 107 | --task_name=${TASKNAME} \ 108 | --random_seed=${RANDOMSEED} \ 109 | --predict_tag=${PREDICTTAG} \ 110 | --lower_case=false \ 111 | --data_dir=${DATADIR}/ \ 112 | --output_dir=${OUTPUTDIR}/data \ 113 | --model_dir=${OUTPUTDIR}/checkpoint \ 114 | --export_dir=${OUTPUTDIR}/export \ 115 | --num_turn=${NUMTURN} \ 116 | --max_seq_length=${SEQLEN} \ 117 | --max_query_length=${QUERYLEN} \ 118 | --max_answer_length=${ANSWERLEN} \ 119 | --train_batch_size=${BATCHSIZE} \ 120 | --predict_batch_size=${BATCHSIZE} \ 121 | --num_hosts=1 \ 122 | --num_core_per_host=${NUMGPUS} \ 123 | --learning_rate=${LEARNINGRATE} \ 124 | --train_steps=${TRAINSTEPS} \ 125 | --warmup_steps=${WARMUPSTEPS} \ 126 | --save_steps=${SAVESTEPS} \ 127 | --do_train=true \ 128 | --do_predict=false \ 129 | --do_export=false \ 130 | --overwrite_data=false 131 | 132 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_coqa.py \ 133 | --spiece_model_file=${MODELDIR}/spiece.model \ 134 | --model_config_path=${MODELDIR}/xlnet_config.json \ 135 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 136 | --task_name=${TASKNAME} \ 137 | --random_seed=${RANDOMSEED} \ 138 | --predict_tag=${PREDICTTAG} \ 139 | --lower_case=false \ 140 | --data_dir=${DATADIR}/ \ 141 | --output_dir=${OUTPUTDIR}/data \ 142 | --model_dir=${OUTPUTDIR}/checkpoint \ 143 | --export_dir=${OUTPUTDIR}/export \ 144 | --num_turn=${NUMTURN} \ 145 | --max_seq_length=${SEQLEN} \ 146 | --max_query_length=${QUERYLEN} \ 147 | --max_answer_length=${ANSWERLEN} \ 148 | --train_batch_size=${BATCHSIZE} \ 149 | --predict_batch_size=${BATCHSIZE} \ 150 | --num_hosts=1 \ 151 | --num_core_per_host=1 \ 152 | --learning_rate=${LEARNINGRATE} \ 153 | --train_steps=${TRAINSTEPS} \ 154 | --warmup_steps=${WARMUPSTEPS} \ 155 | --save_steps=${SAVESTEPS} \ 156 | --do_train=false \ 157 | --do_predict=true \ 158 | --do_export=false \ 159 | --overwrite_data=false 160 | 161 | python tool/convert_coqa.py \ 162 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 163 | --output_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 164 | --answer_threshold=${ANSWERTHRESHOLD} 165 | 166 | rm ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 167 | 168 | python tool/eval_coqa.py \ 169 | --data-file=${DATADIR}/dev-${TASKNAME}.json \ 170 | --pred-file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 171 | >> ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 172 | 173 | end_time=`date +%s` 174 | echo execution time was `expr $end_time - $start_time` s. 175 | 176 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /run_quac.sh: -------------------------------------------------------------------------------- 1 | for i in "$@" 2 | do 3 | case $i in 4 | -g=*|--gpudevice=*) 5 | GPUDEVICE="${i#*=}" 6 | shift 7 | ;; 8 | -n=*|--numgpus=*) 9 | NUMGPUS="${i#*=}" 10 | shift 11 | ;; 12 | -t=*|--taskname=*) 13 | TASKNAME="${i#*=}" 14 | shift 15 | ;; 16 | -r=*|--randomseed=*) 17 | RANDOMSEED="${i#*=}" 18 | shift 19 | ;; 20 | -p=*|--predicttag=*) 21 | PREDICTTAG="${i#*=}" 22 | shift 23 | ;; 24 | -m=*|--modeldir=*) 25 | MODELDIR="${i#*=}" 26 | shift 27 | ;; 28 | -d=*|--datadir=*) 29 | DATADIR="${i#*=}" 30 | shift 31 | ;; 32 | -o=*|--outputdir=*) 33 | OUTPUTDIR="${i#*=}" 34 | shift 35 | ;; 36 | --numturn=*) 37 | NUMTURN="${i#*=}" 38 | shift 39 | ;; 40 | --seqlen=*) 41 | SEQLEN="${i#*=}" 42 | shift 43 | ;; 44 | --querylen=*) 45 | QUERYLEN="${i#*=}" 46 | shift 47 | ;; 48 | --answerlen=*) 49 | ANSWERLEN="${i#*=}" 50 | shift 51 | ;; 52 | --batchsize=*) 53 | BATCHSIZE="${i#*=}" 54 | shift 55 | ;; 56 | --learningrate=*) 57 | LEARNINGRATE="${i#*=}" 58 | shift 59 | ;; 60 | --trainsteps=*) 61 | TRAINSTEPS="${i#*=}" 62 | shift 63 | ;; 64 | --warmupsteps=*) 65 | WARMUPSTEPS="${i#*=}" 66 | shift 67 | ;; 68 | --savesteps=*) 69 | SAVESTEPS="${i#*=}" 70 | shift 71 | ;; 72 | --answerthreshold=*) 73 | ANSWERTHRESHOLD="${i#*=}" 74 | shift 75 | ;; 76 | esac 77 | done 78 | 79 | echo "gpu device = ${GPUDEVICE}" 80 | echo "num gpus = ${NUMGPUS}" 81 | echo "task name = ${TASKNAME}" 82 | echo "random seed = ${RANDOMSEED}" 83 | echo "predict tag = ${PREDICTTAG}" 84 | echo "model dir = ${MODELDIR}" 85 | echo "data dir = ${DATADIR}" 86 | echo "output dir = ${OUTPUTDIR}" 87 | echo "num turn = ${NUMTURN}" 88 | echo "seq len = ${SEQLEN}" 89 | echo "query len = ${QUERYLEN}" 90 | echo "answer len = ${ANSWERLEN}" 91 | echo "batch size = ${BATCHSIZE}" 92 | echo "learning rate = ${LEARNINGRATE}" 93 | echo "train steps = ${TRAINSTEPS}" 94 | echo "warmup steps = ${WARMUPSTEPS}" 95 | echo "save steps = ${SAVESTEPS}" 96 | echo "answer threshold = ${ANSWERTHRESHOLD}" 97 | 98 | alias python=python3 99 | mkdir ${OUTPUTDIR} 100 | 101 | start_time=`date +%s` 102 | 103 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_quac.py \ 104 | --spiece_model_file=${MODELDIR}/spiece.model \ 105 | --model_config_path=${MODELDIR}/xlnet_config.json \ 106 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 107 | --task_name=${TASKNAME} \ 108 | --random_seed=${RANDOMSEED} \ 109 | --predict_tag=${PREDICTTAG} \ 110 | --lower_case=false \ 111 | --data_dir=${DATADIR}/ \ 112 | --output_dir=${OUTPUTDIR}/data \ 113 | --model_dir=${OUTPUTDIR}/checkpoint \ 114 | --export_dir=${OUTPUTDIR}/export \ 115 | --num_turn=${NUMTURN} \ 116 | --max_seq_length=${SEQLEN} \ 117 | --max_query_length=${QUERYLEN} \ 118 | --max_answer_length=${ANSWERLEN} \ 119 | --train_batch_size=${BATCHSIZE} \ 120 | --predict_batch_size=${BATCHSIZE} \ 121 | --num_hosts=1 \ 122 | --num_core_per_host=${NUMGPUS} \ 123 | --learning_rate=${LEARNINGRATE} \ 124 | --train_steps=${TRAINSTEPS} \ 125 | --warmup_steps=${WARMUPSTEPS} \ 126 | --save_steps=${SAVESTEPS} \ 127 | --do_train=true \ 128 | --do_predict=false \ 129 | --do_export=false \ 130 | --overwrite_data=false 131 | 132 | CUDA_VISIBLE_DEVICES=${GPUDEVICE} python run_quac.py \ 133 | --spiece_model_file=${MODELDIR}/spiece.model \ 134 | --model_config_path=${MODELDIR}/xlnet_config.json \ 135 | --init_checkpoint=${MODELDIR}/xlnet_model.ckpt \ 136 | --task_name=${TASKNAME} \ 137 | --random_seed=${RANDOMSEED} \ 138 | --predict_tag=${PREDICTTAG} \ 139 | --lower_case=false \ 140 | --data_dir=${DATADIR}/ \ 141 | --output_dir=${OUTPUTDIR}/data \ 142 | --model_dir=${OUTPUTDIR}/checkpoint \ 143 | --export_dir=${OUTPUTDIR}/export \ 144 | --num_turn=${NUMTURN} \ 145 | --max_seq_length=${SEQLEN} \ 146 | --max_query_length=${QUERYLEN} \ 147 | --max_answer_length=${ANSWERLEN} \ 148 | --train_batch_size=${BATCHSIZE} \ 149 | --predict_batch_size=${BATCHSIZE} \ 150 | --num_hosts=1 \ 151 | --num_core_per_host=1 \ 152 | --learning_rate=${LEARNINGRATE} \ 153 | --train_steps=${TRAINSTEPS} \ 154 | --warmup_steps=${WARMUPSTEPS} \ 155 | --save_steps=${SAVESTEPS} \ 156 | --do_train=false \ 157 | --do_predict=true \ 158 | --do_export=false \ 159 | --overwrite_data=false 160 | 161 | python tool/convert_quac.py \ 162 | --input_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.summary.json \ 163 | --output_file=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 164 | --answer_threshold=${ANSWERTHRESHOLD} 165 | 166 | rm ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 167 | 168 | python tool/eval_quac.py \ 169 | --val_file=${DATADIR}/dev-${TASKNAME}.json \ 170 | --model_output=${OUTPUTDIR}/data/predict.${PREDICTTAG}.span.json \ 171 | --o ${OUTPUTDIR}/data/predict.${PREDICTTAG}.eval.json 172 | 173 | end_time=`date +%s` 174 | echo execution time was `expr $end_time - $start_time` s. 175 | 176 | read -n 1 -s -r -p "Press any key to continue..." -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Machine reading comprehension (MRC), a task which asks machine to read a given context then answer questions based on its understanding, is considered one of the key problems in artificial intelligence and has significant interest from both academic and industry. Over the past few years, great progress has been made in this field, thanks to various end-to-end trained neural models and high quality datasets with large amount of examples proposed. 3 | 4 | {:width="800px"} 5 | 6 | *Figure 1: MRC example from SQuAD 2.0 dev set* 7 | 8 | ## DataSet 9 | * [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension dataset, consisting of questions posed by crowd-workers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. 10 | * [CoQA](https://stanfordnlp.github.io/coqa/) a large-scale dataset for building Conversational Question Answering systems. The goal of the CoQA challenge is to measure the ability of machines to understand a text passage and answer a series of interconnected questions that appear in a conversation. CoQA is pronounced as coca 11 | * [QuAC](https://quac.ai/) is a dataset for modeling, understanding, and participating in information seeking dialog. QuAC introduces challenges not found in existing machine comprehension datasets: its questions are often more open-ended, unanswerable, or only meaningful within the dialog context. 12 | 13 | ## Experiment 14 | ### SQuAD v1.1 15 | 16 | {:width="500px"} 17 | 18 | *Figure 2: Illustrations of fine-tuning XLNet on SQuAD v1.1 task* 19 | 20 | | Model | Train Data | # Epoch | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 21 | |:-----------------:|:----------:|:-------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 22 | | XLNet-base | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 85.90 | 92.17 | 23 | | XLNet-large | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 88.61 | 94.28 | 24 | 25 | *Table 1: The dev set performance of XLNet model finetuned on SQuAD v1.1 task* 26 | 27 | ### SQuAD v2.0 28 | 29 | {:width="500px"} 30 | 31 | *Figure 3: Illustrations of fine-tuning XLNet on SQuAD v2.0 task* 32 | 33 | | Model | Train Data | # Epoch | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 34 | |:-----------------:|:----------:|:-------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 35 | | XLNet-base | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 80.23 | 82.90 | 36 | | XLNet-large | SQuAD 2.0 | ~3 | 8,000 | 48 | 512 | 3e-5 | 85.72 | 88.36 | 37 | 38 | *Table 2: The dev set performance of XLNet model finetuned on SQuAD v2.0 task* 39 | 40 | ### CoQA v1.0 41 | 42 | {:width="500px"} 43 | 44 | *Figure 4: Illustrations of fine-tuning XLNet on CoQA v1.0 task* 45 | 46 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | EM | F1 | 47 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:--------:|:--------:| 48 | | XLNet-base | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 76.4 | 84.4 | 49 | | XLNet-large | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 81.8 | 89.4 | 50 | 51 | *Table 3: The dev set performance of XLNet model finetuned on CoQA v1.0 task* 52 | 53 | ### QuAC v0.2 54 | 55 | {:width="500px"} 56 | 57 | *>Figure 5: Illustrations of fine-tuning XLNet on QuAC v0.2 task* 58 | 59 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | Overall F1 | HEQQ | HEQD | 60 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:----------:|:------:|:------:| 61 | | XLNet-base | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 66.4 | 62.6 | 6.8 | 62 | | XLNet-large | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 71.5 | 68.0 | 11.1 | 63 | 64 | *Table 3: The dev set performance of XLNet model finetuned on QuAC v0.2 task* 65 | 66 | ## Reference 67 | * Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. [SQuAD: 100,000+ questions for machine comprehension of text](https://arxiv.org/abs/1606.05250) [2016] 68 | * Pranav Rajpurkar, Robin Jia, and Percy Liang. [Know what you don’t know: unanswerable questions for SQuAD](https://arxiv.org/abs/1806.03822) [2018] 69 | * Siva Reddy, Danqi Chen, Christopher D. Manning. [CoQA: A Conversational Question Answering Challenge](https://arxiv.org/abs/1808.07042) [2018] 70 | * Eunsol Choi, He He, Mohit Iyyer, Mark Yatskar, Wen-tau Yih, Yejin Choi, Percy Liang, Luke Zettlemoyer. [QuAC : Question Answering in Context](https://arxiv.org/abs/1808.07036) [2018] 71 | * Danqi Chen. [Neural reading comprehension and beyond](https://cs.stanford.edu/~danqi/papers/thesis.pdf) [2018] 72 | * Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matthew Gardner, Christopher T Clark, Kenton Lee, and Luke S. Zettlemoyer. [Deep contextualized word representations](https://arxiv.org/abs/1802.05365) [2018] 73 | * Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. [Improving language understanding by generative pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) [2018] 74 | * Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever. [Language models are unsupervised multitask learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) [2019] 75 | * Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. [BERT: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805) [2018] 76 | * Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) [2019] 77 | * Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) [2019] 78 | * Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) [2019] 79 | * Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le. [XLNet: Generalized autoregressive pretraining for language understanding](https://arxiv.org/abs/1906.08237) [2019] 80 | * Zihang Dai, Zhilin Yang, Yiming Yang, William W Cohen, Jaime Carbonell, Quoc V Le and Ruslan Salakhutdinov. [Transformer-XL: Attentive language models beyond a fixed-length context](https://arxiv.org/abs/1901.02860) [2019] 81 | -------------------------------------------------------------------------------- /tool/eval_quac.py: -------------------------------------------------------------------------------- 1 | import json, string, re 2 | from collections import Counter, defaultdict 3 | from argparse import ArgumentParser 4 | 5 | 6 | def is_overlapping(x1, x2, y1, y2): 7 | return max(x1, y1) <= min(x2, y2) 8 | 9 | def normalize_answer(s): 10 | """Lower text and remove punctuation, articles and extra whitespace.""" 11 | def remove_articles(text): 12 | return re.sub(r'\b(a|an|the)\b', ' ', text) 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | def remove_punc(text): 16 | exclude = set(string.punctuation) 17 | return ''.join(ch for ch in text if ch not in exclude) 18 | def lower(text): 19 | return text.lower() 20 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 21 | 22 | def f1_score(prediction, ground_truth): 23 | prediction_tokens = normalize_answer(prediction).split() 24 | ground_truth_tokens = normalize_answer(ground_truth).split() 25 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 26 | num_same = sum(common.values()) 27 | if num_same == 0: 28 | return 0 29 | precision = 1.0 * num_same / len(prediction_tokens) 30 | recall = 1.0 * num_same / len(ground_truth_tokens) 31 | f1 = (2 * precision * recall) / (precision + recall) 32 | return f1 33 | 34 | def exact_match_score(prediction, ground_truth): 35 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 36 | 37 | def display_counter(title, c, c2=None): 38 | print(title) 39 | for key, _ in c.most_common(): 40 | if c2: 41 | print('%s: %d / %d, %.1f%%, F1: %.1f' % ( 42 | key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()), sum(c2[key]) * 100. / len(c2[key]))) 43 | else: 44 | print('%s: %d / %d, %.1f%%' % (key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()))) 45 | 46 | def leave_one_out_max(prediction, ground_truths, article): 47 | if len(ground_truths) == 1: 48 | return metric_max_over_ground_truths(prediction, ground_truths, article)[1] 49 | else: 50 | t_f1 = [] 51 | # leave out one ref every time 52 | for i in range(len(ground_truths)): 53 | idxes = list(range(len(ground_truths))) 54 | idxes.pop(i) 55 | refs = [ground_truths[z] for z in idxes] 56 | t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1]) 57 | return 1.0 * sum(t_f1) / len(t_f1) 58 | 59 | 60 | def metric_max_over_ground_truths(prediction, ground_truths, article): 61 | scores_for_ground_truths = [] 62 | for ground_truth in ground_truths: 63 | score = compute_span_overlap(prediction, ground_truth, article) 64 | scores_for_ground_truths.append(score) 65 | return max(scores_for_ground_truths, key=lambda x: x[1]) 66 | 67 | 68 | def handle_cannot(refs): 69 | num_cannot = 0 70 | num_spans = 0 71 | for ref in refs: 72 | if ref == 'CANNOTANSWER': 73 | num_cannot += 1 74 | else: 75 | num_spans += 1 76 | if num_cannot >= num_spans: 77 | refs = ['CANNOTANSWER'] 78 | else: 79 | refs = [x for x in refs if x != 'CANNOTANSWER'] 80 | return refs 81 | 82 | 83 | def leave_one_out(refs): 84 | if len(refs) == 1: 85 | return 1. 86 | splits = [] 87 | for r in refs: 88 | splits.append(r.split()) 89 | t_f1 = 0.0 90 | for i in range(len(refs)): 91 | m_f1 = 0 92 | for j in range(len(refs)): 93 | if i == j: 94 | continue 95 | f1_ij = f1_score(refs[i], refs[j]) 96 | if f1_ij > m_f1: 97 | m_f1 = f1_ij 98 | t_f1 += m_f1 99 | return t_f1 / len(refs) 100 | 101 | 102 | def compute_span_overlap(pred_span, gt_span, text): 103 | if gt_span == 'CANNOTANSWER': 104 | if pred_span == 'CANNOTANSWER': 105 | return 'Exact match', 1.0 106 | return 'No overlap', 0. 107 | fscore = f1_score(pred_span, gt_span) 108 | pred_start = text.find(pred_span) 109 | gt_start = text.find(gt_span) 110 | 111 | if pred_start == -1 or gt_start == -1: 112 | return 'Span indexing error', fscore 113 | 114 | pred_end = pred_start + len(pred_span) 115 | gt_end = gt_start + len(gt_span) 116 | 117 | fscore = f1_score(pred_span, gt_span) 118 | overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end) 119 | 120 | if exact_match_score(pred_span, gt_span): 121 | return 'Exact match', fscore 122 | if overlap: 123 | return 'Partial overlap', fscore 124 | else: 125 | return 'No overlap', fscore 126 | 127 | 128 | def eval_fn(val_results, model_results, verbose): 129 | span_overlap_stats = Counter() 130 | sentence_overlap = 0. 131 | para_overlap = 0. 132 | total_qs = 0. 133 | f1_stats = defaultdict(list) 134 | unfiltered_f1s = [] 135 | human_f1 = [] 136 | HEQ = 0. 137 | DHEQ = 0. 138 | total_dials = 0. 139 | yes_nos = [] 140 | followups = [] 141 | unanswerables = [] 142 | for p in val_results: 143 | for par in p['paragraphs']: 144 | did = par['id'] 145 | qa_list = par['qas'] 146 | good_dial = 1. 147 | for qa in qa_list: 148 | q_idx = qa['id'] 149 | val_spans = [anss['text'] for anss in qa['answers']] 150 | val_spans = handle_cannot(val_spans) 151 | hf1 = leave_one_out(val_spans) 152 | 153 | if did not in model_results or q_idx not in model_results[did]: 154 | print(did, q_idx, 'no prediction for this dialogue id') 155 | good_dial = 0 156 | f1_stats['NO ANSWER'].append(0.0) 157 | yes_nos.append(False) 158 | followups.append(False) 159 | if val_spans == ['CANNOTANSWER']: 160 | unanswerables.append(0.0) 161 | total_qs += 1 162 | unfiltered_f1s.append(0.0) 163 | if hf1 >= args.min_f1: 164 | human_f1.append(hf1) 165 | continue 166 | 167 | pred_span, pred_yesno, pred_followup = model_results[did][q_idx] 168 | 169 | max_overlap, _ = metric_max_over_ground_truths( \ 170 | pred_span, val_spans, par['context']) 171 | max_f1 = leave_one_out_max( \ 172 | pred_span, val_spans, par['context']) 173 | unfiltered_f1s.append(max_f1) 174 | 175 | # dont eval on low agreement instances 176 | if hf1 < args.min_f1: 177 | continue 178 | 179 | human_f1.append(hf1) 180 | yes_nos.append(pred_yesno == qa['yesno']) 181 | followups.append(pred_followup == qa['followup']) 182 | if val_spans == ['CANNOTANSWER']: 183 | unanswerables.append(max_f1) 184 | if verbose: 185 | print("-" * 20) 186 | print(pred_span) 187 | print(val_spans) 188 | print(max_f1) 189 | print("-" * 20) 190 | if max_f1 >= hf1: 191 | HEQ += 1. 192 | else: 193 | good_dial = 0. 194 | span_overlap_stats[max_overlap] += 1 195 | f1_stats[max_overlap].append(max_f1) 196 | total_qs += 1. 197 | DHEQ += good_dial 198 | total_dials += 1 199 | DHEQ_score = 100.0 * DHEQ / total_dials 200 | HEQ_score = 100.0 * HEQ / total_qs 201 | all_f1s = sum(f1_stats.values(), []) 202 | overall_f1 = 100.0 * sum(all_f1s) / len(all_f1s) 203 | unfiltered_f1 = 100.0 * sum(unfiltered_f1s) / len(unfiltered_f1s) 204 | yesno_score = (100.0 * sum(yes_nos) / len(yes_nos)) 205 | followup_score = (100.0 * sum(followups) / len(followups)) 206 | unanswerable_score = (100.0 * sum(unanswerables) / len(unanswerables)) 207 | metric_json = {"unfiltered_f1": unfiltered_f1, "f1": overall_f1, "HEQ": HEQ_score, "DHEQ": DHEQ_score, "yes/no": yesno_score, "followup": followup_score, "unanswerable_acc": unanswerable_score} 208 | if verbose: 209 | print("=======================") 210 | display_counter('Overlap Stats', span_overlap_stats, f1_stats) 211 | print("=======================") 212 | print('Overall F1: %.1f' % overall_f1) 213 | print('Yes/No Accuracy : %.1f' % yesno_score) 214 | print('Followup Accuracy : %.1f' % followup_score) 215 | print('Unfiltered F1 ({0:d} questions): {1:.1f}'.format(len(unfiltered_f1s), unfiltered_f1)) 216 | print('Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)'.format(unanswerable_score, len(unanswerables))) 217 | print('Human F1: %.1f' % (100.0 * sum(human_f1) / len(human_f1))) 218 | print('Model F1 >= Human F1 (Questions): %d / %d, %.1f%%' % (HEQ, total_qs, 100.0 * HEQ / total_qs)) 219 | print('Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials)) 220 | print("=======================") 221 | return metric_json 222 | 223 | if __name__ == "__main__": 224 | parser = ArgumentParser() 225 | parser.add_argument('--val_file', type=str, required=True, help='file containing validation results') 226 | parser.add_argument('--model_output', type=str, required=True, help='Path to model output.') 227 | parser.add_argument('--o', type=str, required=False, help='Path to save score json') 228 | parser.add_argument('--min_f1', type=float, default=0.4, help='file containing validation results') 229 | parser.add_argument('--verbose', action='store_true', help='print individual scores') 230 | args = parser.parse_args() 231 | val = json.load(open(args.val_file, 'r'))['data'] 232 | preds = defaultdict(dict) 233 | total = 0 234 | val_total = 0 235 | for line in open(args.model_output, 'r'): 236 | if line.strip(): 237 | pred_idx = json.loads(line.strip()) 238 | dia_id = pred_idx['qid'][0].split("_q#")[0] 239 | for qid, qspan, qyesno, qfollowup in zip(pred_idx['qid'], pred_idx['best_span_str'], pred_idx['yesno'], pred_idx['followup']): 240 | preds[dia_id][qid] = qspan, qyesno, qfollowup 241 | total += 1 242 | for p in val: 243 | for par in p['paragraphs']: 244 | did = par['id'] 245 | qa_list = par['qas'] 246 | val_total += len(qa_list) 247 | metric_json = eval_fn(val, preds, args.verbose) 248 | if args.o: 249 | with open(args.o, 'w') as fout: 250 | json.dump(metric_json, fout) 251 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Reading Comprehension 2 | Machine reading comprehension (MRC), a task which asks machine to read a given context then answer questions based on its understanding, is considered one of the key problems in artificial intelligence and has significant interest from both academic and industry. Over the past few years, great progress has been made in this field, thanks to various end-to-end trained neural models and high quality datasets with large amount of examples proposed. 3 |

Figure 1: MRC example from SQuAD 2.0 dev set
5 | 6 | ## Setting 7 | * Python 3.6.7 8 | * Tensorflow 1.13.1 9 | * NumPy 1.13.3 10 | * SentencePiece 0.1.82 11 | 12 | ## DataSet 13 | * [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension dataset, consisting of questions posed by crowd-workers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. 14 | * [CoQA](https://stanfordnlp.github.io/coqa/) a large-scale dataset for building Conversational Question Answering systems. The goal of the CoQA challenge is to measure the ability of machines to understand a text passage and answer a series of interconnected questions that appear in a conversation. CoQA is pronounced as coca 15 | * [QuAC](https://quac.ai/) is a dataset for modeling, understanding, and participating in information seeking dialog. QuAC introduces challenges not found in existing machine comprehension datasets: its questions are often more open-ended, unanswerable, or only meaningful within the dialog context. 16 | 17 | ## Usage 18 | * Run SQuAD experiment 19 | ```bash 20 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_squad.py \ 21 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 22 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 23 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 24 | --task_name=v2.0 \ 25 | --random_seed=100 \ 26 | --predict_tag=xxxxx \ 27 | --data_dir=data/squad/v2.0 \ 28 | --output_dir=output/squad/v2.0/data \ 29 | --model_dir=output/squad/v2.0/checkpoint \ 30 | --export_dir=output/squad/v2.0/export \ 31 | --max_seq_length=512 \ 32 | --train_batch_size=12 \ 33 | --predict_batch_size=12 \ 34 | --num_hosts=1 \ 35 | --num_core_per_host=4 \ 36 | --learning_rate=3e-5 \ 37 | --train_steps=8000 \ 38 | --warmup_steps=1000 \ 39 | --save_steps=1000 \ 40 | --do_train=true \ 41 | --do_predict=true \ 42 | --do_export=true \ 43 | --overwrite_data=false 44 | ``` 45 | * Run CoQA experiment 46 | ```bash 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_coqa.py \ 48 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 49 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 50 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 51 | --task_name=v1.0 \ 52 | --random_seed=100 \ 53 | --predict_tag=xxxxx \ 54 | --data_dir=data/coqa/v1.0 \ 55 | --output_dir=output/coqa/v1.0/data \ 56 | --model_dir=output/coqa/v1.0/checkpoint \ 57 | --export_dir=output/coqa/v1.0/export \ 58 | --max_seq_length=512 \ 59 | --train_batch_size=12 \ 60 | --predict_batch_size=12 \ 61 | --num_hosts=1 \ 62 | --num_core_per_host=4 \ 63 | --learning_rate=3e-5 \ 64 | --train_steps=8000 \ 65 | --warmup_steps=1000 \ 66 | --save_steps=1000 \ 67 | --do_train=true \ 68 | --do_predict=true \ 69 | --do_export=true \ 70 | --overwrite_data=false 71 | ``` 72 | * Run QuAC experiment 73 | ```bash 74 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_quac.py \ 75 | --spiece_model_file=model/cased_L-24_H-1024_A-16/spiece.model \ 76 | --model_config_path=model/cased_L-24_H-1024_A-16/xlnet_config.json \ 77 | --init_checkpoint=model/cased_L-24_H-1024_A-16/xlnet_model.ckpt \ 78 | --task_name=v1.0 \ 79 | --random_seed=100 \ 80 | --predict_tag=xxxxx \ 81 | --data_dir=data/quac/v0.2 \ 82 | --output_dir=output/quac/v0.2/data \ 83 | --model_dir=output/quac/v0.2/checkpoint \ 84 | --export_dir=output/quac/v0.2/export \ 85 | --max_seq_length=512 \ 86 | --train_batch_size=12 \ 87 | --predict_batch_size=12 \ 88 | --num_hosts=1 \ 89 | --num_core_per_host=4 \ 90 | --learning_rate=3e-5 \ 91 | --train_steps=8000 \ 92 | --warmup_steps=1000 \ 93 | --save_steps=1000 \ 94 | --do_train=true \ 95 | --do_predict=true \ 96 | --do_export=true \ 97 | --overwrite_data=false 98 | ``` 99 | 100 | ## Experiment 101 | ### SQuAD v1.1 102 |
Figure 2: Illustrations of fine-tuning XLNet on SQuAD v1.1 task
104 | 105 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 106 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 107 | | XLNet-base | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 85.90 | 92.17 | 108 | | XLNet-large | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 88.61 | 94.28 | 109 | 110 |Table 1: The dev set performance of XLNet model finetuned on SQuAD v1.1 task
111 | 112 | ### SQuAD v2.0 113 |
Figure 3: Illustrations of fine-tuning XLNet on SQuAD v2.0 task
115 | 116 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Learning Rate | EM | F1 | 117 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:--------:|:--------:| 118 | | XLNet-base | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 80.23 | 82.90 | 119 | | XLNet-large | SQuAD 2.0 | 8,000 | 48 | 512 | 3e-5 | 85.72 | 88.36 | 120 | 121 |Table 2: The dev set performance of XLNet model finetuned on SQuAD v2.0 task
122 | 123 | ### CoQA v1.0 124 |
Figure 4: Illustrations of fine-tuning XLNet on CoQA v1.0 task
126 | 127 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | EM | F1 | 128 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:--------:|:--------:| 129 | | XLNet-base | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 76.4 | 84.4 | 130 | | XLNet-large | CoQA 1.0 | 6,000 | 48 | 512 | 128 | 3e-5 | 81.8 | 89.4 | 131 | 132 |Table 3: The dev set performance of XLNet model finetuned on CoQA v1.0 task
133 | 134 | ### QuAC v0.2 135 |
Figure 5: Illustrations of fine-tuning XLNet on QuAC v0.2 task
137 | 138 | | Model | Train Data | # Train Steps | Batch Size | Max Length | Max Query Len | Learning Rate | Overall F1 | HEQQ | HEQD | 139 | |:-------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-------------:|:----------:|:------:|:------:| 140 | | XLNet-base | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 66.4 | 62.6 | 6.8 | 141 | | XLNet-large | QuAC 0.2 | 8,000 | 48 | 512 | 128 | 2e-5 | 71.5 | 68.0 | 11.1 | 142 | 143 |Table 3: The dev set performance of XLNet model finetuned on QuAC v0.2 task
144 | 145 | ## Reference 146 | * Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. [SQuAD: 100,000+ questions for machine comprehension of text](https://arxiv.org/abs/1606.05250) [2016] 147 | * Pranav Rajpurkar, Robin Jia, and Percy Liang. [Know what you don’t know: unanswerable questions for SQuAD](https://arxiv.org/abs/1806.03822) [2018] 148 | * Siva Reddy, Danqi Chen, Christopher D. Manning. [CoQA: A Conversational Question Answering Challenge](https://arxiv.org/abs/1808.07042) [2018] 149 | * Eunsol Choi, He He, Mohit Iyyer, Mark Yatskar, Wen-tau Yih, Yejin Choi, Percy Liang, Luke Zettlemoyer. [QuAC : Question Answering in Context](https://arxiv.org/abs/1808.07036) [2018] 150 | * Danqi Chen. [Neural reading comprehension and beyond](https://cs.stanford.edu/~danqi/papers/thesis.pdf) [2018] 151 | * Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matthew Gardner, Christopher T Clark, Kenton Lee, and Luke S. Zettlemoyer. [Deep contextualized word representations](https://arxiv.org/abs/1802.05365) [2018] 152 | * Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. [Improving language understanding by generative pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) [2018] 153 | * Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever. [Language models are unsupervised multitask learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) [2019] 154 | * Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. [BERT: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805) [2018] 155 | * Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) [2019] 156 | * Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) [2019] 157 | * Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) [2019] 158 | * Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le. [XLNet: Generalized autoregressive pretraining for language understanding](https://arxiv.org/abs/1906.08237) [2019] 159 | * Zihang Dai, Zhilin Yang, Yiming Yang, William W Cohen, Jaime Carbonell, Quoc V Le and Ruslan Salakhutdinov. [Transformer-XL: Attentive language models beyond a fixed-length context](https://arxiv.org/abs/1901.02860) [2019] 160 | -------------------------------------------------------------------------------- /tool/eval_coqa.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for CoQA. 2 | 3 | The code is based partially on SQuAD 2.0 evaluation script. 4 | """ 5 | import argparse 6 | import json 7 | import re 8 | import string 9 | import sys 10 | 11 | from collections import Counter, OrderedDict 12 | 13 | OPTS = None 14 | 15 | out_domain = ["reddit", "science"] 16 | in_domain = ["mctest", "gutenberg", "race", "cnn", "wikipedia"] 17 | domain_mappings = {"mctest":"children_stories", "gutenberg":"literature", "race":"mid-high_school", "cnn":"news", "wikipedia":"wikipedia", "science":"science", "reddit":"reddit"} 18 | 19 | 20 | class CoQAEvaluator(): 21 | 22 | def __init__(self, gold_file): 23 | self.gold_data, self.id_to_source = CoQAEvaluator.gold_answers_to_dict(gold_file) 24 | 25 | @staticmethod 26 | def gold_answers_to_dict(gold_file): 27 | dataset = json.load(open(gold_file)) 28 | gold_dict = {} 29 | id_to_source = {} 30 | for story in dataset['data']: 31 | source = story['source'] 32 | story_id = story['id'] 33 | id_to_source[story_id] = source 34 | questions = story['questions'] 35 | multiple_answers = [story['answers']] 36 | multiple_answers += story['additional_answers'].values() 37 | for i, qa in enumerate(questions): 38 | qid = qa['turn_id'] 39 | if i + 1 != qid: 40 | sys.stderr.write("Turn id should match index {}: {}\n".format(i + 1, qa)) 41 | gold_answers = [] 42 | for answers in multiple_answers: 43 | answer = answers[i] 44 | if qid != answer['turn_id']: 45 | sys.stderr.write("Question turn id does match answer: {} {}\n".format(qa, answer)) 46 | gold_answers.append(answer['input_text']) 47 | key = (story_id, qid) 48 | if key in gold_dict: 49 | sys.stderr.write("Gold file has duplicate stories: {}".format(source)) 50 | gold_dict[key] = gold_answers 51 | return gold_dict, id_to_source 52 | 53 | @staticmethod 54 | def preds_to_dict(pred_file): 55 | preds = json.load(open(pred_file)) 56 | pred_dict = {} 57 | for pred in preds: 58 | pred_dict[(pred['id'], pred['turn_id'])] = pred['answer'] 59 | return pred_dict 60 | 61 | @staticmethod 62 | def normalize_answer(s): 63 | """Lower text and remove punctuation, storys and extra whitespace.""" 64 | 65 | def remove_articles(text): 66 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 67 | return re.sub(regex, ' ', text) 68 | 69 | def white_space_fix(text): 70 | return ' '.join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return ''.join(ch for ch in text if ch not in exclude) 75 | 76 | def lower(text): 77 | return text.lower() 78 | 79 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 80 | 81 | @staticmethod 82 | def get_tokens(s): 83 | if not s: return [] 84 | return CoQAEvaluator.normalize_answer(s).split() 85 | 86 | @staticmethod 87 | def compute_exact(a_gold, a_pred): 88 | return int(CoQAEvaluator.normalize_answer(a_gold) == CoQAEvaluator.normalize_answer(a_pred)) 89 | 90 | @staticmethod 91 | def compute_f1(a_gold, a_pred): 92 | gold_toks = CoQAEvaluator.get_tokens(a_gold) 93 | pred_toks = CoQAEvaluator.get_tokens(a_pred) 94 | common = Counter(gold_toks) & Counter(pred_toks) 95 | num_same = sum(common.values()) 96 | if len(gold_toks) == 0 or len(pred_toks) == 0: 97 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 98 | return int(gold_toks == pred_toks) 99 | if num_same == 0: 100 | return 0 101 | precision = 1.0 * num_same / len(pred_toks) 102 | recall = 1.0 * num_same / len(gold_toks) 103 | f1 = (2 * precision * recall) / (precision + recall) 104 | return f1 105 | 106 | @staticmethod 107 | def _compute_turn_score(a_gold_list, a_pred): 108 | f1_sum = 0.0 109 | em_sum = 0.0 110 | if len(a_gold_list) > 1: 111 | for i in range(len(a_gold_list)): 112 | # exclude the current answer 113 | gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:] 114 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in gold_answers) 115 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in gold_answers) 116 | else: 117 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in a_gold_list) 118 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in a_gold_list) 119 | 120 | return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))} 121 | 122 | def compute_turn_score(self, story_id, turn_id, a_pred): 123 | ''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. ''' 124 | key = (story_id, turn_id) 125 | a_gold_list = self.gold_data[key] 126 | return CoQAEvaluator._compute_turn_score(a_gold_list, a_pred) 127 | 128 | def get_raw_scores(self, pred_data): 129 | ''''Returns a dict with score with each turn prediction''' 130 | exact_scores = {} 131 | f1_scores = {} 132 | for story_id, turn_id in self.gold_data: 133 | key = (story_id, turn_id) 134 | if key not in pred_data: 135 | sys.stderr.write('Missing prediction for {} and turn_id: {}\n'.format(story_id, turn_id)) 136 | continue 137 | a_pred = pred_data[key] 138 | scores = self.compute_turn_score(story_id, turn_id, a_pred) 139 | # Take max over all gold answers 140 | exact_scores[key] = scores['em'] 141 | f1_scores[key] = scores['f1'] 142 | return exact_scores, f1_scores 143 | 144 | def get_raw_scores_human(self): 145 | ''''Returns a dict with score for each turn''' 146 | exact_scores = {} 147 | f1_scores = {} 148 | for story_id, turn_id in self.gold_data: 149 | key = (story_id, turn_id) 150 | f1_sum = 0.0 151 | em_sum = 0.0 152 | if len(self.gold_data[key]) > 1: 153 | for i in range(len(self.gold_data[key])): 154 | # exclude the current answer 155 | gold_answers = self.gold_data[key][0:i] + self.gold_data[key][i + 1:] 156 | em_sum += max(CoQAEvaluator.compute_exact(a, self.gold_data[key][i]) for a in gold_answers) 157 | f1_sum += max(CoQAEvaluator.compute_f1(a, self.gold_data[key][i]) for a in gold_answers) 158 | else: 159 | exit("Gold answers should be multiple: {}={}".format(key, self.gold_data[key])) 160 | exact_scores[key] = em_sum / len(self.gold_data[key]) 161 | f1_scores[key] = f1_sum / len(self.gold_data[key]) 162 | return exact_scores, f1_scores 163 | 164 | def human_performance(self): 165 | exact_scores, f1_scores = self.get_raw_scores_human() 166 | return self.get_domain_scores(exact_scores, f1_scores) 167 | 168 | def model_performance(self, pred_data): 169 | exact_scores, f1_scores = self.get_raw_scores(pred_data) 170 | return self.get_domain_scores(exact_scores, f1_scores) 171 | 172 | def get_domain_scores(self, exact_scores, f1_scores): 173 | sources = {} 174 | for source in in_domain + out_domain: 175 | sources[source] = Counter() 176 | 177 | for story_id, turn_id in self.gold_data: 178 | key = (story_id, turn_id) 179 | source = self.id_to_source[story_id] 180 | sources[source]['em_total'] += exact_scores.get(key, 0) 181 | sources[source]['f1_total'] += f1_scores.get(key, 0) 182 | sources[source]['turn_count'] += 1 183 | 184 | scores = OrderedDict() 185 | in_domain_em_total = 0.0 186 | in_domain_f1_total = 0.0 187 | in_domain_turn_count = 0 188 | 189 | out_domain_em_total = 0.0 190 | out_domain_f1_total = 0.0 191 | out_domain_turn_count = 0 192 | 193 | for source in in_domain + out_domain: 194 | domain = domain_mappings[source] 195 | scores[domain] = {} 196 | scores[domain]['em'] = round(sources[source]['em_total'] / max(1, sources[source]['turn_count']) * 100, 1) 197 | scores[domain]['f1'] = round(sources[source]['f1_total'] / max(1, sources[source]['turn_count']) * 100, 1) 198 | scores[domain]['turns'] = sources[source]['turn_count'] 199 | if source in in_domain: 200 | in_domain_em_total += sources[source]['em_total'] 201 | in_domain_f1_total += sources[source]['f1_total'] 202 | in_domain_turn_count += sources[source]['turn_count'] 203 | elif source in out_domain: 204 | out_domain_em_total += sources[source]['em_total'] 205 | out_domain_f1_total += sources[source]['f1_total'] 206 | out_domain_turn_count += sources[source]['turn_count'] 207 | 208 | scores["in_domain"] = {'em': round(in_domain_em_total / max(1, in_domain_turn_count) * 100, 1), 209 | 'f1': round(in_domain_f1_total / max(1, in_domain_turn_count) * 100, 1), 210 | 'turns': in_domain_turn_count} 211 | scores["out_domain"] = {'em': round(out_domain_em_total / max(1, out_domain_turn_count) * 100, 1), 212 | 'f1': round(out_domain_f1_total / max(1, out_domain_turn_count) * 100, 1), 213 | 'turns': out_domain_turn_count} 214 | 215 | em_total = in_domain_em_total + out_domain_em_total 216 | f1_total = in_domain_f1_total + out_domain_f1_total 217 | turn_count = in_domain_turn_count + out_domain_turn_count 218 | scores["overall"] = {'em': round(em_total / max(1, turn_count) * 100, 1), 219 | 'f1': round(f1_total / max(1, turn_count) * 100, 1), 220 | 'turns': turn_count} 221 | 222 | return scores 223 | 224 | def parse_args(): 225 | parser = argparse.ArgumentParser('Official evaluation script for CoQA.') 226 | parser.add_argument('--data-file', dest="data_file", help='Input data JSON file.') 227 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.') 228 | parser.add_argument('--out-file', '-o', metavar='eval.json', 229 | help='Write accuracy metrics to file (default is stdout).') 230 | parser.add_argument('--verbose', '-v', action='store_true') 231 | parser.add_argument('--human', dest="human", action='store_true') 232 | if len(sys.argv) == 1: 233 | parser.print_help() 234 | sys.exit(1) 235 | return parser.parse_args() 236 | 237 | def main(): 238 | evaluator = CoQAEvaluator(OPTS.data_file) 239 | 240 | if OPTS.human: 241 | print(json.dumps(evaluator.human_performance(), indent=2)) 242 | 243 | if OPTS.pred_file: 244 | with open(OPTS.pred_file) as f: 245 | pred_data = CoQAEvaluator.preds_to_dict(OPTS.pred_file) 246 | print(json.dumps(evaluator.model_performance(pred_data), indent=2)) 247 | 248 | if __name__ == '__main__': 249 | OPTS = parse_args() 250 | main() 251 | -------------------------------------------------------------------------------- /tool/eval_squad.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 224 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 225 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 226 | main_eval['best_exact'] = best_exact 227 | main_eval['best_exact_thresh'] = exact_thresh 228 | main_eval['best_f1'] = best_f1 229 | main_eval['best_f1_thresh'] = f1_thresh 230 | 231 | def main(): 232 | with open(OPTS.data_file) as f: 233 | dataset_json = json.load(f) 234 | dataset = dataset_json['data'] 235 | with open(OPTS.pred_file) as f: 236 | preds = json.load(f) 237 | if OPTS.na_prob_file: 238 | with open(OPTS.na_prob_file) as f: 239 | na_probs = json.load(f) 240 | else: 241 | na_probs = {k: 0.0 for k in preds} 242 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 243 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 244 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 245 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 246 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 247 | OPTS.na_prob_thresh) 248 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 249 | OPTS.na_prob_thresh) 250 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 251 | if has_ans_qids: 252 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 253 | merge_eval(out_eval, has_ans_eval, 'HasAns') 254 | if no_ans_qids: 255 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 256 | merge_eval(out_eval, no_ans_eval, 'NoAns') 257 | if OPTS.na_prob_file: 258 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 259 | if OPTS.na_prob_file and OPTS.out_image_dir: 260 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 261 | qid_to_has_ans, OPTS.out_image_dir) 262 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 263 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 264 | if OPTS.out_file: 265 | with open(OPTS.out_file, 'w') as f: 266 | json.dump(out_eval, f) 267 | else: 268 | print(json.dumps(out_eval, indent=2)) 269 | 270 | if __name__ == '__main__': 271 | OPTS = parse_args() 272 | if OPTS.out_image_dir: 273 | import matplotlib 274 | matplotlib.use('Agg') 275 | import matplotlib.pyplot as plt 276 | main() 277 | 278 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /run_squad.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | sys.path.append('xlnet') # walkaround due to submodule absolute import... 7 | 8 | import collections 9 | import os 10 | import os.path 11 | import json 12 | import pickle 13 | import time 14 | 15 | import tensorflow as tf 16 | import numpy as np 17 | import sentencepiece as sp 18 | 19 | from xlnet import xlnet 20 | import function_builder 21 | import prepro_utils 22 | import model_utils 23 | 24 | MAX_FLOAT = 1e30 25 | MIN_FLOAT = -1e30 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("data_dir", None, "Data directory where raw data located.") 31 | flags.DEFINE_string("output_dir", None, "Output directory where processed data located.") 32 | flags.DEFINE_string("model_dir", None, "Model directory where checkpoints located.") 33 | flags.DEFINE_string("export_dir", None, "Export directory where saved model located.") 34 | 35 | flags.DEFINE_string("task_name", default=None, help="The name of the task to train.") 36 | flags.DEFINE_string("model_config_path", default=None, help="Config file of the pre-trained model.") 37 | flags.DEFINE_string("init_checkpoint", default=None, help="Initial checkpoint of the pre-trained model.") 38 | flags.DEFINE_string("spiece_model_file", default=None, help="Sentence Piece model path.") 39 | flags.DEFINE_bool("overwrite_data", default=False, help="If False, will use cached data if available.") 40 | flags.DEFINE_integer("random_seed", default=100, help="Random seed for weight initialzation.") 41 | flags.DEFINE_string("predict_tag", None, "Predict tag for predict result tracking.") 42 | 43 | flags.DEFINE_bool("do_train", default=False, help="Whether to run training.") 44 | flags.DEFINE_bool("do_predict", default=False, help="Whether to run prediction.") 45 | flags.DEFINE_bool("do_export", default=False, help="Whether to run exporting.") 46 | 47 | flags.DEFINE_enum("init", default="normal", enum_values=["normal", "uniform"], help="Initialization method.") 48 | flags.DEFINE_float("init_std", default=0.02, help="Initialization std when init is normal.") 49 | flags.DEFINE_float("init_range", default=0.1, help="Initialization std when init is uniform.") 50 | flags.DEFINE_bool("init_global_vars", default=False, help="If true, init all global vars. If false, init trainable vars only.") 51 | 52 | flags.DEFINE_bool("lower_case", default=False, help="Enable lower case nor not.") 53 | flags.DEFINE_integer("doc_stride", default=128, help="Doc stride") 54 | flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length") 55 | flags.DEFINE_integer("max_query_length", default=64, help="Max query length") 56 | flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length") 57 | flags.DEFINE_integer("train_batch_size", default=48, help="Total batch size for training.") 58 | flags.DEFINE_integer("predict_batch_size", default=32, help="Total batch size for predict.") 59 | 60 | flags.DEFINE_integer("train_steps", default=8000, help="Number of training steps") 61 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps") 62 | flags.DEFINE_integer("max_save", default=5, help="Max number of checkpoints to save. Use 0 to save all.") 63 | flags.DEFINE_integer("save_steps", default=1000, help="Save the model for every save_steps. If None, not to save any model.") 64 | flags.DEFINE_integer("shuffle_buffer", default=2048, help="Buffer size used for shuffle.") 65 | 66 | flags.DEFINE_integer("n_best_size", default=5, help="n best size for predictions") 67 | flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.") 68 | flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.") 69 | flags.DEFINE_string("target_eval_key", default="best_f1", help="Use has_ans_f1 for Model I.") 70 | 71 | flags.DEFINE_bool("use_bfloat16", default=False, help="Whether to use bfloat16.") 72 | flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.") 73 | flags.DEFINE_float("dropatt", default=0.1, help="Attention dropout rate.") 74 | flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length") 75 | flags.DEFINE_string("summary_type", default="last", help="Method used to summarize a sequence into a vector.") 76 | 77 | flags.DEFINE_float("learning_rate", default=3e-5, help="initial learning rate") 78 | flags.DEFINE_float("min_lr_ratio", default=0.0, help="min lr ratio for cos decay.") 79 | flags.DEFINE_float("lr_layer_decay_rate", default=0.75, help="lr[L] = learning_rate, lr[l-1] = lr[l] * lr_layer_decay_rate.") 80 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping") 81 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate") 82 | flags.DEFINE_float("adam_epsilon", default=1e-6, help="Adam epsilon") 83 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos") 84 | 85 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 86 | flags.DEFINE_integer("num_hosts", 1, "How many TPU hosts.") 87 | flags.DEFINE_integer("num_core_per_host", 1, "Total number of TPU cores to use.") 88 | flags.DEFINE_string("tpu_job_name", None, "TPU worker job name.") 89 | flags.DEFINE_string("tpu", None, "The Cloud TPU name to use for training.") 90 | flags.DEFINE_string("tpu_zone", None, "GCE zone where the Cloud TPU is located in.") 91 | flags.DEFINE_string("gcp_project", None, "Project name for the Cloud TPU-enabled project.") 92 | flags.DEFINE_string("master", None, "TensorFlow master URL") 93 | flags.DEFINE_integer("iterations", 1000, "number of iterations per TPU training loop.") 94 | 95 | class InputExample(object): 96 | """A single SQuAD example.""" 97 | def __init__(self, 98 | qas_id, 99 | question_text, 100 | paragraph_text, 101 | orig_answer_text=None, 102 | start_position=None, 103 | is_impossible=False): 104 | self.qas_id = qas_id 105 | self.question_text = question_text 106 | self.paragraph_text = paragraph_text 107 | self.orig_answer_text = orig_answer_text 108 | self.start_position = start_position 109 | self.is_impossible = is_impossible 110 | 111 | def __str__(self): 112 | return self.__repr__() 113 | 114 | def __repr__(self): 115 | s = "qas_id: %s" % (prepro_utils.printable_text(self.qas_id)) 116 | s += ", question_text: %s" % (prepro_utils.printable_text(self.question_text)) 117 | s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text)) 118 | if self.start_position: 119 | s += ", start_position: %d" % (self.start_position) 120 | s += ", is_impossible: %r" % (self.is_impossible) 121 | return s 122 | 123 | class InputFeatures(object): 124 | """A single SQuAD feature.""" 125 | def __init__(self, 126 | unique_id, 127 | qas_id, 128 | doc_idx, 129 | token2char_raw_start_index, 130 | token2char_raw_end_index, 131 | token2doc_index, 132 | input_ids, 133 | input_mask, 134 | p_mask, 135 | segment_ids, 136 | cls_index, 137 | para_length, 138 | start_position=None, 139 | end_position=None, 140 | is_impossible=None): 141 | self.unique_id = unique_id 142 | self.qas_id = qas_id 143 | self.doc_idx = doc_idx 144 | self.token2char_raw_start_index = token2char_raw_start_index 145 | self.token2char_raw_end_index = token2char_raw_end_index 146 | self.token2doc_index = token2doc_index 147 | self.input_ids = input_ids 148 | self.input_mask = input_mask 149 | self.p_mask = p_mask 150 | self.segment_ids = segment_ids 151 | self.cls_index = cls_index 152 | self.para_length = para_length 153 | self.start_position = start_position 154 | self.end_position = end_position 155 | self.is_impossible = is_impossible 156 | 157 | class OutputResult(object): 158 | """A single SQuAD result.""" 159 | def __init__(self, 160 | unique_id, 161 | answer_prob, 162 | start_prob, 163 | start_index, 164 | end_prob, 165 | end_index): 166 | self.unique_id = unique_id 167 | self.answer_prob = answer_prob 168 | self.start_prob = start_prob 169 | self.start_index = start_index 170 | self.end_prob = end_prob 171 | self.end_index = end_index 172 | 173 | class SquadPipeline(object): 174 | """Pipeline for SQuAD dataset.""" 175 | def __init__(self, 176 | data_dir, 177 | task_name): 178 | self.data_dir = data_dir 179 | self.task_name = task_name 180 | 181 | def get_train_examples(self): 182 | """Gets a collection of `InputExample`s for the train set.""" 183 | data_path = os.path.join(self.data_dir, "train-{0}.json".format(self.task_name)) 184 | data_list = self._read_json(data_path) 185 | example_list = self._get_example(data_list, True) 186 | return example_list 187 | 188 | def get_dev_examples(self): 189 | """Gets a collection of `InputExample`s for the dev set.""" 190 | data_path = os.path.join(self.data_dir, "dev-{0}.json".format(self.task_name)) 191 | data_list = self._read_json(data_path) 192 | example_list = self._get_example(data_list, False) 193 | return example_list 194 | 195 | def _read_json(self, 196 | data_path): 197 | if os.path.exists(data_path): 198 | with open(data_path, "r") as file: 199 | data_list = json.load(file)["data"] 200 | return data_list 201 | else: 202 | raise FileNotFoundError("data path not found: {0}".format(data_path)) 203 | 204 | def _get_example(self, 205 | data_list, 206 | is_training): 207 | examples = [] 208 | for entry in data_list: 209 | for paragraph in entry["paragraphs"]: 210 | paragraph_text = paragraph["context"] 211 | 212 | for qa in paragraph["qas"]: 213 | qas_id = qa["id"] 214 | question_text = qa["question"] 215 | start_position = None 216 | orig_answer_text = None 217 | is_impossible = False 218 | 219 | if is_training: 220 | is_impossible = qa["is_impossible"] 221 | if (len(qa["answers"]) != 1) and (not is_impossible): 222 | raise ValueError("For training, each question should have exactly 1 answer.") 223 | 224 | if not is_impossible: 225 | answer = qa["answers"][0] 226 | orig_answer_text = answer["text"] 227 | start_position = answer["answer_start"] 228 | else: 229 | start_position = -1 230 | orig_answer_text = "" 231 | 232 | example = InputExample( 233 | qas_id=qas_id, 234 | question_text=question_text, 235 | paragraph_text=paragraph_text, 236 | orig_answer_text=orig_answer_text, 237 | start_position=start_position, 238 | is_impossible=is_impossible) 239 | 240 | examples.append(example) 241 | 242 | return examples 243 | 244 | class XLNetTokenizer(object): 245 | """Default text tokenizer for XLNet""" 246 | def __init__(self, 247 | sp_model_file, 248 | lower_case=False): 249 | """Construct XLNet tokenizer""" 250 | self.sp_processor = sp.SentencePieceProcessor() 251 | self.sp_processor.Load(sp_model_file) 252 | self.lower_case = lower_case 253 | 254 | def tokenize(self, 255 | text): 256 | """Tokenize text for XLNet""" 257 | processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case) 258 | tokenized_pieces = prepro_utils.encode_pieces(self.sp_processor, processed_text, return_unicode=False) 259 | return tokenized_pieces 260 | 261 | def encode(self, 262 | text): 263 | """Encode text for XLNet""" 264 | processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case) 265 | encoded_ids = prepro_utils.encode_ids(self.sp_processor, processed_text) 266 | return encoded_ids 267 | 268 | def token_to_id(self, 269 | token): 270 | """Convert token to id for XLNet""" 271 | return self.sp_processor.PieceToId(token) 272 | 273 | def id_to_token(self, 274 | id): 275 | """Convert id to token for XLNet""" 276 | return self.sp_processor.IdToPiece(id) 277 | 278 | def tokens_to_ids(self, 279 | tokens): 280 | """Convert tokens to ids for XLNet""" 281 | return [self.sp_processor.PieceToId(token) for token in tokens] 282 | 283 | def ids_to_tokens(self, 284 | ids): 285 | """Convert ids to tokens for XLNet""" 286 | return [self.sp_processor.IdToPiece(id) for id in ids] 287 | 288 | class XLNetExampleProcessor(object): 289 | """Default example processor for XLNet""" 290 | def __init__(self, 291 | max_seq_length, 292 | max_query_length, 293 | doc_stride, 294 | tokenizer): 295 | """Construct XLNet example processor""" 296 | self.special_vocab_list = ["", " "])
581 | p_mask.append(0)
582 |
583 | doc_para_length = len(input_tokens)
584 |
585 | input_tokens.append(" "])
587 | p_mask.append(1)
588 |
589 | # We put P before Q because during pretraining, B is always shorter than A
590 | for query_token in query_tokens:
591 | input_tokens.append(query_token)
592 | segment_ids.append(self.segment_vocab_map["", "
"])
593 | p_mask.append(1)
594 |
595 | input_tokens.append("
"])
597 | p_mask.append(1)
598 |
599 | cls_index = len(input_tokens)
600 |
601 | input_tokens.append("