├── .gitignore ├── README.md └── uqa ├── cloze2natural.py ├── data_utils.py ├── docker └── Dockerfile ├── evaluate.py ├── generate_new_qadata.py ├── multi_turn.py ├── run_squad.py ├── scripts ├── gen_refqa.sh ├── install_tools.sh ├── run_main.sh └── run_refine.sh ├── utils_squad.py ├── utils_squad_evaluate.py └── wikiref_process.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | 3 | 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
Harvesting and Refining Question-Answer Pairs for Unsupervised QA
3 | 4 | 5 | This repo contains the data, codes and models for the ACL2020 paper ["Harvesting and Refining Question-Answer Pairs for Unsupervised QA"](https://arxiv.org/abs/2005.02925). 6 | 7 | In this work, we introduce two approaches to improve unsupervised QA. First, we harvest lexically and syntactically divergent questions from Wikipedia to automatically construct a corpus of question-answer pairs (named as RefQA). Second, we take advantage of the QA model to extract more appropriate answers, which iteratively refines data over RefQA. We conduct experiments on SQuAD 1.1, and NewsQA by fine-tuning BERT without access to manually annotated data. Our approach outperforms previous unsupervised approaches by a large margin and is competitive with early supervised models. 8 | 9 | ## Environment 10 | 11 | ### With Docker 12 | 13 | The recommended way to run the code is using docker under Linux. The Dockerfile is in `uqa/docker/Dockerfile`. 14 | 15 | ### With Pip 16 | 17 | First you need to install PyTorch 1.1.0. Please refer to [PyTorch installation page](https://pytorch.org/get-started/locally/#start-locally). 18 | 19 | Then, you can clone this repo and install dependencies by `uqa/scripts/install_tools.sh`: 20 | 21 | ```bash 22 | git clone -q https://github.com/NVIDIA/apex.git 23 | cd apex ; git reset --hard 1603407bf49c7fc3da74fceb6a6c7b47fece2ef8 24 | python3 setup.py install --user --cuda_ext --cpp_ext 25 | 26 | pip install --user cython tensorboardX six numpy tqdm path.py pandas scikit-learn lmdb pyarrow py-lz4framed methodtools py-rouge pyrouge nltk 27 | python3 -c "import nltk; nltk.download('punkt')" 28 | pip install -e git://github.com/Maluuba/nlg-eval.git#egg=nlg-eval 29 | 30 | pip install --user spacy==2.2.0 pytorch-transformers==1.2.0 tensorflow-gpu==1.13.1 31 | python3 -m spacy download en 32 | pip install --user benepar[gpu] 33 | ``` 34 | 35 | The mixed-precision training code requires the specific version of [NVIDIA/apex](https://github.com/NVIDIA/apex/tree/1603407bf49c7fc3da74fceb6a6c7b47fece2ef8), which only supports pytorch<1.2.0. 36 | 37 | ## Data and Models 38 | 39 | The format of our generated data is SQuAD-like. The data can be downloaded from [here](https://drive.google.com/open?id=18o8EjlCcimvuF0HYe8sHSu6epTqDwvp_). 40 | 41 | The links to the trained models: 42 | - [refqa-main](https://drive.google.com/open?id=1r2jgFSGtXBRTAeFzGzAwQ_BG4_Bi8v7f): The trained model using 300k RefQA examples; 43 | - [refqa-refine](https://drive.google.com/open?id=1wiAV7sYQFhXVNCuVK8kk9S114_z7Rjwc): The trained model by our refining process. 44 | 45 | ## Constructing RefQA 46 | 47 | In our released data, the `wikiref.json` file (our raw data) contains the Wikipedia statements and corresponding cited documents (the `summary` and `document` key for each item). 48 | 49 | You can convert the raw data to our RefQA by the following script: 50 | 51 | ```bash 52 | export REFQA_DATA_DIR=/{path_to_refqa_data}/ 53 | 54 | python3 wikiref_process.py \ 55 | --input_file wikiref.json \ 56 | --output_file cloze_clause_wikiref_data.json 57 | python3 cloze2natural.py \ 58 | --input_file cloze_clause_wikiref_data.json \ 59 | --output_file refqa.json 60 | ``` 61 | 62 | Note: Please make sure that the file `wikiref.json` is in the directory `$REFQA_DATA_DIR`. 63 | 64 | Then, for the following refining process, you should split your generated data to several parts, such as a main data to train an initial QA model and other parts to do refining process. 65 | 66 | ## Training and Refining 67 | 68 | Before running on RefQA, you should download/move the [data](#data-and-models) and the SQuAD 1.1 dev file `dev-v1.1.json` to the directory `$REFQA_DATA_DIR`. 69 | 70 | We train our QA model using distributed and mixed-precision training on 4 P100 GPUs. 71 | 72 | ### Training the initial QA model 73 | 74 | You can fine-tune BERT-Large (WWM) on 300k RefQA examples and achieve a F1 > 65 on SQuAD 1.1 dev set. 75 | 76 | ```bash 77 | export REFQA_DATA_DIR=/{path_to_refqa_data}/ 78 | export OUTPUT_DIR=/{path_to_main_output}/ 79 | export CUDA_VISIBLE_DEVICES=0,1,2,3 80 | 81 | python3 -m torch.distributed.launch --nproc_per_node=4 run_squad.py \ 82 | --model_type bert \ 83 | --model_name_or_path bert-large-uncased-whole-word-masking \ 84 | --do_train \ 85 | --do_eval \ 86 | --do_lower_case \ 87 | --train_file $REFQA_DATA_DIR/uqa_train_main.json \ 88 | --predict_file $REFQA_DATA_DIR/dev-v1.1.json \ 89 | --learning_rate 3e-5 \ 90 | --num_train_epochs 2 \ 91 | --max_seq_length 384 \ 92 | --doc_stride 128 \ 93 | --output_dir $OUTPUT_DIR \ 94 | --per_gpu_train_batch_size=6 \ 95 | --per_gpu_eval_batch_size=4 \ 96 | --seed 42 \ 97 | --fp16 \ 98 | --overwrite_output_dir \ 99 | --logging_steps 1000 \ 100 | --save_steps 1000 101 | ``` 102 | 103 | ### Refining RefQA data iteratively 104 | 105 | We provide a fine-tuned checkpoint (downloaded from [here](https://drive.google.com/open?id=1r2jgFSGtXBRTAeFzGzAwQ_BG4_Bi8v7f)) used for refining process. The refining process is conducted as follows: 106 | 107 | ```bash 108 | export REFQA_DATA_DIR=/{path_to_refqa_data}/ 109 | export MAIN_MODEL_DIR=/{path_to_previous_fine-tuned_model}/ 110 | export OUTPUT_DIR=/{path_to_refine_output}/ 111 | export CUDA_VISIBLE_DEVICES=0,1,2,3 112 | 113 | python3 multi_turn.py \ 114 | --refine_data_dir $REFQA_DATA_DIR \ 115 | --output_dir $OUTPUT_DIR \ 116 | --model_dir $MAIN_MODEL_DIR \ 117 | --predict_file $REFQA_DATA_DIR/dev-v1.1.json \ 118 | --generate_method 2 \ 119 | --score_threshold 0.15 \ 120 | --threshold_rate 0.9 \ 121 | --seed 17 \ 122 | --fp16 123 | ``` 124 | 125 | The `multi_turn.py` provides the following command line arguments: 126 | 127 | ``` 128 | positional arguments: 129 | --refine_data_dir The directory of RefQA data for refining 130 | --model_dir The directory of the init checkpoint 131 | --output_dir The output directory 132 | --predict_file SQuAD or other json for predictions. E.g., dev-v1.1.json 133 | 134 | optional arguments: 135 | --generate_method {1|2} The method of generating data for next training, 136 | 1 is using refined data only, 2 is merging refined data with filtered data (1:1 ratio) 137 | --score_threshold The threshold for filtering predicted answers 138 | --threshold_rate The decay factor for the above threshold 139 | --seed Random seed for initialization 140 | --fp16 Whether to use 16-bit (mixed) precision (through NVIDIA apex) 141 | ``` 142 | 143 | 144 | ## Citation 145 | If you find this repo useful in your research, you can cite the following paper: 146 | ``` 147 | @inproceedings{li2020refqa, 148 | title = "Harvesting and Refining Question-Answer Pairs for Unsupervised {QA}", 149 | author = "Li, Zhongli and 150 | Wang, Wenhui and 151 | Dong, Li and 152 | Wei, Furu and 153 | Xu, Ke", 154 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 155 | month = jul, 156 | year = "2020", 157 | address = "Online", 158 | publisher = "Association for Computational Linguistics", 159 | url = "https://www.aclweb.org/anthology/2020.acl-main.600", 160 | doi = "10.18653/v1/2020.acl-main.600", 161 | pages = "6719--6728" 162 | } 163 | ``` 164 | 165 | ## Acknowledgment 166 | 167 | Our code is based on [pytorch-transformers 1.2.0](https://github.com/huggingface/transformers/tree/1.2.0). We thank the authors for their wonderful open-source efforts. 168 | 169 | -------------------------------------------------------------------------------- /uqa/cloze2natural.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys, os, json, time 3 | from tqdm import tqdm 4 | import random 5 | from collections import Counter 6 | from nltk import word_tokenize 7 | import numpy as np 8 | from data_utils import reformulate_quesiton 9 | import spacy 10 | import copy 11 | 12 | mask2wh = { 13 | 'PERSONNORPORG' : 'Who', 14 | 'PLACE' : 'Where', 15 | 'THING' : 'What', 16 | 'TEMPORAL': 'When', 17 | 'NUMERIC' : ['How many','How much'] 18 | } 19 | entity_category = { 20 | 'PERSONNORPORG' : "PERSON, NORP, ORG".replace(' ','').split(','), 21 | 'PLACE' : "GPE, LOC, FAC".replace(' ','').split(','), 22 | 'THING' : 'PRODUCT, EVENT, WORK_OF_ART, LAW, LANGUAGE'.replace(' ','').split(','), 23 | 'TEMPORAL': 'TIME, DATE'.replace(' ','').split(','), 24 | 'NUMERIC' : 'PERCENT, MONEY, QUANTITY, ORDINAL, CARDINAL'.replace(' ','').split(',') 25 | } 26 | entity_type_map = {} 27 | for cate in entity_category: 28 | for item in entity_category[cate]: 29 | entity_type_map[item] = cate 30 | data_dir = os.getenv("REFQA_DATA_DIR", "./data") 31 | 32 | def identity_translate(cloze_question): 33 | if 'NUMERIC' in cloze_question: 34 | return cloze_question.replace('NUMERIC', mask2wh['NUMERIC'][int(2*random.random())]) 35 | else: 36 | for mask in mask2wh: 37 | if mask in cloze_question: 38 | return cloze_question.replace(mask, mask2wh[mask]) 39 | raise Exception('\'{}\' should have one specific masked tag.'.format(cloze_question)) 40 | 41 | def word_shuffle(tokens, word_shuffle_param): 42 | length = len(tokens) 43 | noise = np.random.uniform(0, word_shuffle_param, size=(length ) ) 44 | word_idx = np.array([1.0*i for i in range(length)]) 45 | 46 | scores = word_idx + noise 47 | scores += 1e-6 * np.arange(length) 48 | permutation = scores.argsort() 49 | new_s = [ tokens[idx] for idx in permutation ] 50 | return new_s 51 | 52 | def word_dropout(tokens, word_dropout_param): 53 | length = len(tokens) 54 | if word_dropout_param == 0: 55 | return tokens 56 | assert 0 < word_dropout_param < 1 57 | 58 | keep = np.random.rand(length) >= word_dropout_param 59 | #if length: 60 | # keep[0] = 1 61 | new_s = [ w for j, w in enumerate(tokens) if keep[j] ] 62 | return new_s 63 | 64 | def word_mask(tokens, word_mask_param, mask_str='[MASK]'): 65 | length = len(tokens) 66 | if word_mask_param == 0: 67 | return tokens 68 | assert 0 < word_mask_param < 1 69 | 70 | keep = np.random.rand(length) >= word_mask_param 71 | #if length: 72 | # keep[0] = 1 73 | new_s = [ w if keep[j] else mask_str for j, w in enumerate(tokens)] 74 | return new_s 75 | 76 | def noisy_clozes_translate(cloze_question, params=[2, 0.2, 0.1]): 77 | wh = None 78 | for mask in mask2wh: 79 | if mask in cloze_question: 80 | cloze_question = cloze_question.replace(mask,'') 81 | wh = mask2wh[mask] 82 | break 83 | if isinstance(wh , list): 84 | wh = wh[int(2*random.random())] 85 | 86 | tokens = word_tokenize(cloze_question) 87 | tokens = word_shuffle(tokens, params[0]) 88 | tokens = word_dropout(tokens, params[1]) 89 | tokens = word_mask(tokens, params[2]) 90 | return wh+' '+(' '.join(tokens)) 91 | 92 | def cloze_to_natural_questions(input_data, method): 93 | 94 | natural_data = [] 95 | q_count = 0 96 | 97 | parser = spacy.load("en", disable=['ner', 'tagger']) 98 | 99 | for entry in tqdm(input_data, desc="cloze"): 100 | parags = [] 101 | for paragraph in entry['paragraphs']: 102 | qas = [] 103 | for qa in paragraph['qas']: 104 | qa['question'] = qa['question'].replace('PERSON/NORP/ORG', 'PERSONNORPORG') 105 | try: 106 | if method == 0: 107 | qa['question'] = identity_translate(qa['question']) 108 | elif method == 1: 109 | qa['question'] = noisy_clozes_translate(qa['question']) 110 | elif method == 2: 111 | qa['question'] = identity_translate(reformulate_quesiton(qa['question'], parser, reform_version=1) ) 112 | else: 113 | raise NotImplementedError() 114 | except Exception as e: 115 | print(qa['question']) 116 | print(repr(e)) 117 | continue 118 | 119 | qas.append(qa) 120 | paragraph["qas"] = qas 121 | parags.append(paragraph) 122 | q_count += len(qas) 123 | entry["paragraphs"] = parags 124 | natural_data.append(entry) 125 | #if q_count > 10: 126 | # break 127 | 128 | print('Questions Number', q_count) 129 | return {"version": "v2.0", 'data': natural_data} 130 | 131 | 132 | def filter_data_given_qids(input_data_, qids): 133 | input_data = copy.deepcopy(input_data_) 134 | qids = sorted(qids, key=lambda x: int(x.strip().split('_')[-1])) 135 | q_count = 0 136 | new_data = [] 137 | for entry in tqdm(input_data, desc='filter'): 138 | paras = [] 139 | for paragraph in entry['paragraphs']: 140 | qas = [] 141 | for qa in paragraph['qas']: 142 | if q_count < len(qids) and qa['id'] == qids[q_count]: 143 | qas.append(qa) 144 | q_count += 1 145 | if len(qas) == 0: 146 | continue 147 | paragraph['qas'] = qas 148 | paras.append(paragraph) 149 | if len(paras) == 0: 150 | continue 151 | entry['paragraphs'] = paras 152 | new_data.append(entry) 153 | return new_data 154 | 155 | def main(input_file, output_file, method): 156 | input_file = os.path.join(data_dir, input_file) 157 | with open(input_file, "r", encoding='utf-8') as reader: 158 | input_data = json.load(reader)["data"] 159 | natural_data = cloze_to_natural_questions(input_data, method) 160 | json.dump(natural_data, open(os.path.join(data_dir, output_file) , 'w', encoding='utf-8'), indent=4) 161 | 162 | 163 | if __name__ == '__main__': 164 | import argparse 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--input_file", default=None, type=str) 167 | parser.add_argument("--output_file", default=None, type=str) 168 | parser.add_argument("--method", default=2, type=int) 169 | args = parser.parse_args() 170 | main(args.input_file, args.output_file, args.method) -------------------------------------------------------------------------------- /uqa/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys, os, json, time 3 | import numpy as np 4 | from tqdm import tqdm 5 | import random 6 | import logging 7 | from collections import Counter 8 | import spacy 9 | from spacy.tokens import Token 10 | import copy 11 | 12 | logger = logging.getLogger(__name__) 13 | ANSWER_TYPE = ['PERSONNORPORG', 'PLACE', 'THING', 'TEMPORAL', 'NUMERIC'] 14 | 15 | Token.set_extension('lefts', default=[]) 16 | Token.set_extension('rights', default=[]) 17 | Token.set_extension('relative_position', default=0) 18 | 19 | 20 | def tokenize(text): 21 | def is_whitespace(c): 22 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 23 | return True 24 | return False 25 | 26 | prev_is_whitespace = True 27 | doc_tokens = [] 28 | 29 | for c in text: 30 | if is_whitespace(c): 31 | prev_is_whitespace = True 32 | else: 33 | if prev_is_whitespace: 34 | doc_tokens.append(c) 35 | else: 36 | doc_tokens[-1] += c 37 | prev_is_whitespace = False 38 | return doc_tokens 39 | 40 | def parsing_tree_dfs(node): 41 | N = len(node._.lefts) + len(node._.rights) 42 | if N == 0: 43 | return node.text 44 | 45 | text = '' 46 | for child in node._.lefts: 47 | text += parsing_tree_dfs(child)+' ' 48 | text += node.text 49 | for child in node._.rights: 50 | text += ' '+parsing_tree_dfs(child) 51 | return text 52 | 53 | 54 | def reform_tree(node): 55 | #print(node.text, node.head, node.text in ANSWER_TYPE) 56 | if node.text in ANSWER_TYPE: 57 | node._.lefts = [] 58 | return True 59 | flag = False 60 | res = None 61 | for child in node._.lefts: 62 | flag |= reform_tree(child) 63 | if flag: 64 | node._.lefts.remove(child) 65 | node._.lefts = [child] + node._.lefts 66 | break 67 | if not flag: 68 | for child in node._.rights: 69 | flag |= reform_tree(child) 70 | if flag: 71 | node._.rights.remove(child) 72 | node._.lefts = [child] + node._.lefts 73 | break 74 | return flag 75 | 76 | 77 | # 对cloze进行变化,prepend answer related words 78 | def reformulate_quesiton(question, parser, reform_version=1): 79 | 80 | doc = parser(question) 81 | roots = [] 82 | for token in doc: 83 | token._.lefts = [child for child in token.lefts] 84 | token._.rights = [child for child in token.rights] 85 | if token.dep_ == 'ROOT': 86 | roots.append(token) 87 | #print(token.text, token.head.text, [child for child in token.children]) 88 | ### reformulate ### 89 | for root in roots: 90 | if reform_version == 1: 91 | result = reform_tree(root) 92 | else: 93 | result = False 94 | if result: 95 | roots.remove(root) 96 | roots = [root] + roots 97 | ### tree to seqence ### 98 | new_question = '' 99 | for root in roots: 100 | new_question += ' ' + parsing_tree_dfs(root) 101 | return new_question.strip() 102 | 103 | def reformulate_demo(): 104 | parser = spacy.load("en", disable=['ner','tagger']) 105 | #with open('/home/zhongli/projects/XLMRAW/data/mono/cl2/news.cl2', 'r', encoding='utf-8') as f: 106 | # questions = [ line.strip().replace('PERSON/NORP/ORG' ,'PERSONNORPORG') for line in f ] 107 | #f = open('/home/zhongli/projects/XLMRAW/data/mono/cl2r1/news.cl2r1', 'w', encoding='utf-8') 108 | questions = ['What Guillermo crashed a Matt Damon interview , about his upcoming movie THING'] 109 | qs = [] 110 | for qu in tqdm(questions[:10], desc='reform demo'): 111 | tokens = qu.split(' ') 112 | wh = tokens[0] 113 | q_text = ' '.join(tokens[1:]) 114 | print(q_text) 115 | q_text = reformulate_quesiton(q_text, parser, 1) 116 | print(q_text) 117 | print('----------------------') 118 | qu_new = wh + ' ' + q_text 119 | qs.append(qu_new) 120 | 121 | def data_check(input_file): 122 | with open(input_file, "r", encoding='utf-8') as reader: 123 | input_data = json.load(reader)["data"] 124 | q_count = 0 125 | err = 0 126 | for entry in input_data: 127 | for paragraph in entry['paragraphs']: 128 | context = paragraph['context'] 129 | for qa in paragraph['qas']: 130 | q_count += 1 131 | answer_text = qa['answers'][0]['text'] 132 | answer_start= qa['answers'][0]['answer_start'] 133 | if not context[answer_start:].startswith(answer_text): 134 | err += 1 135 | if err == 0: 136 | print(input_file, 'is correct.') 137 | else: 138 | print(input_file, 'has %d problems.'%err) 139 | print('Number of Question:', q_count) 140 | 141 | def data_sample_v2(input_file, sample_number, balance=False, output_file=None): 142 | with open(input_file, "r", encoding='utf-8') as reader: 143 | input_data = json.load(reader)["data"] 144 | sample_data = [] 145 | qids = [] 146 | q_count = 0 147 | 148 | if balance: 149 | whs = ['How', 'Who', 'When', 'What','Where'] 150 | wh_qids = {} 151 | for wh in whs: 152 | wh_qids[wh] = [] 153 | for entry in input_data: 154 | for paragraph in entry['paragraphs']: 155 | for qa in paragraph['qas']: 156 | q_tokens = qa['question'].split() 157 | qid = qa['id'] 158 | q_wh = None 159 | for wh in whs: 160 | if wh in q_tokens: 161 | q_wh = wh 162 | break 163 | if q_wh is not None: 164 | wh_qids[q_wh].append(qid) 165 | balance_number = int(sample_number / len(whs)) 166 | for wh in whs: 167 | if len(wh_qids[wh]) < balance_number: 168 | print(wh, 'quesitons not enough.') 169 | random.shuffle(wh_qids[wh]) 170 | #print(balance_number, len(balance_number)) 171 | qids += wh_qids[wh][:balance_number] 172 | 173 | else: 174 | for entry in input_data: 175 | for paragraph in entry['paragraphs']: 176 | for qa in paragraph['qas']: 177 | qids.append(qa['id']) 178 | random.shuffle(qids) 179 | qids = qids[:sample_number] 180 | 181 | qids = [ int(qid.split('_')[-1]) for qid in qids ] 182 | qids = sorted(qids) 183 | qids.append(-1) 184 | 185 | for entry in tqdm(input_data, desc="sample"): 186 | parags = [] 187 | for paragraph in entry['paragraphs']: 188 | qas = [] 189 | for qa in paragraph['qas']: 190 | qid = int(qa['id'].split('_')[-1]) 191 | if qid == qids[q_count]: 192 | qa['question'] = qa['question'].replace('PERSON/NORP/ORG', 'PERSONNORPORG') 193 | qas.append(qa) 194 | q_count += 1 195 | if len(qas) == 0: 196 | continue 197 | paragraph["qas"] = qas 198 | parags.append(paragraph) 199 | entry["paragraphs"] = parags 200 | sample_data.append(entry) 201 | 202 | print('Questions Number', q_count) 203 | if output_file is None: 204 | output_file = '/'.join(input_file.split('/')[:-1]) + '/' 205 | if balance: 206 | output_file += 'balanced_' 207 | output_file += 'sample_%dw-'%(int(sample_number/10000))+input_file.split('/')[-1] 208 | print('Saving to',output_file) 209 | json.dump({"version": "v2.0", 'data': sample_data}, open(output_file ,'w',encoding='utf-8')) 210 | 211 | def filter_data_given_qids(input_data_, qids, is_sorted=False): 212 | input_data = copy.deepcopy(input_data_) 213 | if not is_sorted: 214 | qids = sorted(qids, key=lambda x: int(x.strip().split('_')[-1])) 215 | q_count = 0 216 | new_data = [] 217 | for entry in tqdm(input_data, desc='filter'): 218 | paras = [] 219 | for paragraph in entry['paragraphs']: 220 | qas = [] 221 | for qa in paragraph['qas']: 222 | if q_count < len(qids) and qa['id'] == qids[q_count]: 223 | qas.append(qa) 224 | q_count += 1 225 | if len(qas) == 0: 226 | continue 227 | paragraph['qas'] = qas 228 | paras.append(paragraph) 229 | if len(paras) == 0: 230 | continue 231 | entry['paragraphs'] = paras 232 | new_data.append(entry) 233 | return new_data 234 | 235 | def data_split(input_file, data_size): 236 | with open(input_file, "r", encoding='utf-8') as reader: 237 | input_data = json.load(reader)["data"] 238 | sample_data = [] 239 | qids = [] 240 | q_count = 0 241 | for entry in input_data: 242 | for paragraph in entry['paragraphs']: 243 | for qa in paragraph['qas']: 244 | qids.append(qa['id']) 245 | random.shuffle(qids) 246 | 247 | num = 0 248 | while num*data_size < len(qids): 249 | nqids = qids[num*data_size: min(len(qids), (num+1)*data_size)] 250 | new_data = filter_data_given_qids(input_data, nqids) 251 | output_file = '/'.join(input_file.split('/')[:-1]) + '/' 252 | output_file += ('%d_'%num)+input_file.split('/')[-1] 253 | json.dump({"version": "v2.0", 'data': new_data}, open(output_file ,'w',encoding='utf-8')) 254 | print(output_file, len(nqids)) 255 | data_check(output_file) 256 | num += 1 257 | 258 | 259 | def data_concat(files, output_file): 260 | all_data = [] 261 | for input_file in files: 262 | with open(input_file, "r", encoding='utf-8') as reader: 263 | input_data = json.load(reader)["data"] 264 | all_data += input_data 265 | json.dump({"version": "v2.0", 'data': all_data}, open(output_file ,'w',encoding='utf-8')) 266 | 267 | def split_all_data(data_dir, input_file, output_files, output_sizes): 268 | input_file = os.path.join(data_dir, input_file) 269 | with open(input_file, "r", encoding='utf-8') as reader: 270 | input_data = json.load(reader)["data"] 271 | 272 | qids = [] 273 | for entry in input_data: 274 | for paragraph in entry['paragraphs']: 275 | for qa in paragraph['qas']: 276 | qids.append(qa['id']) 277 | 278 | q_pos = 0 279 | assert len(output_files) == len(output_sizes) 280 | 281 | for output_file, data_size in zip(output_files, output_sizes): 282 | output_file = os.path.join(data_dir, output_file) 283 | nqids = qids[q_pos: q_pos+data_size] 284 | q_pos += data_size 285 | new_data = filter_data_given_qids(input_data, nqids) 286 | json.dump({"version": "v2.0", 'data': new_data}, open(output_file ,'w',encoding='utf-8')) 287 | data_check(output_file) 288 | 289 | def recover_wikiref(): 290 | input_file = "../uqa_all_data.json" 291 | with open(input_file, "r", encoding='utf-8') as reader: 292 | input_data = json.load(reader)["data"] 293 | nlp = spacy.load("en_core_web_sm", disable=['ner', 'tagger']) 294 | 295 | wikiref_data = {} 296 | 297 | for entry in tqdm(input_data, desc="Recover Wikiref"): 298 | title = entry['title'] 299 | for paragraph in entry['paragraphs']: 300 | context = paragraph['context'] 301 | doc = nlp(context[len(title):].strip()) 302 | sents = [title] + [sent.text.strip() for sent in doc.sents] 303 | for qa in paragraph['qas']: 304 | qid = qa['id'] 305 | summary = qa['summary'] 306 | uid = qid.split('_')[0] 307 | wikiref_data[uid] = { 308 | "uid": uid, 309 | "document": sents, 310 | "summary": summary 311 | } 312 | wikiref = [wikiref_data[key] for key in wikiref_data] 313 | print(len(wikiref)) 314 | json.dump(wikiref, open("../wikiref.json", "w", encoding='utf-8'), indent=4) 315 | 316 | if __name__ == '__main__': 317 | split_all_data('../', 'uqa_all_data.json', 318 | ['uqa_train_main.json'] + ['uqa_train_refine_%d.json'%i for i in range(6)], 319 | [300000] + [100000 for _ in range(6)]) 320 | 321 | -------------------------------------------------------------------------------- /uqa/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-devel 2 | 3 | RUN apt-get update; apt-get install -y vim wget 4 | RUN git clone -q https://github.com/NVIDIA/apex.git 5 | RUN cd apex ; git reset --hard 1603407bf49c7fc3da74fceb6a6c7b47fece2ef8 ;python setup.py install --user --cuda_ext --cpp_ext 6 | 7 | RUN pip install --user cython tensorboardX six numpy tqdm path.py pandas scikit-learn lmdb pyarrow py-lz4framed methodtools py-rouge pyrouge nltk 8 | RUN python -c "import nltk; nltk.download('punkt')" 9 | RUN pip install -e git://github.com/Maluuba/nlg-eval.git#egg=nlg-eval 10 | 11 | RUN pip install --user spacy==2.2.0 pytorch-transformers==1.2.0 tensorflow-gpu==1.13.1 12 | RUN python -m spacy download en 13 | RUN pip install --user benepar[gpu] 14 | 15 | WORKDIR /workspace 16 | RUN chmod -R a+w /workspace -------------------------------------------------------------------------------- /uqa/evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | def evaluate_each_qtype(dataset, predictions): 79 | result = {} 80 | for wh in ['What', 'How', 'When', 'Where', 'Who', 'Other']: 81 | result[wh] = {'f1': 0, 'total': 0} 82 | 83 | other_wh = [] 84 | for article in dataset: 85 | for paragraph in article['paragraphs']: 86 | for qa in paragraph['qas']: 87 | wh = qa['question'].split(' ')[0] 88 | if wh not in result: 89 | other_wh.append(wh) 90 | wh = 'Other' 91 | result[wh]['total'] += 1 92 | 93 | if qa['id'] not in predictions: 94 | message = 'Unanswered question ' + qa['id'] + \ 95 | ' will receive score 0.' 96 | print(message, file=sys.stderr) 97 | continue 98 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 99 | prediction = predictions[qa['id']] 100 | result[wh]['f1'] += metric_max_over_ground_truths( 101 | f1_score, prediction, ground_truths) 102 | #print(Counter(other_wh)) 103 | total = 0 104 | for wh in result: 105 | total += result[wh]['total'] 106 | for wh in result: 107 | result[wh]['f1'] /= result[wh]['total'] 108 | print(wh, 'F1', result[wh]['f1'], 'Rate', result[wh]['total'] / total) 109 | return result 110 | 111 | def evaluate_what(dataset, predictions): 112 | result = {} 113 | for article in dataset: 114 | for paragraph in article['paragraphs']: 115 | for qa in paragraph['qas']: 116 | wh, token = qa['question'].split(' ')[:2] 117 | if wh.lower() != 'what': 118 | continue 119 | if token not in result: 120 | result[token] = {'f1': 0, 'total': 0} 121 | result[token]['total'] += 1 122 | if qa['id'] not in predictions: 123 | message = 'Unanswered question ' + qa['id'] + \ 124 | ' will receive score 0.' 125 | print(message, file=sys.stderr) 126 | continue 127 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 128 | prediction = predictions[qa['id']] 129 | result[token]['f1'] += metric_max_over_ground_truths( 130 | f1_score, prediction, ground_truths) 131 | #print(Counter(other_wh)) 132 | total = 0 133 | for wh in result: 134 | total += result[wh]['total'] 135 | for wh in sorted(list(result.keys()) , key=lambda x: result[x]['total'], reverse=True)[:20]: 136 | result[wh]['f1'] /= result[wh]['total'] 137 | print(wh, 'F1', '%.4f'%result[wh]['f1'], 'Num', result[wh]['total'] , 'Rate', '%.4f'%(result[wh]['total'] / total)) 138 | 139 | if __name__ == '__main__': 140 | expected_version = '1.1' 141 | parser = argparse.ArgumentParser( 142 | description='Evaluation for SQuAD ' + expected_version) 143 | parser.add_argument('dataset_file', help='Dataset file') 144 | parser.add_argument('prediction_file', help='Prediction File') 145 | args = parser.parse_args() 146 | with open(args.dataset_file) as dataset_file: 147 | dataset_json = json.load(dataset_file) 148 | if (dataset_json['version'] != expected_version): 149 | print('Evaluation expects v-' + expected_version + 150 | ', but got dataset with v-' + dataset_json['version'], 151 | file=sys.stderr) 152 | dataset = dataset_json['data'] 153 | with open(args.prediction_file) as prediction_file: 154 | predictions = json.load(prediction_file) 155 | print(json.dumps(evaluate(dataset, predictions))) 156 | evaluate_each_qtype(dataset, predictions) 157 | evaluate_what(dataset, predictions) -------------------------------------------------------------------------------- /uqa/generate_new_qadata.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys, os, json, time 3 | import numpy as np 4 | from tqdm import tqdm 5 | import random 6 | import logging 7 | from collections import Counter 8 | import spacy 9 | import copy 10 | import benepar 11 | import nltk 12 | from data_utils import reformulate_quesiton, data_check, filter_data_given_qids 13 | from cloze2natural import identity_translate 14 | from wikiref_process import get_clause_v2 15 | import argparse 16 | 17 | nltk.download('punkt') 18 | benepar.download('benepar_en2') 19 | 20 | entity_category = { 21 | 'PERSONNORPORG' : "PERSON, NORP, ORG".replace(' ','').split(','), 22 | 'PLACE' : "GPE, LOC, FAC".replace(' ','').split(','), 23 | 'THING' : 'PRODUCT, EVENT, WORK_OF_ART, LAW, LANGUAGE'.replace(' ','').split(','), 24 | 'TEMPORAL': 'TIME, DATE'.replace(' ','').split(','), 25 | 'NUMERIC' : 'PERCENT, MONEY, QUANTITY, ORDINAL, CARDINAL'.replace(' ','').split(',') 26 | } 27 | entity_type_map = {} 28 | for cate in entity_category: 29 | for item in entity_category[cate]: 30 | entity_type_map[item] = cate 31 | 32 | def get_ref_data(data_path='./data/wikiref.json'): 33 | with open(data_path, 'r', encoding='utf-8') as reader: 34 | data = json.load(reader) 35 | refdata = {} 36 | for item in data: 37 | refdata[item['uid']] = item['summary'] 38 | return refdata 39 | refdata = get_ref_data(os.path.join(os.getenv("REFQA_DATA_DIR", "./data"), 'wikiref.json')) 40 | 41 | spacy_ner = spacy.load("en", disable=['parser', 'tagger']) 42 | spacy_tagger = spacy.load("en", disable=['ner', 'parser']) 43 | spacy_parser = spacy.load("en", disable=['ner', 'tagger']) 44 | bene_parser = benepar.Parser("benepar_en2") 45 | 46 | def get_new_question(context, answer, answer_start, summary, qtype): 47 | sentences = [] 48 | for sent in summary: 49 | if answer in sent: 50 | sentences.append(sent) 51 | if len(sentences) == 0 or answer_start == -1: 52 | return None 53 | 54 | doc = spacy_parser(context) 55 | context_sent = None 56 | char_cnt = 0 57 | for sent_item in doc.sents: 58 | sent = sent_item.text 59 | if char_cnt <= answer_start < char_cnt + len(sent): 60 | context_sent = sent 61 | break 62 | else: 63 | char_cnt += len(sent) 64 | while char_cnt < len(context) and context[char_cnt] == ' ': 65 | char_cnt += 1 66 | 67 | if context_sent is None: 68 | return None 69 | 70 | c_tokens = [] 71 | c_doc = spacy_tagger(context_sent) 72 | for token in c_doc: 73 | if not token.is_stop: 74 | c_tokens.append(token.lemma_) 75 | 76 | result = [] 77 | for sent in sentences: 78 | sent_doc = spacy_tagger(sent) 79 | score = 0 80 | for token in sent_doc: 81 | if token.is_stop: 82 | continue 83 | if token.lemma_ in c_tokens: 84 | score += 1 85 | result.append([score, sent]) 86 | result = sorted(result, key=lambda x: x[0]) 87 | sentence = result[-1][1] 88 | cloze_text = None 89 | for clause in get_clause_v2(sentence, bene_parser): 90 | if answer in clause: 91 | cloze_text = clause.replace(answer, qtype, 1) 92 | break 93 | if cloze_text is None: 94 | return None 95 | 96 | new_question = identity_translate(reformulate_quesiton(cloze_text , spacy_parser, reform_version=1) ) 97 | if new_question.startswith('Wh') or new_question.startswith('How'): 98 | return new_question 99 | else: 100 | return None 101 | 102 | def get_answer_start(context, answer, orig_doc_start): 103 | begin_index = len(' '.join(context.split(' ')[:orig_doc_start])) 104 | answer_index = context.find(answer, begin_index) 105 | return answer_index 106 | 107 | def generate(input_file, nbest_file, output_file, remove_em_answer=False, hard_em=False, score_lower_bound=0.5, debug=False): 108 | with open(input_file, "r", encoding='utf-8') as reader: 109 | input_data = json.load(reader)["data"] 110 | with open(nbest_file, "r", encoding='utf-8') as reader: 111 | nbest_data = json.load(reader) 112 | 113 | q_count = 0 114 | 115 | new_data = [] 116 | for entry in (input_data if not debug else tqdm(input_data, desc='generate')): 117 | paras = [] 118 | for paragraph in entry['paragraphs']: 119 | context = paragraph['context'] 120 | qas = [] 121 | for qa in paragraph['qas']: 122 | answer_text = qa['answers'][0]['text'] 123 | qid = qa['id'] 124 | cnt = 0 125 | for ans in (nbest_data[qid][:1] if hard_em else nbest_data[qid]): 126 | if ans['probability'] < score_lower_bound: 127 | continue 128 | new_qa = copy.deepcopy(qa) 129 | new_qa['id'] = qa['id']+'_%d'%cnt 130 | ans['text'] = ans['text'].strip() 131 | 132 | if debug: 133 | new_qa['orig_question'] = qa['question'] 134 | new_qa['orig_answer'] = answer_text 135 | new_qa['predict_answer'] = ans['text'] 136 | new_qa['summary'] = refdata[qid.split('_')[0]] 137 | 138 | if (answer_text == ans['text']) or (ans['text'] in answer_text): 139 | if remove_em_answer and answer_text == ans['text']: 140 | continue 141 | else: 142 | qas.append(new_qa) 143 | else: 144 | new_qa['answers'][0]['text'] = ans['text'] 145 | new_qa['answers'][0]['answer_start'] = get_answer_start(context, ans['text'], ans['orig_doc_start']) 146 | prev_qtype = qa['answers'][0]['type'] 147 | new_qa['question'] = get_new_question(context, ans['text'], new_qa['answers'][0]['answer_start'], refdata[qid.split('_')[0]], entity_type_map[prev_qtype]) 148 | 149 | if (new_qa['question'] is None) or (new_qa['answers'][0]['answer_start'] == -1): 150 | continue 151 | qas.append(new_qa) 152 | cnt += 1 153 | 154 | if len(qas) == 0: 155 | continue 156 | q_count += len(qas) 157 | paragraph['qas'] = qas 158 | paras.append(paragraph) 159 | if len(paras) == 0: 160 | continue 161 | entry['paragraphs'] = paras 162 | new_data.append(entry) 163 | 164 | print('New Questions', q_count) 165 | 166 | json.dump({"version": "v2.0", 'data': new_data}, open(output_file, 'w', encoding='utf-8'), indent=4) 167 | 168 | def generate2(input_file, nbest_file, output_file , hard_em=False, score_lower_bound=0.5, debug=False): 169 | with open(input_file, "r", encoding='utf-8') as reader: 170 | input_data = json.load(reader)["data"] 171 | with open(nbest_file, "r", encoding='utf-8') as reader: 172 | nbest_data = json.load(reader) 173 | 174 | q_count = 0 175 | em_qids = [] 176 | 177 | new_data = [] 178 | for entry in copy.deepcopy(input_data): 179 | paras = [] 180 | for paragraph in entry['paragraphs']: 181 | context = paragraph['context'] 182 | qas = [] 183 | for qa in paragraph['qas']: 184 | answer_text = qa['answers'][0]['text'] 185 | qid = qa['id'] 186 | cnt = 0 187 | for ans in (nbest_data[qid][:1] if hard_em else nbest_data[qid]): 188 | if ans['probability'] < score_lower_bound: 189 | continue 190 | new_qa = copy.deepcopy(qa) 191 | new_qa['id'] = qa['id']+'_%d'%cnt 192 | ans['text'] = ans['text'].strip() 193 | if debug: 194 | new_qa['orig_question'] = qa['question'] 195 | new_qa['orig_answer'] = answer_text 196 | new_qa['predict_answer'] = ans['text'] 197 | new_qa['summary'] = refdata[qid.split('_')[0]] 198 | 199 | if (answer_text == ans['text']) or (ans['text'] in answer_text): 200 | if answer_text == ans['text']: 201 | em_qids.append(qid) 202 | else: 203 | qas.append(new_qa) 204 | else: 205 | new_qa['answers'][0]['text'] = ans['text'] 206 | new_qa['answers'][0]['answer_start'] = get_answer_start(context, ans['text'], ans['orig_doc_start']) 207 | prev_qtype = qa['answers'][0]['type'] 208 | new_qa['question'] = get_new_question(context, ans['text'], new_qa['answers'][0]['answer_start'], refdata[qid.split('_')[0]], entity_type_map[prev_qtype]) 209 | 210 | if (new_qa['question'] is None) or (new_qa['answers'][0]['answer_start'] == -1): 211 | continue 212 | qas.append(new_qa) 213 | cnt += 1 214 | 215 | if len(qas) == 0: 216 | continue 217 | q_count += len(qas) 218 | paragraph['qas'] = qas 219 | paras.append(paragraph) 220 | if len(paras) == 0: 221 | continue 222 | entry['paragraphs'] = paras 223 | new_data.append(entry) 224 | 225 | random.shuffle(em_qids) 226 | em_qids = em_qids[:q_count] 227 | new_data += filter_data_given_qids(input_data, em_qids) 228 | q_count += len(em_qids) 229 | print('New Questions', q_count) 230 | 231 | json.dump({"version": "v2.0", 'data': new_data}, open(output_file, 'w', encoding='utf-8')) 232 | 233 | 234 | if __name__=='__main__': 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument("--output", default=None, type=str, 237 | help="output_file") 238 | parser.add_argument("--input", default=None, type=str, 239 | help="input_file") 240 | parser.add_argument("--nbest", default=None, type=str, 241 | help="nbest_file") 242 | parser.add_argument("--generate_method", default=-1, type=int, 243 | help="The method of generating new qa data.") 244 | parser.add_argument("--score_threshold", default=0.3, type=float, 245 | help="The threshold of generating new qa data.") 246 | parser.add_argument("--seed", default=42, type=int, 247 | help="random seed") 248 | args = parser.parse_args() 249 | random.seed(args.seed) 250 | 251 | if args.generate_method == 1: 252 | generate(args.input, args.nbest, args.output, remove_em_answer=True, score_lower_bound=args.score_threshold) 253 | elif args.generate_method == 2: 254 | generate2(args.input, args.nbest, args.output, score_lower_bound=args.score_threshold) 255 | elif args.generate_method == 3: 256 | generate(args.input, args.nbest, args.output, remove_em_answer=True, hard_em=True, score_lower_bound=args.score_threshold) 257 | elif args.generate_method == 4: 258 | generate2(args.input, args.nbest, args.output, hard_em=True, score_lower_bound=args.score_threshold) 259 | else: 260 | pass -------------------------------------------------------------------------------- /uqa/multi_turn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys, os, json, time 3 | import numpy as np 4 | from tqdm import tqdm 5 | import random 6 | import logging 7 | import argparse 8 | import subprocess 9 | 10 | from evaluate import evaluate 11 | 12 | logger = logging.getLogger('multi-turn') 13 | logger.setLevel(logging.DEBUG) 14 | console_handler = logging.StreamHandler() 15 | console_handler.setLevel(logging.INFO) 16 | logger.addHandler(console_handler) 17 | 18 | def get_nbest_file(model_dir, dev_file, params): 19 | command = 'python run_squad.py --model_type bert \ 20 | --model_name_or_path bert-large-uncased-whole-word-masking \ 21 | --do_eval --do_lower_case --train_file %s \ 22 | --predict_file %s --max_seq_length 384 --doc_stride 128 --output_dir %s \ 23 | --per_gpu_eval_batch_size=12 --eval_prefix dev' % (params.predict_file, dev_file, model_dir) 24 | if params.fp16: 25 | command += ' --fp16' 26 | nbest_file = os.path.join(model_dir, 'nbest_predictions_dev.json') 27 | if params.debug and os.path.isfile(nbest_file): 28 | logger.info('%s already existed and we use it.'%nbest_file) 29 | else: 30 | logger.info('Generating nbest file...') 31 | subprocess.Popen(command, shell=True).wait() 32 | 33 | if not os.path.isfile(nbest_file): 34 | logger.error('Nbest file %s is not found.'%nbest_file) 35 | exit() 36 | logger.info('Got nbest file %s'%nbest_file) 37 | return nbest_file 38 | 39 | def get_new_train_file(dev_file, nbest_file, model_dir, params): 40 | new_train_file = os.path.join(model_dir, 'train_data_for_next_turn.json') 41 | command = 'python generate_new_qadata.py --input %s --nbest %s --output %s \ 42 | --generate_method %d --score_threshold %.4f --seed %d'%(dev_file, nbest_file, new_train_file, params.generate_method, params.score_threshold, params.seed) 43 | subprocess.Popen(command, shell=True).wait() 44 | if not os.path.isfile(new_train_file): 45 | logger.error('New train file %s is not found.'%new_train_file) 46 | exit() 47 | logger.info('Got new train file %s'%new_train_file) 48 | return new_train_file 49 | 50 | 51 | def do_evaluate(dataset_file, prediction_file): 52 | with open(dataset_file) as df: 53 | dataset_json = json.load(df) 54 | dataset = dataset_json['data'] 55 | with open(prediction_file) as pf: 56 | predictions = json.load(pf) 57 | return evaluate(dataset, predictions) 58 | 59 | def train_model(train_file, model_dir, output_dir, params): 60 | command = 'python -m torch.distributed.launch --nproc_per_node=4 run_squad.py \ 61 | --model_type bert --model_name_or_path %s --do_train --do_eval --do_lower_case \ 62 | --train_file %s --predict_file %s \ 63 | --learning_rate 3e-5 --num_train_epochs 1.0 --max_seq_length 384 --doc_stride 128 \ 64 | --output_dir %s --per_gpu_eval_batch_size=6 --per_gpu_train_batch_size=6 --seed %d \ 65 | --logging_steps 1000 --save_steps 1000 --eval_all_checkpoints \ 66 | --overwrite_output_dir --overwrite_cache'%(model_dir, train_file, params.predict_file, output_dir, params.seed) 67 | if params.fp16: 68 | command += ' --fp16' 69 | subprocess.Popen(command, shell=True).wait() 70 | 71 | # select best model for next turn 72 | new_model_dir = output_dir 73 | score = do_evaluate(params.predict_file, os.path.join(output_dir, 'predictions_.json'))['f1'] 74 | 75 | for filename in os.listdir(output_dir): 76 | if (not filename.startswith('predictions_')) or (filename == 'predictions_.json'): 77 | continue 78 | new_score = do_evaluate(params.predict_file, os.path.join(output_dir, filename))['f1'] 79 | if new_score > score: 80 | score = new_score 81 | ckpt = filename.replace('.json', '').replace('predictions_', 'checkpoint-') 82 | new_model_dir = os.path.join(output_dir, ckpt) 83 | subprocess.Popen('cp %s/vocab.txt %s'%(output_dir, new_model_dir), shell=True).wait() 84 | subprocess.Popen('cp %s/special_tokens_map.json %s'%(output_dir, new_model_dir), shell=True).wait() 85 | subprocess.Popen('cp %s/added_tokens.json %s'%(output_dir, new_model_dir), shell=True).wait() 86 | subprocess.Popen('cp %s/%s %s/predictions_.json'%(output_dir, filename, new_model_dir), shell=True).wait() 87 | 88 | 89 | return new_model_dir, score 90 | 91 | 92 | 93 | def main(params): 94 | dev_data_name = os.path.join(args.refine_data_dir, 'uqa_train_refine_%d.json') 95 | 96 | model_dir = os.path.join(params.output_dir, 'init') 97 | if not os.path.exists(model_dir): 98 | subprocess.Popen('mkdir -p %s'%model_dir, shell=True).wait() 99 | logger.info('Copy model from %s to %s.'%(params.model_dir, model_dir)) 100 | subprocess.Popen('cp %s/vocab.txt %s'%(params.model_dir, model_dir), shell=True).wait() 101 | subprocess.Popen('cp %s/special_tokens_map.json %s'%(params.model_dir, model_dir), shell=True).wait() 102 | subprocess.Popen('cp %s/added_tokens.json %s'%(params.model_dir, model_dir), shell=True).wait() 103 | subprocess.Popen('cp %s/config.json %s'%(params.model_dir, model_dir), shell=True).wait() 104 | subprocess.Popen('cp %s/training_args.bin %s'%(params.model_dir, model_dir), shell=True).wait() 105 | subprocess.Popen('cp %s/predictions_.json %s'%(params.model_dir, model_dir), shell=True).wait() 106 | subprocess.Popen('cp %s/pytorch_model.bin %s'%(params.model_dir, model_dir), shell=True).wait() 107 | 108 | if params.debug: 109 | subprocess.Popen('cp %s/nbest_predictions_6_no_train_eval2.json %s/nbest_predictions_dev.json'%(params.model_dir, model_dir), shell=True).wait() 110 | 111 | if os.path.exists(os.path.join(model_dir, 'predictions_.json')): 112 | current_score = do_evaluate(params.predict_file, os.path.join(model_dir, 'predictions_.json'))['f1'] 113 | else: 114 | current_score = 0.0 115 | 116 | order = [1, 3, 2, 4, 5, 0] 117 | if params.debug: 118 | order = [6] + order 119 | 120 | for step, idx in enumerate(order): 121 | logger.info('-'*80) 122 | logger.info('Prepare for turn_%d / Current f1 %.2f/ Current model %s'%(step, current_score, model_dir)) 123 | dev_file = dev_data_name % idx 124 | output_dir = os.path.join(params.output_dir, 'turn_%d'%step) 125 | if not os.path.exists(output_dir): 126 | subprocess.Popen('mkdir -p %s'%output_dir, shell=True).wait() 127 | 128 | nbest_file = get_nbest_file(model_dir, dev_file, params) 129 | new_train_file = get_new_train_file(dev_file, nbest_file, model_dir, params) 130 | 131 | new_model_dir, new_score = train_model(new_train_file, model_dir, output_dir, params) 132 | 133 | if new_score > current_score: 134 | model_dir = new_model_dir 135 | current_score = new_score 136 | logger.info('Find better model %s and f1 is %.2f'%(model_dir, current_score)) 137 | 138 | params.score_threshold = params.score_threshold * params.threshold_rate 139 | 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--refine_data_dir", default=None, type=str, required=True, 146 | help="RefQA data for refining.") 147 | parser.add_argument("--output_dir", default=None, type=str, required=True, 148 | help="The output directory.") 149 | parser.add_argument("--model_dir", default=None, type=str, required=True, 150 | help="The init model directory.") 151 | parser.add_argument("--predict_file", default='dev-v1.1.json', type=str, 152 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 153 | parser.add_argument("--generate_method", default=1, type=int, 154 | help="The method of generating new qa data.") 155 | parser.add_argument("--score_threshold", default=0.3, type=float, 156 | help="The threshold of generating new qa data.") 157 | parser.add_argument("--threshold_rate", default=1.0, type=float, 158 | help="The change rate of the threshold") 159 | parser.add_argument("--seed", default=42, type=int, 160 | help="seed") 161 | parser.add_argument("--fp16", action='store_true', 162 | help="fp16 training") 163 | parser.add_argument("--debug", action='store_true', 164 | help="debug training") 165 | args = parser.parse_args() 166 | args.output_dir = args.output_dir.replace( 167 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 168 | main(args) 169 | -------------------------------------------------------------------------------- /uqa/run_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import logging 22 | import os 23 | import random 24 | import glob 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 29 | TensorDataset) 30 | from torch.utils.data.distributed import DistributedSampler 31 | from tqdm import tqdm, trange 32 | 33 | from tensorboardX import SummaryWriter 34 | 35 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 36 | BertForQuestionAnswering, BertTokenizer, 37 | XLMConfig, XLMForQuestionAnswering, 38 | XLMTokenizer, XLNetConfig, 39 | XLNetForQuestionAnswering, 40 | XLNetTokenizer) 41 | 42 | from pytorch_transformers import AdamW, WarmupLinearSchedule 43 | 44 | from utils_squad import (read_squad_examples, convert_examples_to_features, 45 | RawResult, write_predictions, 46 | RawResultExtended, write_predictions_extended) 47 | 48 | # The follwing import is the official SQuAD evaluation script (2.0). 49 | # You can remove it from the dependencies if you are using this script outside of the library 50 | # We've added it here for automated tests (see examples/test_examples.py file) 51 | from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ 56 | for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) 57 | 58 | MODEL_CLASSES = { 59 | 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), 60 | 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 61 | 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), 62 | } 63 | 64 | def set_seed(args): 65 | random.seed(args.seed) 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | if args.n_gpu > 0: 69 | torch.cuda.manual_seed_all(args.seed) 70 | 71 | def to_list(tensor): 72 | return tensor.detach().cpu().tolist() 73 | 74 | def train(args, train_dataset, model, tokenizer): 75 | """ Train the model """ 76 | if args.local_rank in [-1, 0]: 77 | tb_writer = SummaryWriter() 78 | 79 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 80 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 81 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 82 | 83 | if args.max_steps > 0: 84 | t_total = args.max_steps 85 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 86 | else: 87 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 88 | 89 | if t_total < args.min_steps: 90 | t_total = args.min_steps 91 | args.num_train_epochs = args.min_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 92 | # Prepare optimizer and schedule (linear warmup and decay) 93 | no_decay = ['bias', 'LayerNorm.weight'] 94 | optimizer_grouped_parameters = [ 95 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 96 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 97 | ] 98 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 99 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 100 | if args.fp16: 101 | try: 102 | from apex import amp 103 | except ImportError: 104 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 105 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 106 | 107 | # multi-gpu training (should be after apex fp16 initialization) 108 | if args.n_gpu > 1: 109 | model = torch.nn.DataParallel(model) 110 | 111 | # Distributed training (should be after apex fp16 initialization) 112 | if args.local_rank != -1: 113 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 114 | output_device=args.local_rank, 115 | find_unused_parameters=True) 116 | 117 | # Train! 118 | logger.info("***** Running training *****") 119 | logger.info(" Num examples = %d", len(train_dataset)) 120 | logger.info(" Num Epochs = %d", args.num_train_epochs) 121 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 122 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 123 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 124 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 125 | logger.info(" Total optimization steps = %d", t_total) 126 | 127 | global_step = 0 128 | tr_loss, logging_loss = 0.0, 0.0 129 | model.zero_grad() 130 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 131 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 132 | for _ in train_iterator: 133 | #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 134 | for step, batch in enumerate(train_dataloader): 135 | model.train() 136 | batch = tuple(t.to(args.device) for t in batch) 137 | inputs = {'input_ids': batch[0], 138 | 'attention_mask': batch[1], 139 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2], 140 | 'start_positions': batch[3], 141 | 'end_positions': batch[4]} 142 | if args.model_type in ['xlnet', 'xlm']: 143 | inputs.update({'cls_index': batch[5], 144 | 'p_mask': batch[6]}) 145 | outputs = model(**inputs) 146 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 147 | 148 | if args.n_gpu > 1: 149 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 150 | if args.gradient_accumulation_steps > 1: 151 | loss = loss / args.gradient_accumulation_steps 152 | 153 | if args.fp16: 154 | with amp.scale_loss(loss, optimizer) as scaled_loss: 155 | scaled_loss.backward() 156 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 157 | else: 158 | loss.backward() 159 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 160 | 161 | tr_loss += loss.item() 162 | if (step + 1) % args.gradient_accumulation_steps == 0: 163 | scheduler.step() # Update learning rate schedule 164 | optimizer.step() 165 | model.zero_grad() 166 | global_step += 1 167 | 168 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 169 | # Log metrics 170 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 171 | results = evaluate(args, model, tokenizer) 172 | for key, value in results.items(): 173 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 174 | logger.info('Eval F1 {}'.format(results['f1'])) 175 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 176 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 177 | logger.info('Step {}, LR {}, Loss {}'.format(global_step,scheduler.get_lr()[0],(tr_loss - logging_loss)/args.logging_steps)) 178 | logging_loss = tr_loss 179 | 180 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 181 | # Save model checkpoint 182 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 183 | if not os.path.exists(output_dir): 184 | os.makedirs(output_dir) 185 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 186 | model_to_save.save_pretrained(output_dir) 187 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 188 | logger.info("Saving model checkpoint to %s", output_dir) 189 | if args.few_shot: 190 | args.save_steps = args.save_steps * 2 191 | if args.save_steps > args.logging_steps: 192 | args.save_steps = args.logging_steps 193 | 194 | 195 | if args.max_steps > 0 and global_step > args.max_steps: 196 | epoch_iterator.close() 197 | break 198 | if args.max_steps > 0 and global_step > args.max_steps: 199 | train_iterator.close() 200 | break 201 | 202 | if args.local_rank in [-1, 0]: 203 | tb_writer.close() 204 | 205 | return global_step, tr_loss / global_step 206 | 207 | 208 | def evaluate(args, model, tokenizer, prefix=""): 209 | dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) 210 | 211 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 212 | os.makedirs(args.output_dir) 213 | 214 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 215 | # Note that DistributedSampler samples randomly 216 | eval_sampler = SequentialSampler(dataset) 217 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 218 | 219 | # Eval! 220 | logger.info("***** Running evaluation {} *****".format(prefix)) 221 | logger.info(" Num examples = %d", len(dataset)) 222 | logger.info(" Batch size = %d", args.eval_batch_size) 223 | all_results = [] 224 | for batch in eval_dataloader: 225 | model.eval() 226 | batch = tuple(t.to(args.device) for t in batch) 227 | with torch.no_grad(): 228 | inputs = {'input_ids': batch[0], 229 | 'attention_mask': batch[1], 230 | 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids 231 | } 232 | example_indices = batch[3] 233 | if args.model_type in ['xlnet', 'xlm']: 234 | inputs.update({'cls_index': batch[4], 235 | 'p_mask': batch[5]}) 236 | outputs = model(**inputs) 237 | 238 | for i, example_index in enumerate(example_indices): 239 | eval_feature = features[example_index.item()] 240 | unique_id = int(eval_feature.unique_id) 241 | if args.model_type in ['xlnet', 'xlm']: 242 | # XLNet uses a more complex post-processing procedure 243 | result = RawResultExtended(unique_id = unique_id, 244 | start_top_log_probs = to_list(outputs[0][i]), 245 | start_top_index = to_list(outputs[1][i]), 246 | end_top_log_probs = to_list(outputs[2][i]), 247 | end_top_index = to_list(outputs[3][i]), 248 | cls_logits = to_list(outputs[4][i])) 249 | else: 250 | result = RawResult(unique_id = unique_id, 251 | start_logits = to_list(outputs[0][i]), 252 | end_logits = to_list(outputs[1][i])) 253 | all_results.append(result) 254 | 255 | # Compute predictions 256 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) 257 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) 258 | if args.version_2_with_negative: 259 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) 260 | else: 261 | output_null_log_odds_file = None 262 | 263 | if args.model_type in ['xlnet', 'xlm']: 264 | # XLNet uses a more complex post-processing procedure 265 | write_predictions_extended(examples, features, all_results, args.n_best_size, 266 | args.max_answer_length, output_prediction_file, 267 | output_nbest_file, output_null_log_odds_file, args.predict_file, 268 | model.config.start_n_top, model.config.end_n_top, 269 | args.version_2_with_negative, tokenizer, args.verbose_logging) 270 | else: 271 | write_predictions(examples, features, all_results, args.n_best_size, 272 | args.max_answer_length, args.do_lower_case, output_prediction_file, 273 | output_nbest_file, output_null_log_odds_file, args.verbose_logging, 274 | args.version_2_with_negative, args.null_score_diff_threshold) 275 | 276 | # Evaluate with the official SQuAD script 277 | evaluate_options = EVAL_OPTS(data_file=args.predict_file, 278 | pred_file=output_prediction_file, 279 | na_prob_file=output_null_log_odds_file) 280 | results = evaluate_on_squad(evaluate_options) 281 | return results 282 | 283 | 284 | def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): 285 | # Load data features from cache or dataset file 286 | input_file = args.predict_file if evaluate else args.train_file 287 | cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}_{}'.format( 288 | 'dev' if evaluate else 'train', str(args.max_seq_length), list(filter(None, args.model_name_or_path.split('/'))).pop(), 289 | list(filter(None, input_file.split('/'))).pop())) 290 | if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: 291 | logger.info("Loading features from cached file %s", cached_features_file) 292 | features = torch.load(cached_features_file) 293 | else: 294 | logger.info("Creating features from dataset file at %s", input_file) 295 | examples = read_squad_examples(input_file=input_file, 296 | is_training=not evaluate, 297 | version_2_with_negative=args.version_2_with_negative) 298 | features = convert_examples_to_features(examples=examples, 299 | tokenizer=tokenizer, 300 | max_seq_length=args.max_seq_length, 301 | doc_stride=args.doc_stride, 302 | max_query_length=args.max_query_length, 303 | is_training=not evaluate) 304 | if (args.local_rank in [-1, 0]) and (not evaluate): 305 | logger.info("Saving features into cached file %s", cached_features_file) 306 | torch.save(features, cached_features_file) 307 | 308 | # Convert to Tensors and build dataset 309 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 310 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 311 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 312 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 313 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 314 | if evaluate: 315 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 316 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 317 | all_example_index, all_cls_index, all_p_mask) 318 | else: 319 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 320 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 321 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 322 | all_start_positions, all_end_positions, 323 | all_cls_index, all_p_mask) 324 | 325 | if output_examples: 326 | return dataset, examples, features 327 | return dataset 328 | 329 | 330 | def main(): 331 | parser = argparse.ArgumentParser() 332 | 333 | ## Required parameters 334 | parser.add_argument("--train_file", default=None, type=str, required=True, 335 | help="SQuAD json for training. E.g., train-v1.1.json") 336 | parser.add_argument("--predict_file", default=None, type=str, required=True, 337 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 338 | parser.add_argument("--model_type", default=None, type=str, required=True, 339 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 340 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 341 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 342 | parser.add_argument("--output_dir", default=None, type=str, required=True, 343 | help="The output directory where the model checkpoints and predictions will be written.") 344 | 345 | ## Other parameters 346 | parser.add_argument("--eval_prefix", default="", type=str, 347 | help="Evaluate prefix") 348 | 349 | parser.add_argument("--config_name", default="", type=str, 350 | help="Pretrained config name or path if not the same as model_name") 351 | parser.add_argument("--tokenizer_name", default="", type=str, 352 | help="Pretrained tokenizer name or path if not the same as model_name") 353 | parser.add_argument("--cache_dir", default="", type=str, 354 | help="Where do you want to store the pre-trained models downloaded from s3") 355 | 356 | parser.add_argument('--version_2_with_negative', action='store_true', 357 | help='If true, the SQuAD examples contain some that do not have an answer.') 358 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, 359 | help="If null_score - best_non_null is greater than the threshold predict null.") 360 | 361 | parser.add_argument("--max_seq_length", default=384, type=int, 362 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 363 | "longer than this will be truncated, and sequences shorter than this will be padded.") 364 | parser.add_argument("--doc_stride", default=128, type=int, 365 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 366 | parser.add_argument("--max_query_length", default=64, type=int, 367 | help="The maximum number of tokens for the question. Questions longer than this will " 368 | "be truncated to this length.") 369 | parser.add_argument("--do_train", action='store_true', 370 | help="Whether to run training.") 371 | parser.add_argument("--do_eval", action='store_true', 372 | help="Whether to run eval on the dev set.") 373 | parser.add_argument("--evaluate_during_training", action='store_true', 374 | help="Rul evaluation during training at each logging step.") 375 | parser.add_argument("--do_lower_case", action='store_true', 376 | help="Set this flag if you are using an uncased model.") 377 | 378 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 379 | help="Batch size per GPU/CPU for training.") 380 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 381 | help="Batch size per GPU/CPU for evaluation.") 382 | parser.add_argument("--learning_rate", default=5e-5, type=float, 383 | help="The initial learning rate for Adam.") 384 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 385 | help="Number of updates steps to accumulate before performing a backward/update pass.") 386 | parser.add_argument("--weight_decay", default=0.0, type=float, 387 | help="Weight deay if we apply some.") 388 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 389 | help="Epsilon for Adam optimizer.") 390 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 391 | help="Max gradient norm.") 392 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 393 | help="Total number of training epochs to perform.") 394 | parser.add_argument("--max_steps", default=-1, type=int, 395 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 396 | parser.add_argument("--min_steps", default=-1, type=int, 397 | help="If > 0: ") 398 | parser.add_argument("--warmup_steps", default=0, type=int, 399 | help="Linear warmup over warmup_steps.") 400 | parser.add_argument("--n_best_size", default=20, type=int, 401 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") 402 | parser.add_argument("--max_answer_length", default=30, type=int, 403 | help="The maximum length of an answer that can be generated. This is needed because the start " 404 | "and end predictions are not conditioned on one another.") 405 | parser.add_argument("--verbose_logging", action='store_true', 406 | help="If true, all of the warnings related to data processing will be printed. " 407 | "A number of warnings are expected for a normal SQuAD evaluation.") 408 | 409 | parser.add_argument('--logging_steps', type=int, default=50, 410 | help="Log every X updates steps.") 411 | parser.add_argument('--save_steps', type=int, default=50, 412 | help="Save checkpoint every X updates steps.") 413 | parser.add_argument('--few_shot', action='store_true', help='few-shot learning') 414 | parser.add_argument("--eval_all_checkpoints", action='store_true', 415 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 416 | parser.add_argument("--no_cuda", action='store_true', 417 | help="Whether not to use CUDA when available") 418 | parser.add_argument('--overwrite_output_dir', action='store_true', 419 | help="Overwrite the content of the output directory") 420 | parser.add_argument('--overwrite_cache', action='store_true', 421 | help="Overwrite the cached training and evaluation sets") 422 | parser.add_argument('--seed', type=int, default=42, 423 | help="random seed for initialization") 424 | 425 | parser.add_argument("--local_rank", type=int, default=-1, 426 | help="local_rank for distributed training on gpus") 427 | parser.add_argument('--fp16', action='store_true', 428 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 429 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 430 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 431 | "See details at https://nvidia.github.io/apex/amp.html") 432 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 433 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 434 | args = parser.parse_args() 435 | args.output_dir = args.output_dir.replace( 436 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 437 | 438 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 439 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 440 | 441 | # Setup distant debugging if needed 442 | if args.server_ip and args.server_port: 443 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 444 | import ptvsd 445 | print("Waiting for debugger attach") 446 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 447 | ptvsd.wait_for_attach() 448 | 449 | if args.few_shot: 450 | args.save_steps = 64 451 | 452 | # Setup CUDA, GPU & distributed training 453 | if args.local_rank == -1 or args.no_cuda: 454 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 455 | args.n_gpu = torch.cuda.device_count() 456 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 457 | torch.cuda.set_device(args.local_rank) 458 | device = torch.device("cuda", args.local_rank) 459 | torch.distributed.init_process_group(backend='nccl') 460 | args.n_gpu = 1 461 | args.device = device 462 | 463 | # Setup logging 464 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 465 | datefmt = '%m/%d/%Y %H:%M:%S', 466 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 467 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 468 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 469 | 470 | # Set seed 471 | set_seed(args) 472 | 473 | # Load pretrained model and tokenizer 474 | if args.local_rank not in [-1, 0]: 475 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 476 | 477 | args.model_type = args.model_type.lower() 478 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 479 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 480 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) 481 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) 482 | 483 | if args.local_rank == 0: 484 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 485 | 486 | model.to(args.device) 487 | 488 | logger.info("Training/evaluation parameters %s", args) 489 | 490 | # Training 491 | if args.do_train: 492 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) 493 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 494 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 495 | 496 | 497 | # Save the trained model and the tokenizer 498 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 499 | # Create output directory if needed 500 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 501 | os.makedirs(args.output_dir) 502 | 503 | logger.info("Saving model checkpoint to %s", args.output_dir) 504 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 505 | # They can then be reloaded using `from_pretrained()` 506 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 507 | model_to_save.save_pretrained(args.output_dir) 508 | tokenizer.save_pretrained(args.output_dir) 509 | 510 | # Good practice: save your training arguments together with the trained model 511 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 512 | 513 | # Load a trained model and vocabulary that you have fine-tuned 514 | model = model_class.from_pretrained(args.output_dir) 515 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 516 | model.to(args.device) 517 | 518 | 519 | # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory 520 | results = {} 521 | if args.do_eval and args.local_rank in [-1, 0]: 522 | checkpoints = [args.output_dir] 523 | if args.eval_all_checkpoints: 524 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 525 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs 526 | 527 | if len(checkpoints) > 1: 528 | final_model = checkpoints[-1] 529 | checkpoints = checkpoints[:-1] 530 | #print(checkpoints) 531 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-')[-1])) 532 | checkpoints.append(final_model) 533 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 534 | 535 | 536 | for checkpoint in checkpoints: 537 | # Reload the model 538 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 539 | model = model_class.from_pretrained(checkpoint) 540 | model.to(args.device) 541 | 542 | # Evaluate 543 | if global_step.isdigit(): 544 | global_step = args.eval_prefix + global_step 545 | else: 546 | global_step = args.eval_prefix 547 | result = evaluate(args, model, tokenizer, prefix=global_step) 548 | logger.info("Step {}, Result {}".format(global_step, result)) 549 | 550 | result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) 551 | results.update(result) 552 | 553 | logger.info("Results: {}".format(results)) 554 | 555 | return results 556 | 557 | 558 | if __name__ == "__main__": 559 | main() -------------------------------------------------------------------------------- /uqa/scripts/gen_refqa.sh: -------------------------------------------------------------------------------- 1 | export REFQA_DATA_DIR=/root/data/refqa 2 | 3 | cd .. 4 | 5 | python3 wikiref_process.py --input_file wikiref.json --output_file cloze_clause_wikiref_data.json 6 | python3 cloze2natural.py --input_file cloze_clause_wikiref_data.json --output_file refqa.json 7 | -------------------------------------------------------------------------------- /uqa/scripts/install_tools.sh: -------------------------------------------------------------------------------- 1 | git clone -q https://github.com/NVIDIA/apex.git 2 | cd apex ; git reset --hard 1603407bf49c7fc3da74fceb6a6c7b47fece2ef8 3 | python setup.py install --user --cuda_ext --cpp_ext 4 | 5 | pip install --user cython tensorboardX six numpy tqdm path.py pandas scikit-learn lmdb pyarrow py-lz4framed methodtools py-rouge pyrouge nltk 6 | python -c "import nltk; nltk.download('punkt')" 7 | pip install -e git://github.com/Maluuba/nlg-eval.git#egg=nlg-eval 8 | 9 | pip install --user spacy==2.2.0 pytorch-transformers==1.2.0 tensorflow-gpu==1.13.1 10 | python -m spacy download en 11 | pip install --user benepar[gpu] -------------------------------------------------------------------------------- /uqa/scripts/run_main.sh: -------------------------------------------------------------------------------- 1 | export REFQA_DATA_DIR=/root/data/refqa 2 | export PYTORCH_PRETRAINED_BERT_CACHE=/root/pretrained_weights 3 | export OUTPUT_DIR=/root/model_outputs/refqa_main_model_output 4 | 5 | cd ../ 6 | 7 | python -m torch.distributed.launch --nproc_per_node=4 run_squad.py \ 8 | --model_type bert \ 9 | --model_name_or_path bert-large-uncased-whole-word-masking \ 10 | --do_train \ 11 | --do_eval \ 12 | --do_lower_case \ 13 | --train_file $REFQA_DATA_DIR/uqa_train_main.json \ 14 | --predict_file $REFQA_DATA_DIR/dev-v1.1.json \ 15 | --learning_rate 3e-5 \ 16 | --num_train_epochs 2 \ 17 | --max_seq_length 384 \ 18 | --doc_stride 128 \ 19 | --output_dir $OUTPUT_DIR \ 20 | --per_gpu_train_batch_size=6 \ 21 | --per_gpu_eval_batch_size=4 \ 22 | --seed 42 \ 23 | --fp16 \ 24 | --overwrite_output_dir \ 25 | --logging_steps 1000 \ 26 | --save_steps 1000 ; 27 | -------------------------------------------------------------------------------- /uqa/scripts/run_refine.sh: -------------------------------------------------------------------------------- 1 | export REFQA_DATA_DIR=/root/data/refqa 2 | export PYTORCH_PRETRAINED_BERT_CACHE=/root/pretrained_weights 3 | export MAIN_MODEL_DIR=/root/model_outputs/best_main_model 4 | export OUTPUT_DIR=/root/model_outputs/refqa_refine_model_output 5 | 6 | cd .. 7 | 8 | python multi_turn.py \ 9 | --refine_data_dir $REFQA_DATA_DIR \ 10 | --output_dir $OUTPUT_DIR \ 11 | --model_dir $MAIN_MODEL_DIR \ 12 | --predict_file $REFQA_DATA_DIR/dev-v1.1.json \ 13 | --generate_method 2 \ 14 | --score_threshold 0.15 \ 15 | --threshold_rate 0.9 \ 16 | --seed 17 \ 17 | --fp16 -------------------------------------------------------------------------------- /uqa/utils_squad.py: -------------------------------------------------------------------------------- 1 | 2 | # coding=utf-8 3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Load SQuAD dataset. """ 18 | 19 | from __future__ import absolute_import, division, print_function 20 | 21 | import json 22 | import logging 23 | import math 24 | import collections 25 | from io import open 26 | 27 | from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize 28 | 29 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 30 | from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class SquadExample(object): 36 | """ 37 | A single training/test example for the Squad dataset. 38 | For examples without an answer, the start and end position are -1. 39 | """ 40 | 41 | def __init__(self, 42 | qas_id, 43 | question_text, 44 | doc_tokens, 45 | orig_answer_text=None, 46 | start_position=None, 47 | end_position=None, 48 | is_impossible=None): 49 | self.qas_id = qas_id 50 | self.question_text = question_text 51 | self.doc_tokens = doc_tokens 52 | self.orig_answer_text = orig_answer_text 53 | self.start_position = start_position 54 | self.end_position = end_position 55 | self.is_impossible = is_impossible 56 | 57 | def __str__(self): 58 | return self.__repr__() 59 | 60 | def __repr__(self): 61 | s = "" 62 | s += "qas_id: %s" % (self.qas_id) 63 | s += ", question_text: %s" % ( 64 | self.question_text) 65 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 66 | if self.start_position: 67 | s += ", start_position: %d" % (self.start_position) 68 | if self.end_position: 69 | s += ", end_position: %d" % (self.end_position) 70 | if self.is_impossible: 71 | s += ", is_impossible: %r" % (self.is_impossible) 72 | return s 73 | 74 | 75 | class InputFeatures(object): 76 | """A single set of features of data.""" 77 | 78 | def __init__(self, 79 | unique_id, 80 | example_index, 81 | doc_span_index, 82 | tokens, 83 | token_to_orig_map, 84 | token_is_max_context, 85 | input_ids, 86 | input_mask, 87 | segment_ids, 88 | cls_index, 89 | p_mask, 90 | paragraph_len, 91 | start_position=None, 92 | end_position=None, 93 | is_impossible=None): 94 | self.unique_id = unique_id 95 | self.example_index = example_index 96 | self.doc_span_index = doc_span_index 97 | self.tokens = tokens 98 | self.token_to_orig_map = token_to_orig_map 99 | self.token_is_max_context = token_is_max_context 100 | self.input_ids = input_ids 101 | self.input_mask = input_mask 102 | self.segment_ids = segment_ids 103 | self.cls_index = cls_index 104 | self.p_mask = p_mask 105 | self.paragraph_len = paragraph_len 106 | self.start_position = start_position 107 | self.end_position = end_position 108 | self.is_impossible = is_impossible 109 | 110 | 111 | def read_squad_examples(input_file, is_training, version_2_with_negative): 112 | """Read a SQuAD json file into a list of SquadExample.""" 113 | with open(input_file, "r", encoding='utf-8') as reader: 114 | input_data = json.load(reader)["data"] 115 | 116 | def is_whitespace(c): 117 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 118 | return True 119 | return False 120 | 121 | examples = [] 122 | for entry in input_data: 123 | for paragraph in entry["paragraphs"]: 124 | paragraph_text = paragraph["context"] 125 | doc_tokens = [] 126 | char_to_word_offset = [] 127 | prev_is_whitespace = True 128 | for c in paragraph_text: 129 | if is_whitespace(c): 130 | prev_is_whitespace = True 131 | else: 132 | if prev_is_whitespace: 133 | doc_tokens.append(c) 134 | else: 135 | doc_tokens[-1] += c 136 | prev_is_whitespace = False 137 | char_to_word_offset.append(len(doc_tokens) - 1) 138 | 139 | for qa in paragraph["qas"]: 140 | qas_id = qa["id"] 141 | question_text = qa["question"] 142 | start_position = None 143 | end_position = None 144 | orig_answer_text = None 145 | is_impossible = False 146 | if is_training: 147 | if version_2_with_negative: 148 | is_impossible = qa["is_impossible"] 149 | if (len(qa["answers"]) != 1) and (not is_impossible): 150 | raise ValueError( 151 | "For training, each question should have exactly 1 answer.") 152 | if not is_impossible: 153 | answer = qa["answers"][0] 154 | orig_answer_text = answer["text"] 155 | answer_offset = answer["answer_start"] 156 | answer_length = len(orig_answer_text) 157 | start_position = char_to_word_offset[answer_offset] 158 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 159 | # Only add answers where the text can be exactly recovered from the 160 | # document. If this CAN'T happen it's likely due to weird Unicode 161 | # stuff so we will just skip the example. 162 | # 163 | # Note that this means for training mode, every example is NOT 164 | # guaranteed to be preserved. 165 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 166 | cleaned_answer_text = " ".join( 167 | whitespace_tokenize(orig_answer_text)) 168 | if actual_text.find(cleaned_answer_text) == -1: 169 | logger.warning("Could not find answer: '%s' vs. '%s'", 170 | actual_text, cleaned_answer_text) 171 | continue 172 | else: 173 | start_position = -1 174 | end_position = -1 175 | orig_answer_text = "" 176 | 177 | example = SquadExample( 178 | qas_id=qas_id, 179 | question_text=question_text, 180 | doc_tokens=doc_tokens, 181 | orig_answer_text=orig_answer_text, 182 | start_position=start_position, 183 | end_position=end_position, 184 | is_impossible=is_impossible) 185 | examples.append(example) 186 | return examples 187 | 188 | 189 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 190 | doc_stride, max_query_length, is_training, 191 | cls_token_at_end=False, 192 | cls_token='[CLS]', sep_token='[SEP]', pad_token=0, 193 | sequence_a_segment_id=0, sequence_b_segment_id=1, 194 | cls_token_segment_id=0, pad_token_segment_id=0, 195 | mask_padding_with_zero=True): 196 | """Loads a data file into a list of `InputBatch`s.""" 197 | 198 | unique_id = 1000000000 199 | # cnt_pos, cnt_neg = 0, 0 200 | # max_N, max_M = 1024, 1024 201 | # f = np.zeros((max_N, max_M), dtype=np.float32) 202 | 203 | features = [] 204 | for (example_index, example) in enumerate(examples): 205 | 206 | # if example_index % 100 == 0: 207 | # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg) 208 | 209 | query_tokens = tokenizer.tokenize(example.question_text) 210 | 211 | if len(query_tokens) > max_query_length: 212 | query_tokens = query_tokens[0:max_query_length] 213 | 214 | tok_to_orig_index = [] 215 | orig_to_tok_index = [] 216 | all_doc_tokens = [] 217 | for (i, token) in enumerate(example.doc_tokens): 218 | orig_to_tok_index.append(len(all_doc_tokens)) 219 | sub_tokens = tokenizer.tokenize(token) 220 | for sub_token in sub_tokens: 221 | tok_to_orig_index.append(i) 222 | all_doc_tokens.append(sub_token) 223 | 224 | tok_start_position = None 225 | tok_end_position = None 226 | if is_training and example.is_impossible: 227 | tok_start_position = -1 228 | tok_end_position = -1 229 | if is_training and not example.is_impossible: 230 | tok_start_position = orig_to_tok_index[example.start_position] 231 | if example.end_position < len(example.doc_tokens) - 1: 232 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 233 | else: 234 | tok_end_position = len(all_doc_tokens) - 1 235 | (tok_start_position, tok_end_position) = _improve_answer_span( 236 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 237 | example.orig_answer_text) 238 | 239 | # The -3 accounts for [CLS], [SEP] and [SEP] 240 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 241 | 242 | # We can have documents that are longer than the maximum sequence length. 243 | # To deal with this we do a sliding window approach, where we take chunks 244 | # of the up to our max length with a stride of `doc_stride`. 245 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 246 | "DocSpan", ["start", "length"]) 247 | doc_spans = [] 248 | start_offset = 0 249 | while start_offset < len(all_doc_tokens): 250 | length = len(all_doc_tokens) - start_offset 251 | if length > max_tokens_for_doc: 252 | length = max_tokens_for_doc 253 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 254 | if start_offset + length == len(all_doc_tokens): 255 | break 256 | start_offset += min(length, doc_stride) 257 | 258 | for (doc_span_index, doc_span) in enumerate(doc_spans): 259 | tokens = [] 260 | token_to_orig_map = {} 261 | token_is_max_context = {} 262 | segment_ids = [] 263 | 264 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 265 | # Original TF implem also keep the classification token (set to 0) (not sure why...) 266 | p_mask = [] 267 | 268 | # CLS token at the beginning 269 | if not cls_token_at_end: 270 | tokens.append(cls_token) 271 | segment_ids.append(cls_token_segment_id) 272 | p_mask.append(0) 273 | cls_index = 0 274 | 275 | # Query 276 | for token in query_tokens: 277 | tokens.append(token) 278 | segment_ids.append(sequence_a_segment_id) 279 | p_mask.append(1) 280 | 281 | # SEP token 282 | tokens.append(sep_token) 283 | segment_ids.append(sequence_a_segment_id) 284 | p_mask.append(1) 285 | 286 | # Paragraph 287 | for i in range(doc_span.length): 288 | split_token_index = doc_span.start + i 289 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 290 | 291 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 292 | split_token_index) 293 | token_is_max_context[len(tokens)] = is_max_context 294 | tokens.append(all_doc_tokens[split_token_index]) 295 | segment_ids.append(sequence_b_segment_id) 296 | p_mask.append(0) 297 | paragraph_len = doc_span.length 298 | 299 | # SEP token 300 | tokens.append(sep_token) 301 | segment_ids.append(sequence_b_segment_id) 302 | p_mask.append(1) 303 | 304 | # CLS token at the end 305 | if cls_token_at_end: 306 | tokens.append(cls_token) 307 | segment_ids.append(cls_token_segment_id) 308 | p_mask.append(0) 309 | cls_index = len(tokens) - 1 # Index of classification token 310 | 311 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 312 | 313 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 314 | # tokens are attended to. 315 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 316 | 317 | # Zero-pad up to the sequence length. 318 | while len(input_ids) < max_seq_length: 319 | input_ids.append(pad_token) 320 | input_mask.append(0 if mask_padding_with_zero else 1) 321 | segment_ids.append(pad_token_segment_id) 322 | p_mask.append(1) 323 | 324 | assert len(input_ids) == max_seq_length 325 | assert len(input_mask) == max_seq_length 326 | assert len(segment_ids) == max_seq_length 327 | 328 | span_is_impossible = example.is_impossible 329 | start_position = None 330 | end_position = None 331 | if is_training and not span_is_impossible: 332 | # For training, if our document chunk does not contain an annotation 333 | # we throw it out, since there is nothing to predict. 334 | doc_start = doc_span.start 335 | doc_end = doc_span.start + doc_span.length - 1 336 | out_of_span = False 337 | if not (tok_start_position >= doc_start and 338 | tok_end_position <= doc_end): 339 | out_of_span = True 340 | if out_of_span: 341 | start_position = 0 342 | end_position = 0 343 | span_is_impossible = True 344 | continue 345 | else: 346 | doc_offset = len(query_tokens) + 2 347 | start_position = tok_start_position - doc_start + doc_offset 348 | end_position = tok_end_position - doc_start + doc_offset 349 | 350 | if is_training and span_is_impossible: 351 | start_position = cls_index 352 | end_position = cls_index 353 | 354 | if example_index < 0: 355 | logger.info("*** Example ***") 356 | logger.info("unique_id: %s" % (unique_id)) 357 | logger.info("example_index: %s" % (example_index)) 358 | logger.info("doc_span_index: %s" % (doc_span_index)) 359 | logger.info("tokens: %s" % " ".join(tokens)) 360 | logger.info("token_to_orig_map: %s" % " ".join([ 361 | "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) 362 | logger.info("token_is_max_context: %s" % " ".join([ 363 | "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() 364 | ])) 365 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 366 | logger.info( 367 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 368 | logger.info( 369 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 370 | if is_training and span_is_impossible: 371 | logger.info("impossible example") 372 | if is_training and not span_is_impossible: 373 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 374 | logger.info("start_position: %d" % (start_position)) 375 | logger.info("end_position: %d" % (end_position)) 376 | logger.info( 377 | "answer: %s" % (answer_text)) 378 | if example_index and example_index % 10000 == 0: 379 | logger.info("Converting to features: %d finished." % (example_index)) 380 | features.append( 381 | InputFeatures( 382 | unique_id=unique_id, 383 | example_index=example_index, 384 | doc_span_index=doc_span_index, 385 | tokens=tokens, 386 | token_to_orig_map=token_to_orig_map, 387 | token_is_max_context=token_is_max_context, 388 | input_ids=input_ids, 389 | input_mask=input_mask, 390 | segment_ids=segment_ids, 391 | cls_index=cls_index, 392 | p_mask=p_mask, 393 | paragraph_len=paragraph_len, 394 | start_position=start_position, 395 | end_position=end_position, 396 | is_impossible=span_is_impossible)) 397 | unique_id += 1 398 | 399 | return features 400 | 401 | 402 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 403 | orig_answer_text): 404 | """Returns tokenized answer spans that better match the annotated answer.""" 405 | 406 | # The SQuAD annotations are character based. We first project them to 407 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 408 | # often find a "better match". For example: 409 | # 410 | # Question: What year was John Smith born? 411 | # Context: The leader was John Smith (1895-1943). 412 | # Answer: 1895 413 | # 414 | # The original whitespace-tokenized answer will be "(1895-1943).". However 415 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 416 | # the exact answer, 1895. 417 | # 418 | # However, this is not always possible. Consider the following: 419 | # 420 | # Question: What country is the top exporter of electornics? 421 | # Context: The Japanese electronics industry is the lagest in the world. 422 | # Answer: Japan 423 | # 424 | # In this case, the annotator chose "Japan" as a character sub-span of 425 | # the word "Japanese". Since our WordPiece tokenizer does not split 426 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 427 | # in SQuAD, but does happen. 428 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 429 | 430 | for new_start in range(input_start, input_end + 1): 431 | for new_end in range(input_end, new_start - 1, -1): 432 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 433 | if text_span == tok_answer_text: 434 | return (new_start, new_end) 435 | 436 | return (input_start, input_end) 437 | 438 | 439 | def _check_is_max_context(doc_spans, cur_span_index, position): 440 | """Check if this is the 'max context' doc span for the token.""" 441 | 442 | # Because of the sliding window approach taken to scoring documents, a single 443 | # token can appear in multiple documents. E.g. 444 | # Doc: the man went to the store and bought a gallon of milk 445 | # Span A: the man went to the 446 | # Span B: to the store and bought 447 | # Span C: and bought a gallon of 448 | # ... 449 | # 450 | # Now the word 'bought' will have two scores from spans B and C. We only 451 | # want to consider the score with "maximum context", which we define as 452 | # the *minimum* of its left and right context (the *sum* of left and 453 | # right context will always be the same, of course). 454 | # 455 | # In the example the maximum context for 'bought' would be span C since 456 | # it has 1 left context and 3 right context, while span B has 4 left context 457 | # and 0 right context. 458 | best_score = None 459 | best_span_index = None 460 | for (span_index, doc_span) in enumerate(doc_spans): 461 | end = doc_span.start + doc_span.length - 1 462 | if position < doc_span.start: 463 | continue 464 | if position > end: 465 | continue 466 | num_left_context = position - doc_span.start 467 | num_right_context = end - position 468 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 469 | if best_score is None or score > best_score: 470 | best_score = score 471 | best_span_index = span_index 472 | 473 | return cur_span_index == best_span_index 474 | 475 | 476 | RawResult = collections.namedtuple("RawResult", 477 | ["unique_id", "start_logits", "end_logits"]) 478 | 479 | def write_predictions(all_examples, all_features, all_results, n_best_size, 480 | max_answer_length, do_lower_case, output_prediction_file, 481 | output_nbest_file, output_null_log_odds_file, verbose_logging, 482 | version_2_with_negative, null_score_diff_threshold): 483 | """Write final predictions to the json file and log-odds of null if needed.""" 484 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 485 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 486 | 487 | example_index_to_features = collections.defaultdict(list) 488 | for feature in all_features: 489 | example_index_to_features[feature.example_index].append(feature) 490 | 491 | unique_id_to_result = {} 492 | for result in all_results: 493 | unique_id_to_result[result.unique_id] = result 494 | 495 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 496 | "PrelimPrediction", 497 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 498 | 499 | all_predictions = collections.OrderedDict() 500 | all_nbest_json = collections.OrderedDict() 501 | scores_diff_json = collections.OrderedDict() 502 | 503 | for (example_index, example) in enumerate(all_examples): 504 | features = example_index_to_features[example_index] 505 | 506 | prelim_predictions = [] 507 | # keep track of the minimum score of null start+end of position 0 508 | score_null = 1000000 # large and positive 509 | min_null_feature_index = 0 # the paragraph slice with min null score 510 | null_start_logit = 0 # the start logit at the slice with min null score 511 | null_end_logit = 0 # the end logit at the slice with min null score 512 | for (feature_index, feature) in enumerate(features): 513 | result = unique_id_to_result[feature.unique_id] 514 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 515 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 516 | # if we could have irrelevant answers, get the min score of irrelevant 517 | if version_2_with_negative: 518 | feature_null_score = result.start_logits[0] + result.end_logits[0] 519 | if feature_null_score < score_null: 520 | score_null = feature_null_score 521 | min_null_feature_index = feature_index 522 | null_start_logit = result.start_logits[0] 523 | null_end_logit = result.end_logits[0] 524 | for start_index in start_indexes: 525 | for end_index in end_indexes: 526 | # We could hypothetically create invalid predictions, e.g., predict 527 | # that the start of the span is in the question. We throw out all 528 | # invalid predictions. 529 | if start_index >= len(feature.tokens): 530 | continue 531 | if end_index >= len(feature.tokens): 532 | continue 533 | if start_index not in feature.token_to_orig_map: 534 | continue 535 | if end_index not in feature.token_to_orig_map: 536 | continue 537 | if not feature.token_is_max_context.get(start_index, False): 538 | continue 539 | if end_index < start_index: 540 | continue 541 | length = end_index - start_index + 1 542 | if length > max_answer_length: 543 | continue 544 | prelim_predictions.append( 545 | _PrelimPrediction( 546 | feature_index=feature_index, 547 | start_index=start_index, 548 | end_index=end_index, 549 | start_logit=result.start_logits[start_index], 550 | end_logit=result.end_logits[end_index])) 551 | if version_2_with_negative: 552 | prelim_predictions.append( 553 | _PrelimPrediction( 554 | feature_index=min_null_feature_index, 555 | start_index=0, 556 | end_index=0, 557 | start_logit=null_start_logit, 558 | end_logit=null_end_logit)) 559 | prelim_predictions = sorted( 560 | prelim_predictions, 561 | key=lambda x: (x.start_logit + x.end_logit), 562 | reverse=True) 563 | 564 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 565 | "NbestPrediction", ["text", "start_logit", "end_logit", "orig_doc_start"]) 566 | 567 | seen_predictions = {} 568 | nbest = [] 569 | for pred in prelim_predictions: 570 | if len(nbest) >= n_best_size: 571 | break 572 | feature = features[pred.feature_index] 573 | if pred.start_index > 0: # this is a non-null prediction 574 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 575 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 576 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 577 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 578 | tok_text = " ".join(tok_tokens) 579 | 580 | # De-tokenize WordPieces that have been split off. 581 | tok_text = tok_text.replace(" ##", "") 582 | tok_text = tok_text.replace("##", "") 583 | 584 | # Clean whitespace 585 | tok_text = tok_text.strip() 586 | tok_text = " ".join(tok_text.split()) 587 | orig_text = " ".join(orig_tokens) 588 | 589 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 590 | #ans_pos_str = "