├── .gitignore ├── .gitmodules ├── BiDAFpp ├── README.md ├── download.sh ├── hotpot_evaluate_v1.py ├── main.py ├── model.py ├── prepro.py ├── run.py ├── sp_model.py └── util.py ├── LICENSE ├── README.md ├── fig ├── golden-retriever-icon.png └── golden-retriever.png ├── prepared_data └── README.md ├── requirements.txt ├── scripts ├── __init__.py ├── build_qa_data.py ├── build_single_hop_qa_data.py ├── download_corenlp.sh ├── download_elastic_6.7.sh ├── download_golden_retriever_models.sh ├── download_hotpotqa.sh ├── download_prepared_data.sh ├── download_processed_wiki.sh ├── e_to_e_helpers │ ├── merge_hops_results.py │ ├── merge_with_es.py │ └── squadify_questions.py ├── eval_drqa.py ├── eval_end_to_end.sh ├── eval_hits.py ├── eval_model2.py ├── eval_model2_emf1.py ├── eval_single_hop.sh ├── format_result.py ├── gen_hop1.py ├── gen_hop2.py ├── index_processed_wiki.py ├── launch_elasticsearch_6.7.sh ├── offline_ir_eval.py ├── preprocess_hop1.py ├── preprocess_hop2.py ├── query_generator_study.py └── query_labels_to_pred.py ├── search ├── jvm.options └── search.py ├── setup.sh └── utils ├── constant.py ├── corenlp.py ├── general.py ├── io.py └── lcs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # golden retriever specific 2 | data/ 3 | elasticsearch-6.7.0/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "DrQA"] 2 | path = DrQA 3 | url = https://github.com/facebookresearch/DrQA.git 4 | -------------------------------------------------------------------------------- /BiDAFpp/README.md: -------------------------------------------------------------------------------- 1 | # BiDAF++ variant used for GoldEn Retriever 2 | 3 | This repository contains the QA component of GoldEn retriever, which is adapted from the [HotpotQA baseline model](https://github.com/hotpotqa/hotpot). 4 | 5 | ## Requirements 6 | 7 | Python 3, pytorch 1.1.0, stanfordnlp 8 | 9 | To install pytorch 1.1.0, follow the instructions at https://pytorch.org/get-started/previous-versions/ . For example, with 10 | CUDA9 and conda you can do 11 | ``` 12 | conda install pytorch=1.1.0 cuda90 -c pytorch 13 | ``` 14 | 15 | To install stanfordnlp, run 16 | ``` 17 | conda install stanfordnlp 18 | ``` 19 | 20 | ## Data Download and Preprocessing 21 | 22 | Run the script to download the data, including HotpotQA data and GloVe embeddings. 23 | ``` 24 | ./download.sh 25 | ``` 26 | Please follow the instructions [here](https://stanfordnlp.github.io/stanfordnlp/corenlp_client.html#setup) to setup CoreNLP for tokenization. 27 | 28 | There are three HotpotQA files: 29 | - Training set http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json 30 | - Dev set in the distractor setting http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json 31 | - Dev set in the fullwiki setting http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json This is just `hotpot_dev_distractor_v1.json` without the gold paragraphs, but instead with the top 10 paragraphs obtained using our 32 | retrieval system. If you want to use your own IR system (which is encouraged!), you can replace the paragraphs in this json 33 | with your own retrieval results. Please note that the gold paragraphs might or might not be in this json because our IR system 34 | is pretty basic. 35 | - Test set in the fullwiki setting http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_test_fullwiki_v1.json Because in the fullwiki setting, you only need to submit your prediction to our evaluation server without the code, we publish the test set without the answers and supporting facts. The context in the file is paragraphs obtained using our retrieval system, which might or might not contain the gold paragraphs. Again you are encouraged to use your own IR system in this setting --- simply replace the paragraphs in this json with your own retrieval results. 36 | 37 | 38 | ## JSON Format 39 | 40 | The top level structure of each JSON file is a list, where each entry represents a question-answer data point. Each data point is 41 | a dict with the following keys: 42 | - `_id`: a unique id for this question-answer data point. This is useful for evaluation. 43 | - `question`: a string. 44 | - `answer`: a string. The test set does not have this key. 45 | - `supporting_facts`: a list. Each entry in the list is a list with two elements `[title, sent_id]`, where `title` denotes the title of the 46 | paragraph, and `sent_id` denotes the supporting fact's id (0-based) in this paragraph. The test set does not have this key. 47 | - `context`: a list. Each entry is a paragraph, which is represented as a list with two elements `[title, sentences]` and `sentences` is a list 48 | of strings. 49 | 50 | There are other keys that are not used in our code, but might be used for other purposes (note that these keys are not present in the test sets, and your model should not rely on these two keys for making preditions on the test sets): 51 | - `type`: either `comparison` or `bridge`, indicating the question type. (See our paper for more details). 52 | - `level`: one of `easy`, `medium`, and `hard`. (See our paper for more details). 53 | 54 | ## Preprocessing 55 | 56 | Preprocess the training and dev sets in the distractor setting: 57 | ``` 58 | python main.py --mode prepro --data_file hotpot_train_v1.1.json --para_limit 2250 --data_split train 59 | python main.py --mode prepro --data_file hotpot_dev_distractor_v1.json --para_limit 2250 --data_split dev 60 | ``` 61 | 62 | Preprocess the dev set in the full wiki setting: 63 | ``` 64 | python main.py --mode prepro --data_file hotpot_dev_fullwiki_v1.json --data_split dev --fullwiki --para_limit 2250 65 | ``` 66 | 67 | Note that the training set has to be preprocessed before the dev sets because some vocabulary and embedding files are produced 68 | when the training set is processed. 69 | 70 | ## Training 71 | 72 | Train a model 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0 python main.py --mode train --para_limit 2250 --batch_size 24 --init_lr 0.1 --keep_prob 1.0 \ 75 | --sp_lambda 1.0 76 | ``` 77 | 78 | Our implementation supports running on multiple GPUs. Remove the `CUDA_VISIBLE_DEVICES` variable to run on all GPUs you have 79 | ``` 80 | python main.py --mode train --para_limit 2250 --batch_size 24 --init_lr 0.1 --keep_prob 1.0 --sp_lambda 1.0 81 | ``` 82 | 83 | You will be able to see the perf reach over 58 F1 on the dev set. Record the file name (something like `HOTPOT-20180924-160521`) 84 | which will be used during evaluation. 85 | 86 | ## Local Evaluation 87 | 88 | First, make predictions and save the predictions into a file (replace `--save` with your own file name). 89 | ``` 90 | CUDA_VISIBLE_DEVICES=0 python main.py --mode test --data_split dev --para_limit 2250 --batch_size 24 --init_lr 0.1 \ 91 | --keep_prob 1.0 --sp_lambda 1.0 --save HOTPOT-20180924-160521 --prediction_file dev_distractor_pred.json 92 | ``` 93 | 94 | Then, call the evaluation script: 95 | ``` 96 | python hotpot_evaluate_v1.py dev_distractor_pred.json hotpot_dev_distractor_v1.json 97 | ``` 98 | 99 | The same procedure can be repeated to evaluate the dev set in the fullwiki setting. 100 | ``` 101 | CUDA_VISIBLE_DEVICES=0 python main.py --mode test --data_split dev --para_limit 2250 --batch_size 24 --init_lr 0.1 \ 102 | --keep_prob 1.0 --sp_lambda 1.0 --save HOTPOT-20180924-160521 --prediction_file dev_fullwiki_pred.json --fullwiki 103 | python hotpot_evaluate_v1.py dev_fullwiki_pred.json hotpot_dev_fullwiki_v1.json 104 | ``` 105 | 106 | ## Prediction File Format 107 | 108 | The prediction files `dev_distractor_pred.json` and `dev_fullwiki_pred.json` should be JSON files with the following keys: 109 | - `answer`: a dict. Each key of the dict is a QA pair id, corresponding to the field `_id` in data JSON files. Each value of the dict is a string representing the predicted answer. 110 | - `sp`: a dict. Each key of the dict is a QA pair id, corresponding to the field `_id` in data JSON files. Each value of the dict is a list representing the predicted supporting facts. Each entry of the list is a list with two elements `[title, sent_id]`, where `title` denotes the title of the paragraph, and `sent_id` denotes the supporting fact's id (0-based) in this paragraph. 111 | 112 | ## Model Submission and Test Set Evaluation 113 | 114 | We use Codalab for test set evaluation. In the distractor setting, you must submit your code and provide a Docker environment. Your code will run on the test set. In the fullwiki setting, you only need to submit your prediction file. See https://worksheets.codalab.org/worksheets/0xa8718c1a5e9e470e84a7d5fb3ab1dde2/ for detailed instructions. 115 | 116 | ## License 117 | The HotpotQA dataset is distribued under the [CC BY-SA 4.0](http://creativecommons.org/licenses/by-sa/4.0/legalcode) license. 118 | The code is distribued under the Apache 2.0 license. 119 | 120 | ## References 121 | 122 | The preprocessing part and the data loader are adapted from https://github.com/HKUST-KnowComp/R-Net . The evaluation script is 123 | adapted from https://rajpurkar.github.io/SQuAD-explorer/ . 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /BiDAFpp/download.sh: -------------------------------------------------------------------------------- 1 | 2 | # Download Hotpot Data 3 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json 4 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json 5 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json 6 | 7 | # Download GloVe 8 | GLOVE_DIR=./ 9 | mkdir -p $GLOVE_DIR 10 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O $GLOVE_DIR/glove.840B.300d.zip 11 | unzip $GLOVE_DIR/glove.840B.300d.zip -d $GLOVE_DIR 12 | -------------------------------------------------------------------------------- /BiDAFpp/hotpot_evaluate_v1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ujson as json 3 | import re 4 | import string 5 | from collections import Counter 6 | import pickle 7 | 8 | def normalize_answer(s): 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def f1_score(prediction, ground_truth): 27 | normalized_prediction = normalize_answer(prediction) 28 | normalized_ground_truth = normalize_answer(ground_truth) 29 | 30 | ZERO_METRIC = (0, 0, 0) 31 | 32 | #if normalized_prediction != normalized_ground_truth: 33 | # print("|{}| |{}|".format(normalized_prediction, normalized_ground_truth)) 34 | 35 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 36 | return ZERO_METRIC 37 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 38 | return ZERO_METRIC 39 | 40 | prediction_tokens = normalized_prediction.split() 41 | ground_truth_tokens = normalized_ground_truth.split() 42 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 43 | num_same = sum(common.values()) 44 | if num_same == 0: 45 | return ZERO_METRIC 46 | precision = 1.0 * num_same / len(prediction_tokens) 47 | recall = 1.0 * num_same / len(ground_truth_tokens) 48 | f1 = (2 * precision * recall) / (precision + recall) 49 | return f1, precision, recall 50 | 51 | 52 | def exact_match_score(prediction, ground_truth): 53 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 54 | 55 | def update_answer(metrics, prediction, gold): 56 | em = exact_match_score(prediction, gold) 57 | f1, prec, recall = f1_score(prediction, gold) 58 | metrics['em'] += float(em) 59 | metrics['f1'] += f1 60 | metrics['prec'] += prec 61 | metrics['recall'] += recall 62 | return em, prec, recall 63 | 64 | def update_sp(metrics, prediction, gold): 65 | cur_sp_pred = set(map(tuple, prediction)) 66 | gold_sp_pred = set(map(tuple, gold)) 67 | tp, fp, fn = 0, 0, 0 68 | for e in cur_sp_pred: 69 | if e in gold_sp_pred: 70 | tp += 1 71 | else: 72 | fp += 1 73 | for e in gold_sp_pred: 74 | if e not in cur_sp_pred: 75 | fn += 1 76 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 77 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 78 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 79 | em = 1.0 if fp + fn == 0 else 0.0 80 | metrics['sp_em'] += em 81 | metrics['sp_f1'] += f1 82 | metrics['sp_prec'] += prec 83 | metrics['sp_recall'] += recall 84 | return em, prec, recall 85 | 86 | def eval(prediction_file, gold_file): 87 | with open(prediction_file) as f: 88 | prediction = json.load(f) 89 | with open(gold_file) as f: 90 | gold = json.load(f) 91 | 92 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 93 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 94 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 95 | for dp in gold: 96 | cur_id = dp['_id'] 97 | can_eval_joint = True 98 | if cur_id not in prediction['answer']: 99 | #print('missing answer {}'.format(cur_id)) 100 | can_eval_joint = False 101 | else: 102 | em, prec, recall = update_answer( 103 | metrics, prediction['answer'][cur_id], dp['answer']) 104 | if cur_id not in prediction['sp']: 105 | #print('missing sp fact {}'.format(cur_id)) 106 | can_eval_joint = False 107 | else: 108 | sp_em, sp_prec, sp_recall = update_sp( 109 | metrics, prediction['sp'][cur_id], dp['supporting_facts']) 110 | 111 | if can_eval_joint: 112 | joint_prec = prec * sp_prec 113 | joint_recall = recall * sp_recall 114 | if joint_prec + joint_recall > 0: 115 | joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) 116 | else: 117 | joint_f1 = 0. 118 | joint_em = em * sp_em 119 | 120 | metrics['joint_em'] += joint_em 121 | metrics['joint_f1'] += joint_f1 122 | metrics['joint_prec'] += joint_prec 123 | metrics['joint_recall'] += joint_recall 124 | 125 | N = len(gold) 126 | for k in metrics.keys(): 127 | metrics[k] /= N 128 | 129 | print(metrics) 130 | 131 | if __name__ == '__main__': 132 | eval(sys.argv[1], sys.argv[2]) 133 | 134 | -------------------------------------------------------------------------------- /BiDAFpp/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from prepro import prepro 3 | from run import train, test 4 | import torch 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | 9 | glove_word_file = "glove.840B.300d.txt" 10 | 11 | hops = True 12 | 13 | word_emb_file = "word_emb{}.json".format('_hops' if hops else '') 14 | char_emb_file = "char_emb{}.json".format('_hops' if hops else '') 15 | train_eval = "train_eval{}.json".format('_hops' if hops else '') 16 | dev_eval = "dev_eval{}.json".format('_hops' if hops else '') 17 | test_eval = "test_eval{}.json".format('_hops' if hops else '') 18 | word2idx_file = "word2idx{}.json".format('_hops' if hops else '') 19 | char2idx_file = "char2idx{}.json".format('_hops' if hops else '') 20 | idx2word_file = 'idx2word{}.json'.format('_hops' if hops else '') 21 | idx2char_file = 'idx2char{}.json'.format('_hops' if hops else '') 22 | train_record_file = 'train_record{}.pkl'.format('_hops' if hops else '') 23 | dev_record_file = 'dev_record{}.pkl'.format('_hops' if hops else '') 24 | test_record_file = 'test_record{}.pkl'.format('_hops' if hops else '') 25 | 26 | 27 | parser.add_argument('--mode', type=str, default='train') 28 | parser.add_argument('--data_file', type=str) 29 | parser.add_argument('--glove_word_file', type=str, default=glove_word_file) 30 | parser.add_argument('--save', type=str, default='HOTPOT') 31 | 32 | parser.add_argument('--word_emb_file', type=str, default=word_emb_file) 33 | parser.add_argument('--char_emb_file', type=str, default=char_emb_file) 34 | parser.add_argument('--train_eval_file', type=str, default=train_eval) 35 | parser.add_argument('--dev_eval_file', type=str, default=dev_eval) 36 | parser.add_argument('--test_eval_file', type=str, default=test_eval) 37 | parser.add_argument('--word2idx_file', type=str, default=word2idx_file) 38 | parser.add_argument('--char2idx_file', type=str, default=char2idx_file) 39 | parser.add_argument('--idx2word_file', type=str, default=idx2word_file) 40 | parser.add_argument('--idx2char_file', type=str, default=idx2char_file) 41 | 42 | parser.add_argument('--train_record_file', type=str, default=train_record_file) 43 | parser.add_argument('--dev_record_file', type=str, default=dev_record_file) 44 | parser.add_argument('--test_record_file', type=str, default=test_record_file) 45 | 46 | parser.add_argument('--glove_char_size', type=int, default=94) 47 | parser.add_argument('--glove_word_size', type=int, default=int(2.2e6)) 48 | parser.add_argument('--glove_dim', type=int, default=300) 49 | parser.add_argument('--char_dim', type=int, default=8) 50 | 51 | parser.add_argument('--para_limit', type=int, default=1000) 52 | parser.add_argument('--ques_limit', type=int, default=80) 53 | parser.add_argument('--sent_limit', type=int, default=100) 54 | parser.add_argument('--char_limit', type=int, default=16) 55 | 56 | parser.add_argument('--batch_size', type=int, default=64) 57 | parser.add_argument('--checkpoint', type=int, default=1000) 58 | parser.add_argument('--period', type=int, default=100) 59 | parser.add_argument('--init_lr', type=float, default=0.5) 60 | parser.add_argument('--max_grad_norm', type=float, default=0.0) 61 | parser.add_argument('--keep_prob', type=float, default=0.8) 62 | parser.add_argument('--hidden', type=int, default=80) 63 | parser.add_argument('--char_hidden', type=int, default=100) 64 | parser.add_argument('--patience', type=int, default=1) 65 | parser.add_argument('--seed', type=int, default=13) 66 | 67 | parser.add_argument('--sp_lambda', type=float, default=0.0) 68 | 69 | parser.add_argument('--data_split', type=str, default='train') 70 | parser.add_argument('--fullwiki', action='store_true') 71 | parser.add_argument('--prediction_file', type=str) 72 | parser.add_argument('--sp_threshold', type=float, default=0.3) 73 | 74 | parser.add_argument('--no-cuda', dest='cuda', default=torch.cuda.is_available(), action='store_false') 75 | 76 | config = parser.parse_args() 77 | 78 | def _concat(filename): 79 | if config.fullwiki: 80 | return 'fullwiki.{}'.format(filename) 81 | return filename 82 | # config.train_record_file = _concat(config.train_record_file) 83 | config.dev_record_file = _concat(config.dev_record_file) 84 | config.test_record_file = _concat(config.test_record_file) 85 | # config.train_eval_file = _concat(config.train_eval_file) 86 | config.dev_eval_file = _concat(config.dev_eval_file) 87 | config.test_eval_file = _concat(config.test_eval_file) 88 | 89 | if config.mode == 'train': 90 | train(config) 91 | elif config.mode == 'prepro': 92 | prepro(config) 93 | elif config.mode == 'test': 94 | test(config) 95 | elif config.mode == 'count': 96 | cnt_len(config) 97 | -------------------------------------------------------------------------------- /BiDAFpp/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import math 8 | from torch.nn import init 9 | from torch.nn.utils import rnn 10 | 11 | class Model(nn.Module): 12 | def __init__(self, config, word_mat, char_mat): 13 | super().__init__() 14 | self.config = config 15 | self.word_dim = config.glove_dim 16 | self.word_emb = nn.Embedding(len(word_mat), len(word_mat[0]), padding_idx=0) 17 | self.word_emb.weight.data.copy_(torch.from_numpy(word_mat)) 18 | self.word_emb.weight.requires_grad = False 19 | self.char_emb = nn.Embedding(len(char_mat), len(char_mat[0]), padding_idx=0) 20 | self.char_emb.weight.data.copy_(torch.from_numpy(char_mat)) 21 | 22 | self.char_cnn = nn.Conv1d(config.char_dim, config.char_hidden, 5) 23 | self.char_hidden = config.char_hidden 24 | self.hidden = config.hidden 25 | 26 | self.rnn = EncoderRNN(config.char_hidden+self.word_dim, config.hidden, 1, True, True, 1-config.keep_prob, False) 27 | 28 | self.qc_att = BiAttention(config.hidden*2, 1-config.keep_prob) 29 | self.linear_1 = nn.Sequential( 30 | nn.Linear(config.hidden*8, config.hidden), 31 | nn.ReLU() 32 | ) 33 | 34 | self.rnn_2 = EncoderRNN(config.hidden, config.hidden, 1, False, True, 1-config.keep_prob, False) 35 | self.self_att = BiAttention(config.hidden*2, 1-config.keep_prob) 36 | self.linear_2 = nn.Sequential( 37 | nn.Linear(config.hidden*8, config.hidden), 38 | nn.ReLU() 39 | ) 40 | 41 | self.rnn_sp = EncoderRNN(config.hidden, config.hidden, 1, False, True, 1-config.keep_prob, False) 42 | self.linear_sp = nn.Linear(config.hidden*2, 1) 43 | 44 | self.rnn_start = EncoderRNN(config.hidden+1, config.hidden, 1, False, True, 1-config.keep_prob, False) 45 | self.linear_start = nn.Linear(config.hidden*2, 1) 46 | 47 | self.rnn_end = EncoderRNN(config.hidden*3+1, config.hidden, 1, False, True, 1-config.keep_prob, False) 48 | self.linear_end = nn.Linear(config.hidden*2, 1) 49 | 50 | self.rnn_type = EncoderRNN(config.hidden*3+1, config.hidden, 1, False, True, 1-config.keep_prob, False) 51 | self.linear_type = nn.Linear(config.hidden*2, 3) 52 | 53 | self.cache_S = 0 54 | 55 | def get_output_mask(self, outer): 56 | S = outer.size(1) 57 | if S <= self.cache_S: 58 | return Variable(self.cache_mask[:S, :S], requires_grad=False) 59 | self.cache_S = S 60 | np_mask = np.tril(np.triu(np.ones((S, S)), 0), 15) 61 | self.cache_mask = outer.data.new(S, S).copy_(torch.from_numpy(np_mask)) 62 | return Variable(self.cache_mask, requires_grad=False) 63 | 64 | def forward(self, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False): 65 | para_size, ques_size, char_size, bsz = context_idxs.size(1), ques_idxs.size(1), context_char_idxs.size(2), context_idxs.size(0) 66 | 67 | context_mask = (context_idxs > 0).float() 68 | ques_mask = (ques_idxs > 0).float() 69 | 70 | context_ch = self.char_emb(context_char_idxs.contiguous().view(-1, char_size)).view(bsz * para_size, char_size, -1) 71 | ques_ch = self.char_emb(ques_char_idxs.contiguous().view(-1, char_size)).view(bsz * ques_size, char_size, -1) 72 | 73 | context_ch = self.char_cnn(context_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, para_size, -1) 74 | ques_ch = self.char_cnn(ques_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, ques_size, -1) 75 | 76 | context_word = self.word_emb(context_idxs) 77 | ques_word = self.word_emb(ques_idxs) 78 | 79 | context_output = torch.cat([context_word, context_ch], dim=2) 80 | ques_output = torch.cat([ques_word, ques_ch], dim=2) 81 | 82 | context_output = self.rnn(context_output, context_lens) 83 | ques_output = self.rnn(ques_output) 84 | 85 | output = self.qc_att(context_output, ques_output, ques_mask) 86 | output = self.linear_1(output) 87 | 88 | output_t = self.rnn_2(output, context_lens) 89 | output_t = self.self_att(output_t, output_t, context_mask) 90 | output_t = self.linear_2(output_t) 91 | 92 | output = output + output_t 93 | 94 | sp_output = self.rnn_sp(output, context_lens) 95 | 96 | start_output = torch.matmul(start_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,self.hidden:]) 97 | end_output = torch.matmul(end_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,:self.hidden]) 98 | sp_output = torch.cat([start_output, end_output], dim=-1) 99 | sp_output = self.linear_sp(sp_output) 100 | sp_output_aux = Variable(sp_output.data.new(sp_output.size(0), sp_output.size(1), 1).zero_()) 101 | predict_support = torch.cat([sp_output_aux, sp_output], dim=-1).contiguous() 102 | 103 | sp_output = torch.matmul(all_mapping, sp_output) 104 | output = torch.cat([output, sp_output], dim=-1) 105 | 106 | output_start = self.rnn_start(output, context_lens) 107 | logit1 = self.linear_start(output_start).squeeze(2) - 1e30 * (1 - context_mask) 108 | output_end = torch.cat([output, output_start], dim=2) 109 | output_end = self.rnn_end(output_end, context_lens) 110 | logit2 = self.linear_end(output_end).squeeze(2) - 1e30 * (1 - context_mask) 111 | 112 | output_type = torch.cat([output, output_end], dim=2) 113 | output_type = torch.max(self.rnn_type(output_type, context_lens), 1)[0] 114 | predict_type = self.linear_type(output_type) 115 | 116 | if not return_yp: return logit1, logit2, predict_type, predict_support 117 | 118 | outer = logit1[:,:,None] + logit2[:,None] 119 | outer_mask = self.get_output_mask(outer) 120 | outer = outer - 1e30 * (1 - outer_mask[None].expand_as(outer)) 121 | yp1 = outer.max(dim=2)[0].max(dim=1)[1] 122 | yp2 = outer.max(dim=1)[0].max(dim=1)[1] 123 | return logit1, logit2, predict_type, predict_support, yp1, yp2 124 | 125 | class LockedDropout(nn.Module): 126 | def __init__(self, dropout): 127 | super().__init__() 128 | self.dropout = dropout 129 | 130 | def forward(self, x): 131 | dropout = self.dropout 132 | if not self.training: 133 | return x 134 | m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout) 135 | mask = Variable(m.div_(1 - dropout), requires_grad=False) 136 | mask = mask.expand_as(x) 137 | return mask * x 138 | 139 | class EncoderRNN(nn.Module): 140 | def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last): 141 | super().__init__() 142 | self.rnns = [] 143 | for i in range(nlayers): 144 | if i == 0: 145 | input_size_ = input_size 146 | output_size_ = num_units 147 | else: 148 | input_size_ = num_units if not bidir else num_units * 2 149 | output_size_ = num_units 150 | self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True)) 151 | self.rnns = nn.ModuleList(self.rnns) 152 | self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)]) 153 | self.dropout = LockedDropout(dropout) 154 | self.concat = concat 155 | self.nlayers = nlayers 156 | self.return_last = return_last 157 | 158 | # self.reset_parameters() 159 | 160 | def reset_parameters(self): 161 | for rnn in self.rnns: 162 | for name, p in rnn.named_parameters(): 163 | if 'weight' in name: 164 | p.data.normal_(std=0.1) 165 | else: 166 | p.data.zero_() 167 | 168 | def get_init(self, bsz, i): 169 | return self.init_hidden[i].expand(-1, bsz, -1).contiguous() 170 | 171 | def forward(self, input, input_lengths=None): 172 | bsz, slen = input.size(0), input.size(1) 173 | output = input 174 | outputs = [] 175 | if input_lengths is not None: 176 | lens = input_lengths.data.cpu().numpy() 177 | for i in range(self.nlayers): 178 | hidden = self.get_init(bsz, i) 179 | output = self.dropout(output) 180 | if input_lengths is not None: 181 | output = rnn.pack_padded_sequence(output, lens, batch_first=True) 182 | output, hidden = self.rnns[i](output, hidden) 183 | if input_lengths is not None: 184 | output, _ = rnn.pad_packed_sequence(output, batch_first=True) 185 | if output.size(1) < slen: # used for parallel 186 | padding = Variable(output.data.new(1, 1, 1).zero_()) 187 | output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1) 188 | if self.return_last: 189 | outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1)) 190 | else: 191 | outputs.append(output) 192 | if self.concat: 193 | return torch.cat(outputs, dim=2) 194 | return outputs[-1] 195 | 196 | class BiAttention(nn.Module): 197 | def __init__(self, input_size, dropout): 198 | super().__init__() 199 | self.dropout = LockedDropout(dropout) 200 | self.input_linear = nn.Linear(input_size, 1, bias=False) 201 | self.memory_linear = nn.Linear(input_size, 1, bias=False) 202 | 203 | self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5))) 204 | 205 | def forward(self, input, memory, mask): 206 | bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1) 207 | 208 | input = self.dropout(input) 209 | memory = self.dropout(memory) 210 | 211 | input_dot = self.input_linear(input) 212 | memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len) 213 | cross_dot = torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous()) 214 | att = input_dot + memory_dot + cross_dot 215 | att = att - 1e30 * (1 - mask[:,None]) 216 | 217 | weight_one = F.softmax(att, dim=-1) 218 | output_one = torch.bmm(weight_one, memory) 219 | weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len) 220 | output_two = torch.bmm(weight_two, input) 221 | 222 | return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1) 223 | 224 | class GateLayer(nn.Module): 225 | def __init__(self, d_input, d_output): 226 | super(GateLayer, self).__init__() 227 | self.linear = nn.Linear(d_input, d_output) 228 | self.gate = nn.Linear(d_input, d_output) 229 | self.sigmoid = nn.Sigmoid() 230 | 231 | def forward(self, input): 232 | return self.linear(input) * self.sigmoid(self.gate(input)) 233 | -------------------------------------------------------------------------------- /BiDAFpp/prepro.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | import spacy 4 | import ujson as json 5 | from collections import Counter 6 | import numpy as np 7 | import os.path 8 | import argparse 9 | import torch 10 | # import pickle 11 | import torch 12 | import os 13 | from joblib import Parallel, delayed 14 | from util import NUM_OF_PARAGRAPHS, MAX_PARAGRAPH_LEN 15 | 16 | import torch 17 | from stanfordnlp.server import CoreNLPClient 18 | 19 | tokenizer_client = CoreNLPClient(annotators=['tokenize', 'ssplit'], timeout=30000, memory='16G', properties={'tokenize.ptb3Escaping': False, 'tokenize.options': "splitHyphenated=true,invertible=true", 'ssplit.eolonly': True}, threads=16) 20 | def word_tokenize(text): 21 | if isinstance(text, str): 22 | ann = tokenizer_client.annotate(text.replace('%', '%25').replace('') for token in ann.sentence[0].token] 24 | return res 25 | else: 26 | ann = tokenizer_client.annotate('\n'.join([x for x in text if len(x.strip())]).replace('%', '%25').replace('') for token in sentence.token] for sentence in ann.sentence] 29 | 30 | res1 = [] 31 | resi = 0 32 | for i, x in enumerate(text): 33 | if len(x.strip()) == 0: 34 | res1.append([]) 35 | else: 36 | if not all(token in x for token in res[resi]): 37 | print(x) 38 | print(res[resi]) 39 | assert all(token in x for token in res[resi]) 40 | res1.append(res[resi]) 41 | resi += 1 42 | return res1 43 | 44 | import bisect 45 | import re 46 | 47 | def find_nearest(a, target, test_func=lambda x: True): 48 | idx = bisect.bisect_left(a, target) 49 | if (0 <= idx < len(a)) and a[idx] == target: 50 | return target, 0 51 | elif idx == 0: 52 | return a[0], abs(a[0] - target) 53 | elif idx == len(a): 54 | return a[-1], abs(a[-1] - target) 55 | else: 56 | d1 = abs(a[idx] - target) if test_func(a[idx]) else 1e200 57 | d2 = abs(a[idx-1] - target) if test_func(a[idx-1]) else 1e200 58 | if d1 > d2: 59 | return a[idx-1], d2 60 | else: 61 | return a[idx], d1 62 | 63 | def fix_span(para, offsets, span): 64 | span = span.strip() 65 | parastr = "".join(para) 66 | assert span in parastr, '{}\t{}'.format(span, parastr) 67 | begins, ends = map(list, zip(*[y for x in offsets for y in x])) 68 | 69 | best_dist = 1e200 70 | best_indices = None 71 | 72 | if span == parastr: 73 | return parastr, (0, len(parastr)), 0 74 | 75 | for m in re.finditer(re.escape(span), parastr): 76 | begin_offset, end_offset = m.span() 77 | 78 | fixed_begin, d1 = find_nearest(begins, begin_offset, lambda x: x < end_offset) 79 | fixed_end, d2 = find_nearest(ends, end_offset, lambda x: x > begin_offset) 80 | 81 | if d1 + d2 < best_dist: 82 | best_dist = d1 + d2 83 | best_indices = (fixed_begin, fixed_end) 84 | if best_dist == 0: 85 | break 86 | 87 | assert best_indices is not None 88 | return parastr[best_indices[0]:best_indices[1]], best_indices, best_dist 89 | 90 | #def word_tokenize(sent): 91 | # doc = nlp(sent) 92 | # return [token.text for token in doc] 93 | 94 | 95 | def convert_idx(text, tokens): 96 | current = 0 97 | spans = [] 98 | for token in tokens: 99 | pre = current 100 | current = text.find(token, current) 101 | if current < 0: 102 | print(f'Token |{token}| not found in |{text}|') 103 | raise Exception() 104 | spans.append((current, current + len(token))) 105 | current += len(token) 106 | return spans 107 | 108 | def prepro_sent(sent): 109 | return sent 110 | # return sent.replace("''", '" ').replace("``", '" ') 111 | 112 | def _process_article(article, config): 113 | paragraphs = article['context'] 114 | # some articles in the fullwiki dev/test sets have zero paragraphs 115 | if len(paragraphs) == 0: 116 | paragraphs = [['some random title', 'some random stuff']] 117 | 118 | text_context, context_tokens, context_chars = '', [], [] 119 | offsets = [] 120 | flat_offsets = [] 121 | start_end_facts = [] # (start_token_id, end_token_id, is_sup_fact=True/False) 122 | sent2title_ids = [] 123 | 124 | def _process(sent, sent_tokens, is_sup_fact, is_title=False): 125 | nonlocal text_context, context_tokens, context_chars, offsets, start_end_facts, flat_offsets 126 | N_chars = len(text_context) 127 | 128 | sent = sent 129 | #sent_tokens = word_tokenize(sent) 130 | if is_title: 131 | sent = ' {} '.format(sent) 132 | sent_tokens = [''] + sent_tokens + [''] 133 | # Change 1: If we see a new paragraph we add an empty list (to which later we will add the paragraphs) 134 | context_tokens.append([]) 135 | context_chars.append([]) 136 | start_end_facts.append([]) 137 | if len(context_tokens[-1]) >= MAX_PARAGRAPH_LEN: 138 | return 139 | # truncate sentence if paragraph is too long 140 | if len(context_tokens[-1]) + len(sent_tokens) >= MAX_PARAGRAPH_LEN: 141 | sent_tokens = sent_tokens[:MAX_PARAGRAPH_LEN - len(context_tokens[-1])] 142 | sent_chars = [list(token) for token in sent_tokens] 143 | sent_spans = convert_idx(sent, sent_tokens) 144 | if len(context_tokens[-1]) + len(sent_tokens) >= MAX_PARAGRAPH_LEN: 145 | sent = sent[:sent_spans[-1][1]] 146 | 147 | sent_spans = [[N_chars+e[0], N_chars+e[1]] for e in sent_spans] 148 | N_tokens, my_N_tokens = len(context_tokens[-1]), len(sent_tokens) 149 | 150 | text_context += sent 151 | # Change 2: Add items to the empty list 152 | ## First occurence of when a flattened list is made 153 | context_tokens[-1].extend(sent_tokens) 154 | context_chars[-1].extend(sent_chars) 155 | start_end_facts[-1].append((N_tokens, N_tokens+my_N_tokens, is_sup_fact)) 156 | ## the above context tokens is then used to populate the context_idxs 157 | # end change 158 | offsets.append(sent_spans) 159 | flat_offsets.extend(sent_spans) 160 | 161 | if 'supporting_facts' in article: 162 | sp_set = set(list(map(tuple, article['supporting_facts']))) 163 | else: 164 | sp_set = set() 165 | 166 | to_tokenize = [prepro_sent(article['question'])] 167 | for para in paragraphs: 168 | to_tokenize.extend([para[0]]) 169 | to_tokenize.extend(para[1]) 170 | tokens = word_tokenize(to_tokenize) 171 | ques_tokens = tokens[0] 172 | tokens_id = 1 173 | 174 | for para in paragraphs: 175 | cur_title, cur_para = para[0], para[1] 176 | sent2title_ids.append((cur_title, -1)) 177 | _process(prepro_sent(cur_title), tokens[tokens_id], False, is_title=True) 178 | tokens_id += 1 179 | for sent_id, sent in enumerate(cur_para): 180 | is_sup_fact = (cur_title, sent_id) in sp_set 181 | _process(prepro_sent(sent), tokens[tokens_id], is_sup_fact) 182 | tokens_id += 1 183 | sent2title_ids.append((cur_title, sent_id)) 184 | 185 | if 'answer' in article: 186 | answer = article['answer'].strip() 187 | if answer.lower() == 'yes': 188 | best_indices = [-1, -1] 189 | elif answer.lower() == 'no': 190 | best_indices = [-2, -2] 191 | else: 192 | if article['answer'].strip() not in ''.join(text_context): 193 | # in the fullwiki setting, the answer might not have been retrieved 194 | # use (0, 1) so that we can proceed 195 | best_indices = (0, 1) 196 | else: 197 | _, best_indices, _ = fix_span(text_context, offsets, article['answer']) 198 | answer_span = [] 199 | for idx, span in enumerate(flat_offsets): 200 | if not (best_indices[1] <= span[0] or best_indices[0] >= span[1]): 201 | answer_span.append(idx) 202 | best_indices = (answer_span[0], answer_span[-1]) 203 | else: 204 | # some random stuff 205 | answer = 'random' 206 | best_indices = (0, 1) 207 | 208 | #ques_tokens = word_tokenize(prepro_sent(article['question'])) 209 | ques_chars = [list(token) for token in ques_tokens] 210 | 211 | example = {'context_tokens': context_tokens,'context_chars': context_chars, 'ques_tokens': ques_tokens, 'ques_chars': ques_chars, 'y1s': [best_indices[0]], 'y2s': [best_indices[1]], 'id': article['_id'], 'start_end_facts': start_end_facts} 212 | eval_example = {'context': text_context, 'spans': flat_offsets, 'answer': [answer], 'id': article['_id'], 213 | 'sent2title_ids': sent2title_ids} 214 | return example, eval_example 215 | 216 | def process_file(filename, config, word_counter=None, char_counter=None): 217 | data = json.load(open(filename, 'r')) 218 | 219 | examples = [] 220 | eval_examples = {} 221 | 222 | #outputs = Parallel(verbose=10)(delayed(_process_article)(article, config) for article in data) 223 | outputs = [_process_article(article, config) for article in tqdm(data)] 224 | examples = [e[0] for e in outputs] 225 | for _, e in outputs: 226 | if e is not None: 227 | eval_examples[e['id']] = e 228 | 229 | # only count during training 230 | if word_counter is not None and char_counter is not None: 231 | print('Counting words and characters...') 232 | for example in tqdm(examples): 233 | # Change 3: Get all the words of all the paragraphs. 234 | word_counter.update(example['ques_tokens']) 235 | word_counter.update(word for para in example['context_tokens'] for word in para) 236 | for token in example['ques_tokens']: 237 | char_counter.update(token) 238 | for token in (word for para in example['context_tokens'] for word in para): 239 | char_counter.update(token) 240 | 241 | random.shuffle(examples) 242 | print("{} questions in total".format(len(examples))) 243 | 244 | return examples, eval_examples 245 | 246 | def get_embedding(counter, data_type, limit=-1, emb_file=None, size=None, vec_size=None, token2idx_dict=None): 247 | print("Generating {} embedding...".format(data_type)) 248 | embedding_dict = {} 249 | filtered_elements = [k for k, v in counter.items() if v > limit] 250 | if emb_file is not None: 251 | assert size is not None 252 | assert vec_size is not None 253 | with open(emb_file, "r", encoding="utf-8") as fh: 254 | for line in tqdm(fh, total=size): 255 | array = line.split() 256 | word = "".join(array[0:-vec_size]) 257 | vector = list(map(float, array[-vec_size:])) 258 | if word in counter and counter[word] > limit: 259 | embedding_dict[word] = vector 260 | print("{} / {} tokens have corresponding {} embedding vector".format( 261 | len(embedding_dict), len(filtered_elements), data_type)) 262 | else: 263 | assert vec_size is not None 264 | for token in filtered_elements: 265 | embedding_dict[token] = [np.random.normal( 266 | scale=0.01) for _ in range(vec_size)] 267 | print("{} tokens have corresponding embedding vector".format( 268 | len(filtered_elements))) 269 | 270 | NULL = "--NULL--" 271 | OOV = "--OOV--" 272 | token2idx_dict = {token: idx for idx, token in enumerate( 273 | embedding_dict.keys(), 2)} if token2idx_dict is None else token2idx_dict 274 | token2idx_dict[NULL] = 0 275 | token2idx_dict[OOV] = 1 276 | embedding_dict[NULL] = [0. for _ in range(vec_size)] 277 | embedding_dict[OOV] = [0. for _ in range(vec_size)] 278 | idx2emb_dict = {idx: embedding_dict[token] 279 | for token, idx in token2idx_dict.items()} 280 | emb_mat = [idx2emb_dict[idx] for idx in range(len(idx2emb_dict))] 281 | 282 | idx2token_dict = {idx: token for token, idx in token2idx_dict.items()} 283 | 284 | return emb_mat, token2idx_dict, idx2token_dict 285 | 286 | 287 | def build_features(config, examples, data_type, out_file, word2idx_dict, char2idx_dict): 288 | if data_type == 'test': 289 | para_limit, ques_limit = 0, 0 290 | for example in tqdm(examples): 291 | para_limit = max(para_limit, len(example['context_tokens'])) 292 | ques_limit = max(ques_limit, len(example['ques_tokens'])) 293 | else: 294 | para_limit = config.para_limit 295 | ques_limit = config.ques_limit 296 | 297 | char_limit = config.char_limit 298 | 299 | def filter_func(example): 300 | return len(example["context_tokens"]) > para_limit or len(example["ques_tokens"]) > ques_limit 301 | 302 | print("Processing {} examples...".format(data_type)) 303 | datapoints = [] 304 | total = 0 305 | total_ = 0 306 | for example in tqdm(examples): 307 | total_ += 1 308 | 309 | if filter_func(example): 310 | continue 311 | 312 | total += 1 313 | 314 | def _get_word(word): 315 | for each in (word, word.lower(), word.capitalize(), word.upper()): 316 | if each in word2idx_dict: 317 | return word2idx_dict[each] 318 | return 1 319 | 320 | def _get_char(char): 321 | if char in char2idx_dict: 322 | return char2idx_dict[char] 323 | return 1 324 | 325 | # Convert text to indices (leave tensorization to the data iterator) 326 | context_idxs = [[_get_word(w) for w in para] for para in example['context_tokens']] 327 | ques_idxs = [_get_word(w) for w in example['ques_tokens']] 328 | 329 | context_char_idxs = [[[_get_char(c) for c in token] for token in para] for para in example['context_chars']] 330 | ques_char_idxs = [[_get_char(c) for c in token] for token in example['ques_chars']] 331 | 332 | start, end = example["y1s"][-1], example["y2s"][-1] 333 | y1, y2 = start, end 334 | 335 | datapoints.append({'context_idxs': context_idxs, 336 | 'context_char_idxs': context_char_idxs, 337 | 'ques_idxs': ques_idxs, 338 | 'ques_char_idxs': ques_char_idxs, 339 | 'y1': y1, 340 | 'y2': y2, 341 | 'id': example['id'], 342 | 'start_end_facts': example['start_end_facts']}) 343 | print("Build {} / {} instances of features in total".format(total, total_)) 344 | # pickle.dump(datapoints, open(out_file, 'wb'), protocol=-1) 345 | torch.save(datapoints, out_file) 346 | 347 | def save(filename, obj, message=None): 348 | if message is not None: 349 | print("Saving {}...".format(message)) 350 | with open(filename, "w") as fh: 351 | json.dump(obj, fh) 352 | 353 | def prepro(config): 354 | random.seed(13) 355 | 356 | if config.data_split == 'train': 357 | word_counter, char_counter = Counter(), Counter() 358 | examples, eval_examples = process_file(config.data_file, config, word_counter, char_counter) 359 | else: 360 | examples, eval_examples = process_file(config.data_file, config) 361 | 362 | word2idx_dict = None 363 | if os.path.isfile(config.word2idx_file): 364 | with open(config.word2idx_file, "r") as fh: 365 | word2idx_dict = json.load(fh) 366 | else: 367 | word_emb_mat, word2idx_dict, idx2word_dict = get_embedding(word_counter, "word", emb_file=config.glove_word_file, 368 | size=config.glove_word_size, vec_size=config.glove_dim, token2idx_dict=word2idx_dict) 369 | 370 | char2idx_dict = None 371 | if os.path.isfile(config.char2idx_file): 372 | with open(config.char2idx_file, "r") as fh: 373 | char2idx_dict = json.load(fh) 374 | else: 375 | char_emb_mat, char2idx_dict, idx2char_dict = get_embedding( 376 | char_counter, "char", emb_file=None, size=None, vec_size=config.char_dim, token2idx_dict=char2idx_dict) 377 | 378 | if config.data_split == 'train': 379 | record_file = config.train_record_file 380 | eval_file = config.train_eval_file 381 | elif config.data_split == 'dev': 382 | record_file = config.dev_record_file 383 | eval_file = config.dev_eval_file 384 | elif config.data_split == 'test': 385 | record_file = config.test_record_file 386 | eval_file = config.test_eval_file 387 | 388 | build_features(config, examples, config.data_split, record_file, word2idx_dict, char2idx_dict) 389 | save(eval_file, eval_examples, message='{} eval'.format(config.data_split)) 390 | 391 | if not os.path.isfile(config.word2idx_file): 392 | save(config.word_emb_file, word_emb_mat, message="word embedding") 393 | save(config.char_emb_file, char_emb_mat, message="char embedding") 394 | save(config.word2idx_file, word2idx_dict, message="word2idx") 395 | save(config.char2idx_file, char2idx_dict, message="char2idx") 396 | save(config.idx2word_file, idx2word_dict, message='idx2word') 397 | save(config.idx2char_file, idx2char_dict, message='idx2char') 398 | 399 | -------------------------------------------------------------------------------- /BiDAFpp/run.py: -------------------------------------------------------------------------------- 1 | import ujson as json 2 | import numpy as np 3 | from tqdm import tqdm 4 | import os 5 | from torch import optim, nn 6 | from model import Model #, NoCharModel, NoSelfModel 7 | from sp_model import SPModel 8 | # from normal_model import NormalModel, NoSelfModel, NoCharModel, NoSentModel 9 | # from oracle_model import OracleModel, OracleModelV2 10 | # from util import get_record_parser, convert_tokens, evaluate, get_batch_dataset, get_dataset 11 | from util import convert_tokens, evaluate 12 | from util import get_buckets, HotpotDataset, DataIterator, IGNORE_INDEX 13 | import time 14 | import shutil 15 | import random 16 | import torch 17 | from torch.autograd import Variable 18 | import sys 19 | from torch.nn import functional as F 20 | from torch.utils.data import RandomSampler 21 | 22 | def create_exp_dir(path, scripts_to_save=None): 23 | if not os.path.exists(path): 24 | os.mkdir(path) 25 | 26 | print('Experiment dir : {}'.format(path)) 27 | if scripts_to_save is not None: 28 | if not os.path.exists(os.path.join(path, 'scripts')): 29 | os.mkdir(os.path.join(path, 'scripts')) 30 | for script in scripts_to_save: 31 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 32 | shutil.copyfile(script, dst_file) 33 | 34 | nll_sum = nn.CrossEntropyLoss(reduction='sum', ignore_index=IGNORE_INDEX) 35 | nll_average = nn.CrossEntropyLoss(reduction='mean', ignore_index=IGNORE_INDEX) 36 | nll_all = nn.CrossEntropyLoss(reduction='none', ignore_index=IGNORE_INDEX) 37 | 38 | def train(config): 39 | with open(config.word_emb_file, "r") as fh: 40 | word_mat = np.array(json.load(fh), dtype=np.float32) 41 | with open(config.char_emb_file, "r") as fh: 42 | char_mat = np.array(json.load(fh), dtype=np.float32) 43 | with open(config.dev_eval_file, "r") as fh: 44 | dev_eval_file = json.load(fh) 45 | with open(config.idx2word_file, 'r') as fh: 46 | idx2word_dict = json.load(fh) 47 | 48 | random.seed(config.seed) 49 | np.random.seed(config.seed) 50 | torch.manual_seed(config.seed) 51 | if config.cuda: 52 | torch.cuda.manual_seed_all(config.seed) 53 | 54 | config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) 55 | create_exp_dir(config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) 56 | def logging(s, print_=True, log_=True): 57 | if print_: 58 | print(s) 59 | if log_: 60 | with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: 61 | f_log.write(s + '\n') 62 | 63 | logging('Config') 64 | for k, v in config.__dict__.items(): 65 | logging(' - {} : {}'.format(k, v)) 66 | 67 | logging("Building model...") 68 | train_buckets = get_buckets(config.train_record_file) 69 | dev_buckets = get_buckets(config.dev_record_file) 70 | 71 | def build_train_iterator(): 72 | train_dataset = HotpotDataset(train_buckets) 73 | return DataIterator(train_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, sampler=RandomSampler(train_dataset), num_workers=2) 74 | 75 | def build_dev_iterator(): 76 | dev_dataset = HotpotDataset(dev_buckets) 77 | return DataIterator(dev_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, num_workers=2) 78 | 79 | if config.sp_lambda > 0: 80 | model = SPModel(config, word_mat, char_mat) 81 | else: 82 | model = Model(config, word_mat, char_mat) 83 | 84 | logging('nparams {}'.format(sum([p.nelement() for p in model.parameters() if p.requires_grad]))) 85 | ori_model = model.cuda() if config.cuda else model 86 | model = nn.DataParallel(ori_model) 87 | 88 | lr = config.init_lr 89 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) 90 | cur_patience = 0 91 | total_loss = 0 92 | global_step = 0 93 | best_dev_F1 = None 94 | stop_train = False 95 | start_time = time.time() 96 | eval_start_time = time.time() 97 | model.train() 98 | 99 | train_iterator = build_train_iterator() 100 | dev_iterator = build_dev_iterator() 101 | 102 | for epoch in range(10000): 103 | for data in train_iterator: 104 | if config.cuda: 105 | data = {k:(data[k].cuda() if k != 'ids' else data[k]) for k in data} 106 | context_idxs = data['context_idxs'] 107 | ques_idxs = data['ques_idxs'] 108 | context_char_idxs = data['context_char_idxs'] 109 | ques_char_idxs = data['ques_char_idxs'] 110 | context_lens = data['context_lens'] 111 | y1 = data['y1'] 112 | y2 = data['y2'] 113 | q_type = data['q_type'] 114 | is_support = data['is_support'] 115 | start_mapping = data['start_mapping'] 116 | end_mapping = data['end_mapping'] 117 | all_mapping = data['all_mapping'] 118 | 119 | logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, context_lens.sum(1).max().item(), return_yp=False) 120 | loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) 121 | loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) 122 | loss = loss_1 + config.sp_lambda * loss_2 123 | 124 | optimizer.zero_grad() 125 | loss.backward() 126 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm if config.max_grad_norm > 0 else 1e10) 127 | optimizer.step() 128 | 129 | total_loss += loss.item() 130 | global_step += 1 131 | 132 | if global_step % config.period == 0: 133 | cur_loss = total_loss / config.period 134 | elapsed = time.time() - start_time 135 | logging('| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | gradnorm: {:6.3}'.format(epoch, global_step, lr, elapsed*1000/config.period, cur_loss, grad_norm)) 136 | total_loss = 0 137 | start_time = time.time() 138 | 139 | if global_step % config.checkpoint == 0: 140 | model.eval() 141 | metrics = evaluate_batch(dev_iterator, model, 0, dev_eval_file, config) 142 | model.train() 143 | 144 | logging('-' * 89) 145 | logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint, 146 | epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) 147 | logging('-' * 89) 148 | 149 | eval_start_time = time.time() 150 | 151 | dev_F1 = metrics['f1'] 152 | if best_dev_F1 is None or dev_F1 > best_dev_F1: 153 | best_dev_F1 = dev_F1 154 | torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) 155 | cur_patience = 0 156 | else: 157 | cur_patience += 1 158 | if cur_patience >= config.patience: 159 | lr /= 2.0 160 | for param_group in optimizer.param_groups: 161 | param_group['lr'] = lr 162 | if lr < config.init_lr * 1e-2: 163 | stop_train = True 164 | break 165 | cur_patience = 0 166 | if stop_train: break 167 | logging('best_dev_F1 {}'.format(best_dev_F1)) 168 | 169 | def evaluate_batch(data_source, model, max_batches, eval_file, config): 170 | answer_dict = {} 171 | sp_dict = {} 172 | total_loss, step_cnt = 0, 0 173 | iter = data_source 174 | for step, data in enumerate(iter): 175 | if step >= max_batches and max_batches > 0: break 176 | 177 | with torch.no_grad(): 178 | if config.cuda: 179 | data = {k:(data[k].cuda() if k != 'ids' else data[k]) for k in data} 180 | context_idxs = data['context_idxs'] 181 | ques_idxs = data['ques_idxs'] 182 | context_char_idxs = data['context_char_idxs'] 183 | ques_char_idxs = data['ques_char_idxs'] 184 | context_lens = data['context_lens'] 185 | y1 = data['y1'] 186 | y2 = data['y2'] 187 | q_type = data['q_type'] 188 | is_support = data['is_support'] 189 | start_mapping = data['start_mapping'] 190 | end_mapping = data['end_mapping'] 191 | all_mapping = data['all_mapping'] 192 | 193 | logit1, logit2, predict_type, predict_support, yp1, yp2 = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, context_lens.sum(1).max().item(), return_yp=True) 194 | loss = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) + config.sp_lambda * nll_average(predict_support.view(-1, 2), is_support.view(-1)) 195 | answer_dict_ = convert_tokens(eval_file, data['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), np.argmax(predict_type.data.cpu().numpy(), 1)) 196 | answer_dict.update(answer_dict_) 197 | 198 | total_loss += loss.item() 199 | step_cnt += 1 200 | loss = total_loss / step_cnt 201 | metrics = evaluate(eval_file, answer_dict) 202 | metrics['loss'] = loss 203 | 204 | return metrics 205 | 206 | def predict(data_source, model, eval_file, config, prediction_file): 207 | answer_dict = {} 208 | sp_dict = {} 209 | sp_th = config.sp_threshold 210 | for step, data in enumerate(tqdm(data_source)): 211 | with torch.no_grad(): 212 | if config.cuda: 213 | data = {k:(data[k].cuda() if k != 'ids' else data[k]) for k in data} 214 | context_idxs = data['context_idxs'] 215 | ques_idxs = data['ques_idxs'] 216 | context_char_idxs = data['context_char_idxs'] 217 | ques_char_idxs = data['ques_char_idxs'] 218 | context_lens = data['context_lens'] 219 | start_mapping = data['start_mapping'] 220 | end_mapping = data['end_mapping'] 221 | all_mapping = data['all_mapping'] 222 | 223 | logit1, logit2, predict_type, predict_support, yp1, yp2 = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, context_lens.sum(1).max().item(), return_yp=True) 224 | answer_dict_ = convert_tokens(eval_file, data['ids'], yp1.data.cpu().numpy().tolist(), yp2.data.cpu().numpy().tolist(), np.argmax(predict_type.data.cpu().numpy(), 1)) 225 | answer_dict.update(answer_dict_) 226 | 227 | predict_support_np = torch.sigmoid(predict_support[:, :, 1] - predict_support[:, :, 0]).data.cpu().numpy() 228 | for i in range(predict_support_np.shape[0]): 229 | cur_sp_pred = [] 230 | cur_id = data['ids'][i] 231 | for j in range(predict_support_np.shape[1]): 232 | if j >= len(eval_file[cur_id]['sent2title_ids']): break 233 | if predict_support_np[i, j] > sp_th: 234 | cur_sp_pred.append(eval_file[cur_id]['sent2title_ids'][j]) 235 | sp_dict.update({cur_id: cur_sp_pred}) 236 | 237 | prediction = {'answer': answer_dict, 'sp': sp_dict} 238 | with open(prediction_file, 'w') as f: 239 | json.dump(prediction, f) 240 | 241 | def test(config): 242 | with open(config.word_emb_file, "r") as fh: 243 | word_mat = np.array(json.load(fh), dtype=np.float32) 244 | with open(config.char_emb_file, "r") as fh: 245 | char_mat = np.array(json.load(fh), dtype=np.float32) 246 | if config.data_split == 'dev': 247 | with open(config.dev_eval_file, "r") as fh: 248 | dev_eval_file = json.load(fh) 249 | else: 250 | with open(config.test_eval_file, 'r') as fh: 251 | dev_eval_file = json.load(fh) 252 | with open(config.idx2word_file, 'r') as fh: 253 | idx2word_dict = json.load(fh) 254 | 255 | random.seed(config.seed) 256 | np.random.seed(config.seed) 257 | torch.manual_seed(config.seed) 258 | if config.cuda: 259 | torch.cuda.manual_seed_all(config.seed) 260 | 261 | def logging(s, print_=True, log_=True): 262 | if print_: 263 | print(s) 264 | if log_: 265 | with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: 266 | f_log.write(s + '\n') 267 | 268 | if config.data_split == 'dev': 269 | dev_buckets = get_buckets(config.dev_record_file) 270 | para_limit = config.para_limit 271 | ques_limit = config.ques_limit 272 | elif config.data_split == 'test': 273 | para_limit = None 274 | ques_limit = None 275 | dev_buckets = get_buckets(config.test_record_file) 276 | 277 | def build_dev_iterator(): 278 | dev_dataset = HotpotDataset(dev_buckets) 279 | return DataIterator(dev_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, num_workers=2) 280 | 281 | if config.sp_lambda > 0: 282 | model = SPModel(config, word_mat, char_mat) 283 | else: 284 | model = Model(config, word_mat, char_mat) 285 | ori_model = model.cuda() if config.cuda else model 286 | ori_model.load_state_dict(torch.load(os.path.join(config.save, 'model.pt'), map_location=lambda storage, loc: storage)) 287 | model = nn.DataParallel(ori_model) 288 | 289 | model.eval() 290 | predict(build_dev_iterator(), model, dev_eval_file, config, config.prediction_file) 291 | 292 | -------------------------------------------------------------------------------- /BiDAFpp/sp_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import math 8 | from torch.nn import init 9 | from torch.nn.utils import rnn 10 | 11 | class SPModel(nn.Module): 12 | def __init__(self, config, word_mat, char_mat): 13 | super().__init__() 14 | self.config = config 15 | self.word_dim = config.glove_dim 16 | self.word_emb = nn.Embedding(len(word_mat), len(word_mat[0]), padding_idx=0) 17 | self.word_emb.weight.data.copy_(torch.from_numpy(word_mat)) 18 | self.word_emb.weight.requires_grad = False 19 | self.char_emb = nn.Embedding(len(char_mat), len(char_mat[0]), padding_idx=0) 20 | self.char_emb.weight.data.copy_(torch.from_numpy(char_mat)) 21 | 22 | self.char_cnn = nn.Conv1d(config.char_dim, config.char_hidden, 5) 23 | self.char_hidden = config.char_hidden 24 | self.hidden = config.hidden 25 | 26 | self.dropout = LockedDropout(1-config.keep_prob) 27 | 28 | self.rnn = EncoderRNN(self.word_dim + self.char_hidden + 1, config.hidden, 1, True, True, 1-config.keep_prob, False) 29 | 30 | self.qc_att = BiAttention(config.hidden*2, 1-config.keep_prob) 31 | self.linear_1 = nn.Sequential( 32 | nn.Linear(config.hidden*6, config.hidden*2), 33 | nn.Tanh() 34 | ) 35 | 36 | self.rnn_2 = EncoderRNN(config.hidden * 2, config.hidden, 1, False, True, 1-config.keep_prob, False) 37 | self.self_att = BiAttention(config.hidden*2, 1-config.keep_prob) 38 | self.linear_2 = nn.Sequential( 39 | nn.Linear(config.hidden*6, config.hidden*2), 40 | nn.Tanh() 41 | ) 42 | 43 | self.rnn_sp = EncoderRNN(config.hidden*2, config.hidden, 1, False, True, 1-config.keep_prob, False) 44 | self.linear_sp = nn.Linear(config.hidden*2, 1) 45 | 46 | self.rnn_start = EncoderRNN(config.hidden*4, config.hidden, 1, False, True, 1-config.keep_prob, False) 47 | self.linear_start = nn.Linear(config.hidden*2, 1) 48 | 49 | self.rnn_end = EncoderRNN(config.hidden*4, config.hidden, 1, False, True, 1-config.keep_prob, False) 50 | self.linear_end = nn.Linear(config.hidden*2, 1) 51 | 52 | self.rnn_type = EncoderRNN(config.hidden*4, config.hidden, 1, False, True, 1-config.keep_prob, False) 53 | self.linear_type = nn.Linear(config.hidden*2, 3) 54 | 55 | self.cache_S = 0 56 | 57 | def get_output_mask(self, outer): 58 | S = outer.size(1) 59 | if S <= self.cache_S: 60 | return Variable(self.cache_mask[:S, :S], requires_grad=False) 61 | self.cache_S = S 62 | np_mask = np.tril(np.triu(np.ones((S, S)), 0), 15) 63 | self.cache_mask = outer.data.new(S, S).copy_(torch.from_numpy(np_mask)) 64 | return Variable(self.cache_mask, requires_grad=False) 65 | 66 | def rnn_over_context(self, rnn, x, lens): 67 | batch_size, num_of_paragraphs, para_len, hidden_dim = x.size() 68 | x = self.dropout(x.view(batch_size, num_of_paragraphs * para_len, hidden_dim)) 69 | x = x.view(batch_size * num_of_paragraphs, para_len, hidden_dim) 70 | lens = lens.view(-1) 71 | l1 = torch.max(lens, lens.new_ones(1)) 72 | y = rnn(x, l1) 73 | return y.masked_fill((lens == 0).unsqueeze(1).unsqueeze(2), 0).view(batch_size, num_of_paragraphs, para_len, -1) 74 | 75 | def forward(self, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, max_len, return_yp=False): 76 | # Note:- Dimensions of context_idxs is [10, 10, 40] 77 | # cur_batch size is 10 and each of the batch items is a vector of size [10, 40] 78 | para_size, ques_size, char_size, bsz = context_idxs.size(1), ques_idxs.size(1), context_char_idxs.size(-1), context_idxs.size(0) 79 | 80 | batch_size, num_of_paragraphs, para_len = context_idxs.size() 81 | context_idxs = context_idxs.reshape(-1, para_len) 82 | context_mask = (context_idxs > 0).float() 83 | ques_mask = (ques_idxs > 0).float() 84 | 85 | context_ch = self.char_emb(context_char_idxs) 86 | ques_ch = self.char_emb(ques_char_idxs) 87 | # 88 | context_ch = self.char_cnn(context_ch.view(batch_size * num_of_paragraphs * para_len, char_size, -1).permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(batch_size * num_of_paragraphs, para_len, -1) 89 | ques_ch = self.char_cnn(ques_ch.view(batch_size * ques_size, char_size, -1).permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, ques_size, -1) 90 | 91 | context_word = self.word_emb(context_idxs) 92 | ques_word = self.word_emb(ques_idxs) 93 | 94 | context_output = torch.cat([context_word, context_ch, context_word.new_zeros((context_word.size(0), context_word.size(1), 1))], dim=2).view(batch_size, num_of_paragraphs, para_len, -1) 95 | ques_output = torch.cat([ques_word, ques_ch, ques_word.new_ones((ques_word.size(0), ques_word.size(1), 1))], dim=2) 96 | 97 | context_output = self.rnn_over_context(self.rnn, context_output, context_lens) 98 | ques_output = self.rnn(self.dropout(ques_output)) 99 | 100 | qc_hid = torch.cat([context_output.view(batch_size, num_of_paragraphs * para_len, -1), ques_output], 1) 101 | qc_mask = torch.cat([context_mask.view(batch_size, num_of_paragraphs * para_len), ques_mask], 1) 102 | 103 | #output = self.qc_att(context_output.view(batch_size, num_of_paragraphs * para_len, -1), ques_output, 104 | # context_mask.view(batch_size, num_of_paragraphs * para_len), ques_mask) 105 | output = self.qc_att(qc_hid, qc_hid, qc_mask, qc_mask) 106 | output = self.linear_1(self.dropout(output)) 107 | 108 | c_output = output[:, :num_of_paragraphs * para_len].contiguous() 109 | q_output = output[:, num_of_paragraphs * para_len:].contiguous() 110 | output_t = self.rnn_over_context(self.rnn_2, c_output.view(batch_size, num_of_paragraphs, para_len, -1), context_lens) 111 | ques_output2 = self.rnn_2(self.dropout(q_output)) 112 | 113 | qc_hid2 = torch.cat([output_t.view(batch_size, num_of_paragraphs * para_len, -1), ques_output2], 1) 114 | #output_t = self.self_att(output_t, output_t, context_mask.view(batch_size, num_of_paragraphs * para_len), 115 | # context_mask.view(batch_size, num_of_paragraphs * para_len)) 116 | output_t = self.self_att(qc_hid2, qc_hid2, qc_mask, qc_mask) 117 | output_t = self.linear_2(self.dropout(output_t)) 118 | 119 | output = output + output_t 120 | output = output[:, :num_of_paragraphs * para_len].contiguous() # discard question output 121 | output = output.view(batch_size, num_of_paragraphs, para_len, -1) 122 | 123 | sp_output = self.rnn_over_context(self.rnn_sp, output, context_lens) 124 | sp_output = sp_output.view(batch_size, num_of_paragraphs * para_len, -1) 125 | 126 | #start_output = torch.matmul(start_mapping, sp_output[:,:,self.hidden:]) 127 | #end_output = torch.matmul(end_mapping, sp_output[:,:,:self.hidden]) 128 | #sp_output = torch.cat([start_output, end_output], dim=-1) 129 | sp_output = torch.matmul(all_mapping, sp_output) / (all_mapping.float().sum(-1, keepdim=True) + 1e-6) 130 | sp_output_t = self.linear_sp(self.dropout(sp_output)) 131 | sp_output_aux = sp_output_t.new_zeros(sp_output_t.size(0), sp_output_t.size(1), 1) 132 | #sp_output_aux = (sp_output_t.max(1, keepdim=True)[0] - 6).expand(*sp_output_t.size()) 133 | predict_support = torch.cat([sp_output_aux, sp_output_t], dim=-1).contiguous() 134 | 135 | sp_output = torch.matmul(all_mapping.transpose(1, 2), sp_output) 136 | 137 | output_start = torch.cat([output, sp_output.view(batch_size, num_of_paragraphs, para_len, -1)], dim=-1) 138 | output_start = self.rnn_over_context(self.rnn_start, output_start, context_lens) 139 | output_end = torch.cat([output, output_start], dim=-1) 140 | output_end = self.rnn_over_context(self.rnn_end, output_end, context_lens) 141 | output_type = torch.cat([output, output_end], dim=-1) 142 | output_type = self.rnn_over_context(self.rnn_type, output_type, context_lens) 143 | 144 | predict_start = self.linear_start(self.dropout(output_start.view(batch_size, num_of_paragraphs * para_len, -1))).view(batch_size, num_of_paragraphs, para_len) 145 | predict_end = self.linear_end(self.dropout(output_end.view(batch_size, num_of_paragraphs * para_len, -1))).view(batch_size, num_of_paragraphs, para_len) 146 | output_type = output_type.view(batch_size, num_of_paragraphs, para_len, output_type.size(-1)) 147 | 148 | # disect padded sequences of each paragraph and make padded sequence for each example 149 | # as predictions so we don't have to mess with the data format 150 | cumlens = context_lens.sum(1) 151 | 152 | logit1 = [] 153 | logit2 = [] 154 | p0_type = [] 155 | for i in range(context_lens.size(0)): 156 | logit1.append(torch.cat([predict_start[i, j, :context_lens[i][j]] for j in range(context_lens.size(1))] + [predict_start.new_full((max_len-cumlens[i], ), -1e30)], dim=0)) 157 | logit2.append(torch.cat([predict_end[i, j, :context_lens[i][j]] for j in range(context_lens.size(1))] + [predict_end.new_full((max_len-cumlens[i], ), -1e30)], dim=0)) 158 | p0_type.append(torch.cat([output_type[i, j, :context_lens[i][j]] for j in range(context_lens.size(1))] + [predict_end.new_full((max_len-cumlens[i], output_type.size(-1)), -1e30)], dim=0)) 159 | 160 | logit1 = torch.stack(logit1) 161 | logit2 = torch.stack(logit2) 162 | p0_type = torch.stack(p0_type) 163 | 164 | predict_type = self.linear_type(self.dropout(p0_type).max(1)[0]) 165 | 166 | if not return_yp: return logit1, logit2, predict_type, predict_support 167 | 168 | outer = logit1[:,:,None] + logit2[:,None] 169 | outer_mask = self.get_output_mask(outer) 170 | outer = outer - 1e30 * (1 - outer_mask[None].expand_as(outer)) 171 | yp = outer.view(outer.size(0), -1).max(1)[1] 172 | yp1 = yp // outer.size(1) 173 | yp2 = yp % outer.size(1) 174 | #yp1 = outer.max(dim=2)[0].max(dim=1)[1] 175 | #yp2 = outer.max(dim=1)[0].max(dim=1)[1] 176 | return logit1, logit2, predict_type, predict_support, yp1, yp2 177 | 178 | class LockedDropout(nn.Module): 179 | def __init__(self, dropout): 180 | super().__init__() 181 | self.dropout = dropout 182 | 183 | def forward(self, x): 184 | dropout = self.dropout 185 | if not self.training: 186 | return x 187 | m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout) 188 | mask = Variable(m.div_(1 - dropout), requires_grad=False) 189 | mask = mask.expand_as(x) 190 | return mask * x 191 | 192 | class EncoderRNN(nn.Module): 193 | def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last): 194 | super().__init__() 195 | self.rnns = nn.ModuleList() 196 | for i in range(nlayers): 197 | if i == 0: 198 | input_size_ = input_size 199 | output_size_ = num_units 200 | else: 201 | input_size_ = num_units if not bidir else num_units * 2 202 | output_size_ = num_units 203 | self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True)) 204 | self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)]) 205 | self.dropout = LockedDropout(dropout) 206 | self.concat = concat 207 | self.nlayers = nlayers 208 | self.return_last = return_last 209 | 210 | # self.reset_parameters() 211 | 212 | def reset_parameters(self): 213 | for rnn in self.rnns: 214 | for name, p in rnn.named_parameters(): 215 | if 'weight' in name: 216 | p.data.normal_(std=0.1) 217 | else: 218 | p.data.zero_() 219 | 220 | def get_init(self, bsz, i): 221 | return self.init_hidden[i].expand(-1, bsz, -1).contiguous() 222 | 223 | def forward(self, input, input_lengths=None): 224 | bsz, slen = input.size(0), input.size(1) 225 | output = input 226 | outputs = [] 227 | if input_lengths is not None: 228 | lens = input_lengths#.data.cpu().numpy() 229 | for i in range(self.nlayers): 230 | hidden = self.get_init(bsz, i) 231 | if i > 0: 232 | output = self.dropout(output) 233 | if input_lengths is not None: 234 | output = rnn.pack_padded_sequence(output, lens, batch_first=True, enforce_sorted=False) 235 | output, hidden = self.rnns[i](output, hidden) 236 | if input_lengths is not None: 237 | output, _ = rnn.pad_packed_sequence(output, batch_first=True, total_length=input.size(1)) 238 | if output.size(1) < slen: # used for parallel 239 | padding = Variable(output.data.new(1, 1, 1).zero_()) 240 | output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1) 241 | if self.return_last: 242 | outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1)) 243 | else: 244 | outputs.append(output) 245 | if self.concat: 246 | return torch.cat(outputs, dim=2) 247 | return outputs[-1] 248 | 249 | class BiAttention(nn.Module): 250 | def __init__(self, input_size, dropout): 251 | super().__init__() 252 | self.dropout = LockedDropout(dropout) 253 | self.input_linear = nn.Linear(input_size, 1, bias=False) 254 | self.memory_linear = nn.Linear(input_size, 1, bias=False) 255 | 256 | self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5))) 257 | 258 | def forward(self, input, memory, input_mask, memory_mask): 259 | bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1) 260 | 261 | input = self.dropout(input) 262 | memory = self.dropout(memory) 263 | 264 | input_dot = self.input_linear(input) 265 | memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len) 266 | cross_dot = torch.bmm(input * self.dot_scale, memory.transpose(1, 2)) 267 | att = input_dot + memory_dot + cross_dot 268 | att = att - 1e30 * (1 - memory_mask[:,None]) - 1e30 * (1 - input_mask[:, :, None]) 269 | 270 | weight_one = F.softmax(att, dim=-1)#.masked_fill(1 - memory_mask[:, None].byte(), 0).masked_fill(1 - input_mask[:, :, None].byte(), 0) 271 | output_one = torch.bmm(weight_one, memory) 272 | #weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len) 273 | #output_two = torch.bmm(weight_two, input) 274 | 275 | #return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1) 276 | return torch.cat([input, output_one, input*output_one], dim=-1) 277 | #return input + output_one 278 | 279 | class GateLayer(nn.Module): 280 | def __init__(self, d_input, d_output): 281 | super(GateLayer, self).__init__() 282 | self.linear = nn.Linear(d_input, d_output) 283 | self.gate = nn.Linear(d_input, d_output) 284 | self.sigmoid = nn.Sigmoid() 285 | 286 | def forward(self, input): 287 | return self.linear(input) * self.sigmoid(self.gate(input)) 288 | -------------------------------------------------------------------------------- /BiDAFpp/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | from collections import Counter 5 | import string 6 | import pickle 7 | import random 8 | from torch.autograd import Variable 9 | import copy 10 | import ujson as json 11 | import traceback 12 | import bisect 13 | 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | IGNORE_INDEX = -100 17 | 18 | NUM_OF_PARAGRAPHS = 10 19 | MAX_PARAGRAPH_LEN = 400 20 | 21 | RE_D = re.compile('\d') 22 | def has_digit(string): 23 | return RE_D.search(string) 24 | 25 | def prepro(token): 26 | return token if not has_digit(token) else 'N' 27 | 28 | def pad_data(data, sizes, dtype=np.int64, out=None): 29 | res = np.zeros(sizes, dtype=dtype) if out is None else out 30 | if len(sizes) == 1: 31 | res[:min(len(data), sizes[0])] = data[:sizes[0]] 32 | elif len(sizes) == 2: 33 | for i, x in enumerate(data): 34 | if i >= sizes[0]: break 35 | res[i, :min(len(x), sizes[1])] = data[i][:sizes[1]] 36 | elif len(sizes) == 3: 37 | for i, x in enumerate(data): 38 | if i >= sizes[0]: break 39 | for j, y in enumerate(x): 40 | if j >= sizes[1]: break 41 | res[i, j, :min(len(y), sizes[2])] = data[i][j][:sizes[2]] 42 | 43 | return res#torch.from_numpy(res) 44 | 45 | class HotpotDataset(Dataset): 46 | def __init__(self, buckets): 47 | self.buckets = buckets 48 | self.cumlens = [] 49 | for i, b in enumerate(self.buckets): 50 | last = 0 if i == 0 else self.cumlens[-1] 51 | self.cumlens.append(last + len(b)) 52 | 53 | def __len__(self): 54 | return self.cumlens[-1] 55 | 56 | def __getitem__(self, i): 57 | bucket_id = bisect.bisect_right(self.cumlens, i) 58 | offset = 0 if bucket_id == 0 else self.cumlens[bucket_id-1] 59 | return self.buckets[bucket_id][i - offset] 60 | 61 | class DataIterator(DataLoader): 62 | def __init__(self, dataset, para_limit, ques_limit, char_limit, sent_limit, **kwargs): 63 | if kwargs.get('collate_fn', None) is None: 64 | kwargs['collate_fn'] = self._collate_fn 65 | if para_limit is not None and ques_limit is not None: 66 | self.para_limit = para_limit 67 | self.ques_limit = ques_limit 68 | else: 69 | para_limit, ques_limit = 0, 0 70 | for bucket in buckets: 71 | for dp in bucket: 72 | para_limit = max(para_limit, dp['context_idxs'].size(0)) 73 | ques_limit = max(ques_limit, dp['ques_idxs'].size(0)) 74 | self.para_limit, self.ques_limit = para_limit, ques_limit 75 | 76 | self.char_limit = char_limit 77 | self.sent_limit = sent_limit 78 | 79 | super().__init__(dataset, **kwargs) 80 | 81 | def _collate_fn(self, batch_data): 82 | # Change: changing the dimensions of context_idxs 83 | batch_size = len(batch_data) 84 | max_sent_cnt = max(len([y for x in batch_data[i]['start_end_facts'] for y in x]) for i in range(len(batch_data))) 85 | 86 | context_idxs = np.zeros((batch_size, NUM_OF_PARAGRAPHS, MAX_PARAGRAPH_LEN), dtype=np.int64) 87 | ques_idxs = np.zeros((batch_size, self.ques_limit), dtype=np.int64) 88 | context_char_idxs = np.zeros((batch_size, NUM_OF_PARAGRAPHS, MAX_PARAGRAPH_LEN, self.char_limit), dtype=np.int64) 89 | ques_char_idxs = np.zeros((batch_size, self.ques_limit, self.char_limit), dtype=np.int64) 90 | y1 = np.zeros(batch_size, dtype=np.int64) 91 | y2 = np.zeros(batch_size, dtype=np.int64) 92 | q_type = np.zeros(batch_size, dtype=np.int64) 93 | start_mapping = np.zeros((batch_size, max_sent_cnt, NUM_OF_PARAGRAPHS * MAX_PARAGRAPH_LEN), dtype=np.float32) 94 | end_mapping = np.zeros((batch_size, max_sent_cnt, NUM_OF_PARAGRAPHS * MAX_PARAGRAPH_LEN), dtype=np.float32) 95 | all_mapping = np.zeros((batch_size, max_sent_cnt, NUM_OF_PARAGRAPHS * MAX_PARAGRAPH_LEN), dtype=np.float32) 96 | is_support = np.full((batch_size, max_sent_cnt), IGNORE_INDEX, dtype=np.int64) 97 | 98 | ids = [x['id'] for x in batch_data] 99 | 100 | max_sent_cnt = 0 101 | 102 | for i in range(len(batch_data)): 103 | pad_data(batch_data[i]['context_idxs'], (NUM_OF_PARAGRAPHS, MAX_PARAGRAPH_LEN), out=context_idxs[i]) 104 | pad_data(batch_data[i]['ques_idxs'], (self.ques_limit,), out=ques_idxs[i]) 105 | pad_data(batch_data[i]['context_char_idxs'], (NUM_OF_PARAGRAPHS, MAX_PARAGRAPH_LEN, self.char_limit), out=context_char_idxs[i]) 106 | pad_data(batch_data[i]['ques_char_idxs'], (self.ques_limit, self.char_limit), out=ques_char_idxs[i]) 107 | if batch_data[i]['y1'] >= 0: 108 | y1[i] = batch_data[i]['y1'] 109 | y2[i] = batch_data[i]['y2'] 110 | q_type[i] = 0 111 | elif batch_data[i]['y1'] == -1: 112 | y1[i] = IGNORE_INDEX 113 | y2[i] = IGNORE_INDEX 114 | q_type[i] = 1 115 | elif batch_data[i]['y1'] == -2: 116 | y1[i] = IGNORE_INDEX 117 | y2[i] = IGNORE_INDEX 118 | q_type[i] = 2 119 | elif batch_data[i]['y1'] == -3: 120 | y1[i] = IGNORE_INDEX 121 | y2[i] = IGNORE_INDEX 122 | q_type[i] = 3 123 | else: 124 | assert False 125 | 126 | for j, (para_id, cur_sp_dp) in enumerate((para_id, s) for para_id, para in enumerate(batch_data[i]['start_end_facts']) for s in para): 127 | if j >= self.sent_limit: break 128 | if len(cur_sp_dp) == 3: 129 | start, end, is_sp_flag = tuple(cur_sp_dp) 130 | else: 131 | start, end, is_sp_flag, is_gold = tuple(cur_sp_dp) 132 | start += para_id * MAX_PARAGRAPH_LEN 133 | end += para_id * MAX_PARAGRAPH_LEN 134 | if start < end: 135 | start_mapping[i, j, start] = 1 136 | end_mapping[i, j, end-1] = 1 137 | all_mapping[i, j, start:end] = 1 138 | is_support[i, j] = int(is_sp_flag) 139 | 140 | input_lengths = (context_idxs > 0).astype(np.int64).sum(2) 141 | max_q_len = int((ques_idxs > 0).astype(np.int64).sum(1).max()) 142 | 143 | context_idxs = torch.from_numpy(context_idxs) 144 | ques_idxs = torch.from_numpy(ques_idxs[:, :max_q_len]) 145 | context_char_idxs = torch.from_numpy(context_char_idxs) 146 | ques_char_idxs = torch.from_numpy(ques_char_idxs[:, :max_q_len]) 147 | input_lengths = torch.from_numpy(input_lengths) 148 | y1 = torch.from_numpy(y1) 149 | y2 = torch.from_numpy(y2) 150 | q_type = torch.from_numpy(q_type) 151 | is_support = torch.from_numpy(is_support) 152 | start_mapping = torch.from_numpy(start_mapping) 153 | end_mapping = torch.from_numpy(end_mapping) 154 | all_mapping = torch.from_numpy(all_mapping) 155 | 156 | return {'context_idxs': context_idxs, 157 | 'ques_idxs': ques_idxs, 158 | 'context_char_idxs': context_char_idxs, 159 | 'ques_char_idxs': ques_char_idxs, 160 | 'context_lens': input_lengths, 161 | 'y1': y1, 162 | 'y2': y2, 163 | 'ids': ids, 164 | 'q_type': q_type, 165 | 'is_support': is_support, 166 | 'start_mapping': start_mapping, 167 | 'end_mapping': end_mapping, 168 | 'all_mapping': all_mapping} 169 | 170 | def get_buckets(record_file): 171 | # datapoints = pickle.load(open(record_file, 'rb')) 172 | datapoints = torch.load(record_file) 173 | return [datapoints] 174 | 175 | def convert_tokens(eval_file, qa_id, pp1, pp2, p_type): 176 | answer_dict = {} 177 | for qid, p1, p2, type in zip(qa_id, pp1, pp2, p_type): 178 | if type == 0: 179 | context = eval_file[str(qid)]["context"] 180 | spans = eval_file[str(qid)]["spans"] 181 | start_idx = spans[p1][0] 182 | end_idx = spans[p2][1] 183 | answer_dict[str(qid)] = context[start_idx: end_idx] 184 | elif type == 1: 185 | answer_dict[str(qid)] = 'yes' 186 | elif type == 2: 187 | answer_dict[str(qid)] = 'no' 188 | elif type == 3: 189 | answer_dict[str(qid)] = 'noanswer' 190 | else: 191 | assert False 192 | return answer_dict 193 | 194 | def evaluate(eval_file, answer_dict): 195 | f1 = exact_match = total = 0 196 | for key, value in answer_dict.items(): 197 | total += 1 198 | ground_truths = eval_file[key]["answer"] 199 | prediction = value 200 | assert len(ground_truths) == 1 201 | cur_EM = exact_match_score(prediction, ground_truths[0]) 202 | cur_f1, _, _ = f1_score(prediction, ground_truths[0]) 203 | exact_match += cur_EM 204 | f1 += cur_f1 205 | 206 | exact_match = 100.0 * exact_match / total 207 | f1 = 100.0 * f1 / total 208 | 209 | return {'exact_match': exact_match, 'f1': f1} 210 | 211 | # def evaluate(eval_file, answer_dict, full_stats=False): 212 | # if full_stats: 213 | # with open('qaid2type.json', 'r') as f: 214 | # qaid2type = json.load(f) 215 | # f1_b = exact_match_b = total_b = 0 216 | # f1_4 = exact_match_4 = total_4 = 0 217 | 218 | # qaid2perf = {} 219 | 220 | # f1 = exact_match = total = 0 221 | # for key, value in answer_dict.items(): 222 | # total += 1 223 | # ground_truths = eval_file[key]["answer"] 224 | # prediction = value 225 | # cur_EM = metric_max_over_ground_truths( 226 | # exact_match_score, prediction, ground_truths) 227 | # # cur_f1 = metric_max_over_ground_truths(f1_score, 228 | # # prediction, ground_truths) 229 | # assert len(ground_truths) == 1 230 | # cur_f1, cur_prec, cur_recall = f1_score(prediction, ground_truths[0]) 231 | # exact_match += cur_EM 232 | # f1 += cur_f1 233 | # if full_stats and key in qaid2type: 234 | # if qaid2type[key] == '4': 235 | # f1_4 += cur_f1 236 | # exact_match_4 += cur_EM 237 | # total_4 += 1 238 | # elif qaid2type[key] == 'b': 239 | # f1_b += cur_f1 240 | # exact_match_b += cur_EM 241 | # total_b += 1 242 | # else: 243 | # assert False 244 | 245 | # if full_stats: 246 | # qaid2perf[key] = {'em': cur_EM, 'f1': cur_f1, 'pred': prediction, 247 | # 'prec': cur_prec, 'recall': cur_recall} 248 | 249 | # exact_match = 100.0 * exact_match / total 250 | # f1 = 100.0 * f1 / total 251 | 252 | # ret = {'exact_match': exact_match, 'f1': f1} 253 | # if full_stats: 254 | # if total_b > 0: 255 | # exact_match_b = 100.0 * exact_match_b / total_b 256 | # exact_match_4 = 100.0 * exact_match_4 / total_4 257 | # f1_b = 100.0 * f1_b / total_b 258 | # f1_4 = 100.0 * f1_4 / total_4 259 | # ret.update({'exact_match_b': exact_match_b, 'f1_b': f1_b, 260 | # 'exact_match_4': exact_match_4, 'f1_4': f1_4, 261 | # 'total_b': total_b, 'total_4': total_4, 'total': total}) 262 | 263 | # ret['qaid2perf'] = qaid2perf 264 | 265 | # return ret 266 | 267 | def normalize_answer(s): 268 | 269 | def remove_articles(text): 270 | return re.sub(r'\b(a|an|the)\b', ' ', text) 271 | 272 | def white_space_fix(text): 273 | return ' '.join(text.split()) 274 | 275 | def remove_punc(text): 276 | exclude = set(string.punctuation) 277 | return ''.join(ch for ch in text if ch not in exclude) 278 | 279 | def lower(text): 280 | return text.lower() 281 | 282 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 283 | 284 | 285 | def f1_score(prediction, ground_truth): 286 | normalized_prediction = normalize_answer(prediction) 287 | normalized_ground_truth = normalize_answer(ground_truth) 288 | 289 | ZERO_METRIC = (0, 0, 0) 290 | 291 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 292 | return ZERO_METRIC 293 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 294 | return ZERO_METRIC 295 | 296 | prediction_tokens = normalized_prediction.split() 297 | ground_truth_tokens = normalized_ground_truth.split() 298 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 299 | num_same = sum(common.values()) 300 | if num_same == 0: 301 | return ZERO_METRIC 302 | precision = 1.0 * num_same / len(prediction_tokens) 303 | recall = 1.0 * num_same / len(ground_truth_tokens) 304 | f1 = (2 * precision * recall) / (precision + recall) 305 | return f1, precision, recall 306 | 307 | 308 | def exact_match_score(prediction, ground_truth): 309 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 310 | 311 | 312 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 313 | scores_for_ground_truths = [] 314 | for ground_truth in ground_truths: 315 | score = metric_fn(prediction, ground_truth) 316 | scores_for_ground_truths.append(score) 317 | return max(scores_for_ground_truths) 318 | 319 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 The Board of Trustees of The Leland Stanford Junior University 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoldEn Retriever Icon GoldEn Retriever 2 | 3 | This repository contains the authors' implementation of the EMNLP-IJCNLP 2019 paper "[Answering Complex Open-domain Questions Through Iterative Query Generation](https://arxiv.org/pdf/1910.07000.pdf)". 4 | 5 | It contains code for GoldEn (Gold Entity) Retriever, an iterative retrieve-and-read system that answers complex open-domain questions. This model answers complex questions that involve multiple steps of reasoning in an open-context open-domain setting (e.g., given the entire Wikipedia). GoldEn Retriever answers these questions by iterating between "reading" the context and generating natural language queries to search for supporting facts to read. It achieves competitive performance on the [HotpotQA leaderboard](https://hotpotqa.github.io/) _without_ using powerful pretrained neural networks such as BERT. Below is an example of how this model answers a complex question by generating natural language queries at each step. 6 | 7 | ![GoldEn Retriever model architecture](fig/golden-retriever.png) 8 | 9 | We also include in the `prepared_data` folder HotpotQA files generated by the GoldEn Retriever model/training procedure that can be used to train and evaluate few-document question answering systems. These QA systems can then be combined with GoldEn Retriever in the open-context open-domain setting. 10 | 11 | ## Training Your Own GoldEn Retriever 12 | 13 | 14 | ### Setting up 15 | Checkout the code from our repository using 16 | ```bash 17 | git clone --recursive https://github.com/qipeng/golden-retriever.git 18 | ``` 19 | This will help you set up submodule dependencies (needed for DrQA). (Equivalently, you can do `git submodule update --init --recursive` after `git clone`) 20 | 21 | This repo requires Python 3.6. Please check your shell environment's `python` before proceeding. To use ElasticSearch, make sure you also install Java Development Kit (JDK) version 8. 22 | 23 | The setup script will download all required dependencies (python requirements, 24 | data, etc.) required to run the GoldEn Retriever pipeline end-to-end. Before running this script, make sure you have the Unix utility `wget` (which can be installed through anaconda as well as other common package managers). 25 | Along the way, it will also start running Elasticsearch and index the 26 | wikipedia dataset locally. 27 | 28 | _Note: This might take a while to finish and requires a large amount of disk space, so it is strongly recommended that you run this on a machine with at least 100GB of free disk space._ 29 | 30 | ```bash 31 | bash setup.sh 32 | ``` 33 | 34 | ### Run the model end-to-end 35 | 36 | ```bash 37 | bash scripts/eval_end_to_end.sh 38 | ``` 39 | 40 | By default, this generates predictions on the HotpotQA dev set in a directory named `outdir`. Take a look at the contents of the `eval_end_to_end.sh` script for more details 41 | or to modify inputs/outputs/model/etc. 42 | 43 | ### Training model components 44 | 45 | #### Hop 1 query generator 46 | 47 | 1. Generate oracle queries (labels) for the Hop 1 query generator 48 | 49 | ```bash 50 | python -m scripts.gen_hop1 dev && python -m scripts.gen_hop1 train 51 | ``` 52 | This generates Hop 1 oracle queries under `data/hop1` 53 | 2. Create dataset 54 | 55 | ```bash 56 | mkdir -p tmp 57 | 58 | python -m scripts.preprocess_hop1 --input_path --output_path ./tmp/hotpot_hop1_squad_train.json 59 | python -m scripts.preprocess_hop1 --input_path --output_path ./tmp/hotpot_hop1_squad_dev.json 60 | ``` 61 | 3. Preprocess with DrQA format 62 | 63 | ```bash 64 | # In the DrQA dir 65 | 66 | python scripts/reader/preprocess.py --split hotpot_hop1_squad_train --workers 4 67 | python scripts/reader/preprocess.py --split hotpot_hop1_squad_dev --workers 4 68 | ``` 69 | 4. Sample training code 70 | 71 | ```bash 72 | python scripts/reader/train.py --embedding-file data/embeddings/glove.840B.300d.txt --tune-partial 500 --train-file --dev-file --dev-json hotpot_hop1_squad_dev.json --hidden-size 128 --parallel True --data-workers 10 --batch-size 32 --test-batch-size 128 --learning-rate 0.001 --model-dir --max-len 50 --model-name hop1_model 73 | ``` 74 | 5. Sample prediction code 75 | 76 | ```bash 77 | python scripts/reader/predict.py data/datasets/hotpot_hop1 78 | ``` 79 | 80 | #### Hop 2 query generator 81 | 82 | 1. Generate oracle queries (labels) for the Hop 2 query generator (note this has to be run after Hop 1 oracle queries have been generated) 83 | 84 | ```bash 85 | python -m scripts.gen_hop2 dev && python -m scripts.gen_hop2 train 86 | ``` 87 | This generates Hop 2 oracle queries under `data/hop2` 88 | 2. Create DrQA dataset 89 | 90 | Copy the hop2 label json files into DrQA/data/datasets folder, then 91 | 92 | ```bash 93 | python -m scripts.preprocess_hop2 hotpot_hop2_train.json 94 | python -m scripts.preprocess_hop2 hotpot_hop2_dev.json 95 | ``` 96 | 3. Preprocess with DrQA format 97 | 98 | ```bash 99 | # In the DrQA dir 100 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD_hotpot_hop2_dev --workers 4 101 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD_hotpot_hop2_train --workers 4 102 | ``` 103 | 4. Sample training code 104 | 105 | ```bash 106 | python scripts/reader/train.py --embedding-file data/embeddings/glove.840B.300d.txt --tune-partial 1000 --max-len 20 --train-file --dev-file --dev-json --model-dir --model-name hop2_model --expand-dictionary False --num-epochs 40 107 | ``` 108 | 5. Sample prediction code 109 | 110 | ```bash 111 | python scripts/reader/predict.py --model --embedding-file data/embeddings/glove.840B.300d.txt --out-dir data/datasets 112 | ``` 113 | 114 | #### BiDAF++ question answering component 115 | 116 | 1. Generate QA data that is more compatible with the query generators using the oracle queries (note that this needs to be run after Hop 1 and Hop 2 query generation) 117 | 118 | ```bash 119 | python -m scripts.build_qa_data train && python -m scripts.build_qa_data dev-distractor 120 | # Optionally, run "python -m scripts.build_qa_data dev-fullwiki" to generate a dev set from the oracle queries where the gold paragraphs are not guanranteed to be contained 121 | ``` 122 | This will generate training and dev sets that contain retrieved documents from Wikipedia with the oracle query under `data/hotpotqa` with the suffix `_hops.json` 123 | 2. Preprocess the data for the BiDAF++ QA component 124 | 125 | ```bash 126 | # In the BiDAFpp directory 127 | python main.py --mode prepro --data_file ../data/hotpotqa/hotpot_train_hops.json --para_limit 2250 --data_split train && python main.py --mode prepro --data_file ../data/hotpotqa/hotpot_dev_distractor_hops.json --para_limit 2250 --data_split dev 128 | ``` 129 | Note that the training set has to be preprocessed before the dev set. 130 | 3. Train the BiDAF++ QA component 131 | 132 | ```bash 133 | python main.py --mode train --para_limit 2250 --batch_size 64 --init_lr 0.001 --patience 3 --keep_prob .9 --sp_lambda 10.0 --period 20 --max_grad_norm 5 --hidden 128 134 | ``` 135 | 4. Sample code for predicting from the trained QA component 136 | 137 | ```bash 138 | python main.py --mode prepro --data_file --para_limit 2250 --data_split test --fullwiki # preprocess the input data 139 | python main.py --mode test --data_split test --save --prediction_file --sp_threshold .33 --sp_lambda 10.0 --fullwiki --hidden 128 --batch_size 16 140 | ``` 141 | 142 | ## Citation 143 | 144 | If you use GoldEn Retriever in your work, please consider citing our paper 145 | 146 | ``` 147 | @inproceedings{qi2019answering, 148 | author={Qi, Peng and Lin, Xiaowen and Mehr, Leo and Wang, Zijian and Manning, Christopher D.}, 149 | booktitle={2019 Conference on Empirical Methods in Natural Language Processing and 9th International Joint Conference on Natural Language Processing ({EMNLP-IJCNLP})}, 150 | title={Answering Complex Open-domain Questions Through Iterative Query Generation}, 151 | url={https://nlp.stanford.edu/pubs/qi2019answering.pdf}, 152 | year={2019} 153 | } 154 | ``` 155 | 156 | ## License 157 | 158 | All work contained in this package is licensed under the Apache License, Version 2.0. See the included LICENSE file. 159 | -------------------------------------------------------------------------------- /fig/golden-retriever-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qipeng/golden-retriever/c806574a373a4ee86b7e754f169bb2a54d3ba15f/fig/golden-retriever-icon.png -------------------------------------------------------------------------------- /fig/golden-retriever.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qipeng/golden-retriever/c806574a373a4ee86b7e754f169bb2a54d3ba15f/fig/golden-retriever.png -------------------------------------------------------------------------------- /prepared_data/README.md: -------------------------------------------------------------------------------- 1 | # GoldEn Retriever Prepared Data Files 2 | 3 | **Note**: Please run `bash scripts/download_prepared_data.sh` from the root directory of the repository first to populate this directory. 4 | 5 | This folder contains data files we have prepared with the GoldEn Retriever pipeline, so that researchers can easily train and test their few-document multi-hop QA models in an open-context open-domain setting (or simply improve on their few-document QA models with a different set of "distractor" paragraphs). 6 | 7 | This folder contains four JSON files: 8 | 9 | 1. `hotpot_train_golden.json` is the training set we use to train our final question answering component. It has the same format as the distractor training set, i.e., the two gold paragraphs are always present, along with up to eight distractor paragraphs. The distractor paragraphs come from both stages of information retrieval, and thus are less susceptible to the same bias of in the original distractor setting. 10 | 2. `hotpot_dev_distractor_golden.json` is the corresponding development set, parallel to the original distractor setting. 11 | 3. `hotpot_dev_fullwiki_golden.json` contains 10 paragraphs for each question retrieved by the GoldEn Retriever pipeline with the oracle queries derived from our process of generating supervision signal, and the gold paragraphs are not guaranteed to be present. This provides a better estimate for open-domain QA performance at development time. 12 | 4. `hotpot_test_fullwiki_input_golden.json` is the input file to the final QA component we used for our leaderboard submission. This file could also serve as a drop-in replacement for other QA systems (especially few-document ones) since it's in the same format as the HotpotQA dataset. The documents are retrieved with queries generated by our trained GoldEn Retriever query generation models. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance==0.5.3 2 | elasticsearch>=6.0.0,<7.0.0 3 | tqdm 4 | stanfordnlp 5 | nltk 6 | ujson 7 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qipeng/golden-retriever/c806574a373a4ee86b7e754f169bb2a54d3ba15f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/build_qa_data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | def ir_candidates(ir_result, retain=5, force_retain_target=True): 5 | _id = ir_result['_id'] 6 | 7 | target = ir_result['target_para'] 8 | 9 | retained = ir_result['ir_result'][:retain] 10 | 11 | if force_retain_target and target['title'] not in [x['title'] for x in retained]: 12 | retained = retained[:-1] + [target] 13 | 14 | return (_id, [(x['title'], x['text']) for x in retained]) 15 | 16 | def merge_and_shuffle(orig_datum, hop1_candidates, hop2_candidates): 17 | # we always process examples in order so alignment should be preserved for free 18 | assert orig_datum['_id'] == hop1_candidates[0] == hop2_candidates[0] 19 | 20 | all_candidates = [] 21 | all_titles = set() 22 | for doc in hop1_candidates[1] + hop2_candidates[1]: 23 | if doc[0] not in all_titles: 24 | # deduplicate if we can 25 | all_candidates.append(doc) 26 | all_titles.add(doc[0]) 27 | 28 | random.shuffle(all_candidates) 29 | 30 | res = copy.copy(orig_datum) 31 | 32 | res['context'] = all_candidates 33 | 34 | return res 35 | 36 | if __name__ == "__main__": 37 | import argparse 38 | import json 39 | from time import time 40 | from datetime import timedelta 41 | from joblib import Parallel, delayed 42 | 43 | RETAIN_HOP1 = 5 44 | RETAIN_HOP2 = 5 45 | 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument('split', choices=['train', 'dev-distractor', 'dev-fullwiki']) 49 | 50 | args = parser.parse_args() 51 | 52 | if args.split == 'train': 53 | input_data = 'data/hotpotqa/hotpot_train_v1.1.json' 54 | hop1_ir_result = 'data/hop1/hotpot_hop1_train_ir_result.json' 55 | hop2_ir_result = 'data/hop2/hotpot_hop2_train_ir_result.json' 56 | output_data = 'data/hotpotqa/hotpot_train_hops.json' 57 | force_retain = True 58 | elif args.split == 'dev-distractor': 59 | input_data = 'data/hotpotqa/hotpot_dev_distractor_v1.json' 60 | hop1_ir_result = 'data/hop1/hotpot_hop1_dev_ir_result.json' 61 | hop2_ir_result = 'data/hop2/hotpot_hop2_dev_ir_result.json' 62 | output_data = 'data/hotpotqa/hotpot_dev_distractor_hops.json' 63 | force_retain = True 64 | else: 65 | input_data = 'data/hotpotqa/hotpot_dev_distractor_v1.json' 66 | hop1_ir_result = 'data/hop1/hotpot_hop1_dev_ir_result.json' 67 | hop2_ir_result = 'data/hop2/hotpot_hop2_dev_ir_result.json' 68 | output_data = 'data/hotpotqa/hotpot_dev_fullwiki_hops.json' 69 | force_retain = False 70 | 71 | print('Loading HotpotQA training input... ', end="", flush=True) 72 | t0 = time() 73 | with open(input_data) as f: 74 | orig_data = json.load(f) 75 | print('Done. (took {})'.format(timedelta(seconds=int(time()-t0))), flush=True) 76 | 77 | print('Loading Hop 1 IR result... ', end="", flush=True) 78 | t0 = time() 79 | with open(hop1_ir_result) as f: 80 | hop1 = json.load(f) 81 | print('Done. (took {})'.format(timedelta(seconds=int(time()-t0))), flush=True) 82 | 83 | hop1_candidates = Parallel(n_jobs=32, verbose=10)(delayed(ir_candidates)(x, retain=RETAIN_HOP1, force_retain_target=force_retain) for x in hop1) 84 | del hop1 85 | 86 | print('Loading Hop 2 IR result... ', end="", flush=True) 87 | t0 = time() 88 | with open(hop2_ir_result) as f: 89 | hop2 = json.load(f) 90 | print('Done. (took {})'.format(timedelta(seconds=int(time()-t0))), flush=True) 91 | 92 | hop2_candidates = Parallel(n_jobs=32, verbose=10)(delayed(ir_candidates)(x, retain=RETAIN_HOP2, force_retain_target=force_retain) for x in hop2) 93 | del hop2 94 | 95 | final = Parallel(n_jobs=32, verbose=10)(delayed(merge_and_shuffle)(x, y, z) for x, y, z in zip(orig_data, hop1_candidates, hop2_candidates)) 96 | 97 | print('Saving data file... ', end="", flush=True) 98 | t0 = time() 99 | with open(output_data, 'w') as f: 100 | json.dump(final, f) 101 | print('Done. (took {})'.format(timedelta(seconds=int(time()-t0))), flush=True) 102 | -------------------------------------------------------------------------------- /scripts/build_single_hop_qa_data.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from copy import copy 3 | import json 4 | from tqdm import tqdm 5 | 6 | from search.search import bulk_text_query 7 | from utils.general import chunks 8 | 9 | def main(): 10 | import argparse 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('split', choices=['train', 'dev']) 14 | 15 | args = parser.parse_args() 16 | 17 | if args.split == 'train': 18 | filename = 'data/hotpotqa/hotpot_train_v1.1.json' 19 | outputname = 'data/hotpotqa/hotpot_train_single_hop.json' 20 | else: 21 | filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json' 22 | outputname = 'data/hotpotqa/hotpot_dev_single_hop.json' 23 | batch_size = 64 24 | 25 | with open(filename) as f: 26 | data = json.load(f) 27 | 28 | outputdata = [] 29 | processed = 0 30 | for batch in tqdm(chunks(data, batch_size), total=(len(data) + batch_size - 1) // batch_size): 31 | queries = [x['question'] for x in batch] 32 | res = bulk_text_query(queries, topn=10, lazy=False) 33 | for r, d in zip(res, batch): 34 | d1 = copy(d) 35 | context = [item['data_object'] for item in r] 36 | context = [(x['title'], x['text']) for x in context] 37 | d1['context'] = context 38 | outputdata.append(d1) 39 | 40 | processed += len(batch) 41 | 42 | with open(outputname, 'w') as f: 43 | json.dump(outputdata, f) 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/download_corenlp.sh: -------------------------------------------------------------------------------- 1 | wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip 2 | unzip stanford-corenlp-full-2018-10-05.zip 3 | rm stanford-corenlp-full-2018-10-05.zip 4 | export CORENLP_HOME=`pwd`/stanford-corenlp-full-2018-10-05 5 | -------------------------------------------------------------------------------- /scripts/download_elastic_6.7.sh: -------------------------------------------------------------------------------- 1 | # download elasticsearch 2 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-6.7.0.tar.gz 3 | # extract 4 | tar -xzvf elasticsearch-6.7.0.tar.gz 5 | # clean up 6 | rm elasticsearch-6.7.0.tar.gz 7 | -------------------------------------------------------------------------------- /scripts/download_golden_retriever_models.sh: -------------------------------------------------------------------------------- 1 | mkdir -p models 2 | pushd models 3 | 4 | echo "Downloading query generators..." 5 | wget https://nlp.stanford.edu/projects/golden-retriever/hop1.mdl 6 | wget https://nlp.stanford.edu/projects/golden-retriever/hop2.mdl 7 | 8 | echo "Downloading QA model..." 9 | cd ../BiDAFpp 10 | wget https://nlp.stanford.edu/projects/golden-retriever/QAModel.zip 11 | unzip QAModel.zip 12 | rm QAModel.zip 13 | wget https://nlp.stanford.edu/projects/golden-retriever/jsons.zip 14 | unzip jsons.zip 15 | rm jsons.zip 16 | popd 17 | echo "Done!" 18 | -------------------------------------------------------------------------------- /scripts/download_hotpotqa.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/hotpotqa 2 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json -O data/hotpotqa/hotpot_train_v1.1.json 3 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json -O data/hotpotqa/hotpot_dev_distractor_v1.json 4 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json -O data/hotpotqa/hotpot_dev_fullwiki_v1.json 5 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_test_fullwiki_v1.json -O data/hotpotqa/hotpot_test_fullwiki_v1.json 6 | -------------------------------------------------------------------------------- /scripts/download_prepared_data.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading prepared data..." 2 | 3 | cd prepared_data 4 | wget https://nlp.stanford.edu/projects/golden-retriever/prepared_data.zip 5 | 6 | echo "Extracting files..." 7 | unzip prepared_data.zip 8 | rm prepared_data.zip 9 | echo "Done!" 10 | -------------------------------------------------------------------------------- /scripts/download_processed_wiki.sh: -------------------------------------------------------------------------------- 1 | # download the wiki dump file 2 | mkdir -p data 3 | wget https://nlp.stanford.edu/projects/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 -O data/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 4 | # verify that we have the whole thing 5 | unameOut="$(uname -s)" 6 | case "${unameOut}" in 7 | Darwin*) MD5SUM="md5 -r";; 8 | *) MD5SUM=md5sum 9 | esac 10 | if [ `$MD5SUM data/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 | awk '{print $1}'` == "01edf64cd120ecc03a2745352779514c" ]; then 11 | echo "Downloaded the processed Wikipedia dump from the HotpotQA website. Everything's looking good, so let's extract it!" 12 | else 13 | echo "The md5 doesn't seem to match what we expected, try again?" 14 | exit 1 15 | fi 16 | cd data 17 | tar -xjvf enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 18 | # clean up 19 | rm enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 20 | echo 'Done!' 21 | -------------------------------------------------------------------------------- /scripts/e_to_e_helpers/merge_hops_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given: 3 | hop1 and hop2 files with ES results 4 | 5 | Outputs: 6 | single file containing questions, hop1 and hop2 ES results merged 7 | use --num_each to control how many contexts are taken from each file 8 | """ 9 | 10 | import argparse 11 | 12 | from utils.io import load_json_file, write_json_file 13 | 14 | def main(args): 15 | hop1_data = load_json_file(args.hop1_file) 16 | hop2_data = load_json_file(args.hop2_file) 17 | 18 | out_data = [] 19 | for hop1, hop2 in zip(hop1_data, hop2_data): 20 | # We're assuming that the hop1 and hop2 files are sorted in the same 21 | # order. If this doesn't hold, then we would just make a map 22 | # {id -> entry} for one file. 23 | assert hop1['_id'] == hop2['_id'] 24 | 25 | entry = {} 26 | entry['_id'] = hop1['_id'] 27 | entry['question'] = hop1['question'] 28 | if args.include_queries: 29 | entry['hop1_query'] = hop1['query'] 30 | entry['hop2_query'] = hop2['query'] 31 | 32 | entry['context'] = [] 33 | all_titles = set() 34 | for doc in hop1['json_context'][:args.num_each] + hop2['json_context'][:args.num_each]: 35 | if doc[0] not in all_titles: 36 | entry['context'].append(doc) 37 | all_titles.add(doc[0]) 38 | 39 | out_data.append(entry) 40 | 41 | write_json_file(out_data, args.out_file) 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description='Merge hop1 and hop2 results.') 45 | parser.add_argument('hop1_file') 46 | parser.add_argument('hop2_file') 47 | parser.add_argument('out_file', help='filename to write data out to') 48 | parser.add_argument('--include_queries', action='store_true') 49 | parser.add_argument('--num_each', default=5, 50 | help='number of contexts to take from each hop', 51 | type=int) 52 | args = parser.parse_args() 53 | main(args) 54 | 55 | -------------------------------------------------------------------------------- /scripts/e_to_e_helpers/merge_with_es.py: -------------------------------------------------------------------------------- 1 | """ 2 | Query ES and merge results with original hotpot data. 3 | 4 | Input: 5 | - query file 6 | - hotpotqa data 7 | - output filename 8 | - whether this is for hop1 or hop2 9 | 10 | Outputs: 11 | - json file containing a list of: 12 | {'context', 'question', '_id', 'query', 'json_context'} 13 | context -- the concatentation of the top n paragraphs for the given query 14 | to ES. 15 | json_context -- same as context, but in json structure same as original 16 | hotpot data. 17 | question, _id -- identical to those from the original HotPotQA data 18 | """ 19 | 20 | import argparse 21 | from tqdm import tqdm 22 | from search.search import bulk_text_query 23 | from utils.io import load_json_file, write_json_file 24 | from utils.general import chunks, make_context 25 | 26 | def main(query_file, question_file, out_file, top_n): 27 | query_data = load_json_file(query_file) 28 | question_data = load_json_file(question_file) 29 | 30 | out_data = [] 31 | 32 | for chunk in tqdm(list(chunks(question_data, 100))): 33 | queries = [] 34 | for datum in chunk: 35 | _id = datum['_id'] 36 | queries.append(query_data[_id] if isinstance(query_data[_id], str) else query_data[_id][0][0]) 37 | 38 | es_results = bulk_text_query(queries, topn=top_n, lazy=False) 39 | for es_result, datum in zip(es_results, chunk): 40 | _id = datum['_id'] 41 | question = datum['question'] 42 | query = query_data[_id] if isinstance(query_data[_id], str) else query_data[_id][0][0] 43 | context = make_context(question, es_result) 44 | json_context = [ 45 | [p['title'], p['data_object']['text']] 46 | for p in es_result 47 | ] 48 | 49 | out_data.append({ 50 | '_id': _id, 51 | 'question': question, 52 | 'context': context, 53 | 'query': query, 54 | 'json_context': json_context 55 | }) 56 | 57 | write_json_file(out_data, out_file) 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser( 61 | description='Query ES and merge results with original hotpot data.') 62 | parser.add_argument('query_file', help='.preds file containing ES queries ') 63 | parser.add_argument('question_file', help='.json file containing original questions and ids') 64 | parser.add_argument('out_file', help='filename to write data out to') 65 | parser.add_argument('--top_n', default=5, 66 | help='number of docs to return from ES', 67 | type=int) 68 | args = parser.parse_args() 69 | 70 | main(args.query_file, args.question_file, args.out_file, args.top_n) 71 | 72 | -------------------------------------------------------------------------------- /scripts/e_to_e_helpers/squadify_questions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a list of questions, produces them in SQuAD format for DrQA. 3 | 4 | Input file should be in json, a list of objects each of which 5 | must have at least a "question" and an "_id". 6 | """ 7 | 8 | from argparse import ArgumentParser 9 | from tqdm import tqdm 10 | 11 | from utils.io import write_json_file, load_json_file 12 | 13 | def main(question_file, out_file): 14 | data = load_json_file(question_file) 15 | 16 | rows = [] 17 | for entry in data: 18 | assert 'question' in entry, 'every entry must have a question' 19 | assert '_id' in entry, 'every entry must have an _id' 20 | row = { 21 | 'title': '', 22 | 'paragraphs': [{ 23 | 'context': entry['question'], 24 | 'qas': [{ 25 | 'question': entry['question'], 26 | 'id': entry['_id'], 27 | 'answers': [{'answer_start': 0, 'text': ''}] 28 | }] 29 | }] 30 | } 31 | rows.append(row) 32 | 33 | write_json_file({'data': rows}, out_file) 34 | 35 | if __name__ == "__main__": 36 | parser = ArgumentParser() 37 | parser.add_argument('question_file', 38 | help="json file containing a list of questions and IDs") 39 | parser.add_argument('out_file', 40 | help="File to output SQuAD-formatted questions to") 41 | 42 | args = parser.parse_args() 43 | main(args.question_file, args.out_file) 44 | -------------------------------------------------------------------------------- /scripts/eval_drqa.py: -------------------------------------------------------------------------------- 1 | # sample usage: python -m scripts.eval_drqa hotpot_hop1_squad_dev-768-50-v3.preds hotpot_dev_fullwiki_v1.json 2 | 3 | from collections import Counter 4 | import json, time, re, sys 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import numpy as np 8 | from search.search import * 9 | 10 | def main(pred_filename, original_filename): 11 | batch_size = 200 12 | Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] 13 | max_n = max(Ns) 14 | 15 | 16 | with open(pred_filename) as f: 17 | data = json.load(f) 18 | 19 | 20 | with open(original_filename) as f: 21 | original_data = json.load(f) 22 | 23 | reconstructed_data = [] 24 | for idx, entry in enumerate(original_data): 25 | id = entry['_id'] 26 | query = data[id][0][0] 27 | gold = set(y[0] for y in entry['supporting_facts']) 28 | reconstructed_data.append((query, gold)) 29 | 30 | batches = [reconstructed_data[b*batch_size:min((b+1)*batch_size, len(data))] 31 | for b in range((len(data) + batch_size - 1) // batch_size)] 32 | 33 | 34 | 35 | 36 | para1 = Counter() 37 | para2 = Counter() 38 | processed = 0 39 | for batch in tqdm(batches): 40 | queries = [x[0] for x in batch] 41 | res = bulk_text_query(queries, topn=max_n, lazy=True) 42 | # set lazy to true because we don't really care about the json object here 43 | for r, d in zip(res,batch): 44 | para1_found = False 45 | para2_found = False 46 | for i, para in enumerate(r): 47 | if para['title'] in d[1]: 48 | if not para1_found: 49 | para1[i] += 1 50 | para1_found = True 51 | else: 52 | assert not para2_found 53 | para2[i] += 1 54 | para2_found = True 55 | 56 | if not para1_found: 57 | para1[max_n] += 1 58 | if not para2_found: 59 | para2[max_n] += 1 60 | 61 | processed += len(batch) 62 | 63 | print(processed) 64 | 65 | for n in Ns: 66 | c1 = sum(para1[k] for k in range(n)) 67 | c2 = sum(para2[k] for k in range(n)) 68 | 69 | print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( 70 | n, 100 * (c1+c2) / 2 / len(data), n, 100 * c1 / len(data), n, 100 * c2 / len(data))) 71 | 72 | if __name__ == "__main__": 73 | pred_filename = sys.argv[1] 74 | original_filename = sys.argv[2] 75 | main(pred_filename, original_filename) 76 | -------------------------------------------------------------------------------- /scripts/eval_end_to_end.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Instructions: 4 | # - make sure you have already run `setup.sh` and are using the correct python environment 5 | # - call this script from the root directory of this project 6 | # - please feel free to modify any of the script inputs below 7 | 8 | set -e # stop script if any command fails 9 | 10 | # Script parameters for users (feel free to edit): 11 | OUTDIR="outdir" # suggested convention is "[dataset_name]_eval" 12 | QUESTION_FILE="data/hotpotqa/hotpot_dev_distractor_v1.json" 13 | HOP1_MODEL_NAME="hop1" 14 | HOP2_MODEL_NAME="hop2" 15 | QA_MODEL_NAME="QAModel" 16 | 17 | DRQA_DIR="DrQA" 18 | EMBED_FILE="${DRQA_DIR}/data/embeddings/glove.840B.300d.txt" 19 | RECOMPUTE_ALL=false # change to `true` to force recompute everything 20 | NUM_DRQA_WORKERS=16 21 | BIDAFPP_DIR="BiDAFpp" 22 | 23 | # Toggle these settings to experiment with oracle queries in the pipeline 24 | USE_HOP1_ORACLE=false 25 | USE_HOP2_ORACLE=false 26 | 27 | # set -x # UNCOMMENT this line for debugging output 28 | 29 | # Change code below this line at your own risk! 30 | ########################################################################################### 31 | 32 | realpath() { 33 | [[ $1 = /* ]] && echo "$1" || echo "$PWD/${1#./}" 34 | } 35 | 36 | # use DrQA's version of corenlp 37 | export CLASSPATH="`realpath ${DRQA_DIR}`/data/corenlp/*:$CLASSPATH:." 38 | export CORENLP_HOME=`realpath stanford-corenlp-full-2018-10-05` 39 | 40 | HOP1_MODEL_FILE="models/$HOP1_MODEL_NAME.mdl" 41 | HOP2_MODEL_FILE="models/$HOP2_MODEL_NAME.mdl" 42 | HOP1_LABEL="data/hop1/hotpot_hop1_dev.json" 43 | HOP2_LABEL="data/hop2/hotpot_hop2_dev.json" 44 | 45 | CLASSPATH=${DRQA_DIR}/data/corenlp/* 46 | 47 | if [ ! -d ${DRQA_DIR} ] 48 | then 49 | echo "Make sure you've cloned the DrQA repo in ${DRQA_DIR}" 50 | exit 1 51 | fi 52 | 53 | if [ ! -f $EMBED_FILE ] 54 | then 55 | echo "Download the Glove embeddings and place them: $EMBED_FILE" 56 | echo "http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip" 57 | exit 1 58 | fi 59 | 60 | if [ ! -f $HOP1_MODEL_FILE ] 61 | then 62 | echo "Make sure your HOP1 model file exists: $HOP1_MODEL_FILE" 63 | exit 1 64 | fi 65 | 66 | if [ ! -f $HOP2_MODEL_FILE ] 67 | then 68 | echo "Make sure your HOP2 model file exists: $HOP2_MODEL_FILE" 69 | exit 1 70 | fi 71 | 72 | if [ ! -f $QUESTION_FILE ] 73 | then 74 | echo "Make sure your Question file exists: $QUESTION_FILE" 75 | exit 1 76 | fi 77 | 78 | echo "Placing temporary evaluation files in: $OUTDIR" 79 | mkdir -p $OUTDIR 80 | 81 | if $RECOMPUTE_ALL || [ ! -f $OUTDIR/hop1_squadified.json ] 82 | then 83 | python -m scripts.e_to_e_helpers.squadify_questions $QUESTION_FILE $OUTDIR/hop1_squadified.json 84 | fi 85 | 86 | HOP1_PREDICTIONS="$OUTDIR/hop1_squadified-$HOP1_MODEL_NAME.preds" 87 | if $RECOMPUTE_ALL || [ ! -f $HOP1_PREDICTIONS ] 88 | then 89 | echo "Generating hop1 predictions..." 90 | if $USE_HOP1_ORACLE; then 91 | python scripts/query_labels_to_pred.py $HOP1_LABEL $HOP1_PREDICTIONS 92 | else 93 | python ${DRQA_DIR}/scripts/reader/predict.py $OUTDIR/hop1_squadified.json --out-dir $OUTDIR --num-workers $NUM_DRQA_WORKERS --embedding-file $EMBED_FILE --model $HOP1_MODEL_FILE 94 | fi 95 | fi 96 | 97 | echo "Hop1 predicted labels:" 98 | ls -la $HOP1_PREDICTIONS 99 | 100 | echo "Trying to connect to ES at localhost:9200..." 101 | if ! curl -s -I localhost:9200 > /dev/null; 102 | then 103 | echo 'running "sh scripts/launch_elasticsearch_6.7.sh"' 104 | sh scripts/launch_elasticsearch_6.7.sh 105 | while ! curl -I localhost:9200; 106 | do 107 | sleep 2; 108 | done 109 | fi 110 | echo "ES is up and running" 111 | 112 | if [ ! -f $OUTDIR/hop2_input.json ] \ 113 | || [ ! -f $OUTDIR/SQuAD_hop2_input.json ] \ 114 | || $RECOMPUTE_ALL 115 | then 116 | echo "Creating input for hop2 query prediction" 117 | python -m scripts.e_to_e_helpers.merge_with_es \ 118 | $HOP1_PREDICTIONS \ 119 | $QUESTION_FILE \ 120 | $OUTDIR/hop2_input.json 121 | 122 | python -m scripts.preprocess_hop2 $OUTDIR hop2_input.json 123 | echo "Created Hop2 SQuAD-formatted input:" 124 | else 125 | echo 'Using existing Hop2 SQuAD-formatted input file:' 126 | fi 127 | ls -la $OUTDIR/SQuAD_hop2_input.json 128 | 129 | HOP2_PREDICTIONS="$OUTDIR/SQuAD_hop2_input-$HOP2_MODEL_NAME.preds" 130 | if $RECOMPUTE_ALL || [ ! -f $HOP2_PREDICTIONS ] 131 | then 132 | if $USE_HOP2_ORACLE; then 133 | python scripts/query_labels_to_pred.py $HOP2_LABEL $HOP2_PREDICTIONS 134 | else 135 | python ${DRQA_DIR}/scripts/reader/predict.py $OUTDIR/SQuAD_hop2_input.json --out-dir $OUTDIR --num-workers $NUM_DRQA_WORKERS --embedding-file $EMBED_FILE --model $HOP2_MODEL_FILE 136 | fi 137 | fi 138 | 139 | echo "Hop2 predictions:" 140 | ls -la $HOP2_PREDICTIONS 141 | 142 | if $RECOMPUTE_ALL || [ ! -f $OUTDIR/hop2_output.json ] 143 | then 144 | echo "Querying ES with hop2 predictions" 145 | python -m scripts.e_to_e_helpers.merge_with_es \ 146 | $HOP2_PREDICTIONS \ 147 | $QUESTION_FILE \ 148 | $OUTDIR/hop2_output.json 149 | echo "Created Hop2 output:" 150 | else 151 | echo "Using existing Hop2 output:" 152 | fi 153 | ls -la $OUTDIR/hop2_output.json 154 | 155 | if $RECOMPUTE_ALL || [ ! -f $OUTDIR/qa_input.json ]; then 156 | python -m scripts.e_to_e_helpers.merge_hops_results \ 157 | $OUTDIR/hop2_input.json \ 158 | $OUTDIR/hop2_output.json \ 159 | $OUTDIR/qa_input.json \ 160 | --include_queries \ 161 | --num_each 5 162 | echo "Created QA output:" 163 | else 164 | echo "Using existing QA output:" 165 | fi 166 | ls -la $OUTDIR/qa_input.json 167 | 168 | WD=`pwd` 169 | if $RECOMPUTE_ALL || [ ! -f $WD/$OUTDIR/golden.json ]; then 170 | pushd $BIDAFPP_DIR 171 | python main.py --mode prepro --data_file $WD/$OUTDIR/qa_input.json --para_limit 2250 --data_split test --fullwiki 172 | python main.py --mode test --data_split test --save ${QA_MODEL_NAME} --prediction_file $WD/$OUTDIR/golden.json --sp_threshold .33 --sp_lambda 10.0 --fullwiki --hidden 128 --batch_size 16 173 | popd 174 | fi 175 | ls $WD/$OUTDIR/golden.json 176 | 177 | if [ -f $WD/${QUESTION_FILE} ]; then 178 | cd $BIDAFPP_DIR 179 | python hotpot_evaluate_v1.py $WD/$OUTDIR/golden.json $WD/${QUESTION_FILE} | python $WD/scripts/format_result.py 180 | fi 181 | cd $WD 182 | 183 | echo "Done! Final results in: $OUTDIR/golden.json" 184 | 185 | -------------------------------------------------------------------------------- /scripts/eval_hits.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import json 3 | from tqdm import tqdm 4 | 5 | from search.search import bulk_text_query 6 | 7 | def main(): 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('split', choices=['train', 'dev']) 12 | 13 | args = parser.parse_args() 14 | 15 | if args.split == 'train': 16 | filename = 'data/hotpotqa/hotpot_train_v1.1.json' 17 | else: 18 | filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json' 19 | batch_size = 64 20 | Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] 21 | max_n = max(Ns) 22 | 23 | with open(filename) as f: 24 | data = json.load(f) 25 | 26 | batches = [[(x['question'], set(y[0] for y in x['supporting_facts'])) 27 | for x in data[b*batch_size:min((b+1)*batch_size, len(data))]] 28 | for b in range((len(data) + batch_size - 1) // batch_size)] 29 | 30 | para1 = Counter() 31 | para2 = Counter() 32 | processed = 0 33 | for batch in tqdm(batches): 34 | queries = [x[0] for x in batch] 35 | res = bulk_text_query(queries, topn=max_n, lazy=True) 36 | # set lazy to true because we don't really care about the json object here 37 | for r, d in zip(res, batch): 38 | para1_found = False 39 | para2_found = False 40 | for i, para in enumerate(r): 41 | if para['title'] in d[1]: 42 | if not para1_found: 43 | para1[i] += 1 44 | para1_found = True 45 | else: 46 | assert not para2_found 47 | para2[i] += 1 48 | para2_found = True 49 | 50 | if not para1_found: 51 | para1[max_n] += 1 52 | if not para2_found: 53 | para2[max_n] += 1 54 | 55 | processed += len(batch) 56 | 57 | for n in Ns: 58 | c1 = sum(para1[k] for k in range(n)) 59 | c2 = sum(para2[k] for k in range(n)) 60 | 61 | print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( 62 | n, 100 * (c1+c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed)) 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /scripts/eval_model2.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import json, time, re, sys 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import numpy as np 6 | from search.search import * 7 | import argparse 8 | 9 | def evaluate(pred_filename, original_filename): 10 | batch_size = 200 11 | Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] 12 | max_n = max(Ns) 13 | 14 | with open(pred_filename) as f: 15 | data = json.load(f) 16 | 17 | with open(original_filename) as f: 18 | original_data = json.load(f) 19 | 20 | if len(original_data) != len(data): 21 | print("Warning: Data length mismatch") 22 | print("Label file:", len(original_data)) 23 | print("Preds file:", len(data)) 24 | 25 | 26 | reconstructed_data = [] 27 | for idx, entry in enumerate(original_data): 28 | id = entry['_id'] 29 | query = data[id][0][0] 30 | title1 = entry['title1'] 31 | title2 = entry['title2'] 32 | reconstructed_data.append((query, title1, title2)) 33 | 34 | batches = [reconstructed_data[b*batch_size:min((b+1)*batch_size, len(data))] 35 | for b in range((len(original_data) + batch_size - 1) // batch_size)] 36 | 37 | 38 | para1 = Counter() 39 | para2 = Counter() 40 | for batch in tqdm(batches): 41 | queries = [x[0] for x in batch] 42 | res = bulk_text_query(queries, topn=max_n, lazy=True) 43 | # set lazy to true because we don't really care about the json object here 44 | for r, d in zip(res, batch): 45 | query, title1, title2 = d 46 | para1_found = False 47 | para2_found = False 48 | # enumerate search results for current query 49 | for i, para in enumerate(r): 50 | if para['title'] == title1 and not para1_found: 51 | para1[i] += 1 52 | para1_found = True 53 | elif para['title'] == title2 and not para2_found: 54 | para2[i] += 1 55 | para2_found = True 56 | 57 | # Print stats 58 | for n in Ns: 59 | c1 = sum(para1[k] for k in range(n)) 60 | c2 = sum(para2[k] for k in range(n)) 61 | 62 | print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( 63 | n, 64 | 100 * (c1+c2) / 2 / len(reconstructed_data), # Hits@n 65 | n, 66 | 100 * c1 / len(reconstructed_data), # P1@n 67 | n, 68 | 100 * c2 / len(reconstructed_data) # P2@n 69 | )) 70 | 71 | 72 | if __name__ == "__main__": 73 | # Example usage: python -m scripts.eval_model2 /u/scr/veralin/DrQA/data/datasets/SQuAD_hotpot_hop2_dev_v4-hop2_v2_30e.preds /u/scr/veralin/deep-retriever/data/hop2/hotpot_hop2_dev_v4.json 74 | parser = argparse.ArgumentParser(description='IR evaluation for model 2 predictions.') 75 | parser.add_argument('pred_filename', help='The prediction json files ') 76 | parser.add_argument('original_filename', help='The label json file that contains title1 and title2') 77 | 78 | args = parser.parse_args() 79 | 80 | # pred_filename = sys.argv[1] 81 | # original_filename = sys.argv[2] 82 | 83 | evaluate(args.pred_filename, args.original_filename) 84 | 85 | -------------------------------------------------------------------------------- /scripts/eval_model2_emf1.py: -------------------------------------------------------------------------------- 1 | import ujson as json 2 | import re 3 | import string 4 | from collections import Counter 5 | import argparse 6 | 7 | 8 | def normalize_answer(s): 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def f1_score(prediction, ground_truth): 27 | normalized_prediction = normalize_answer(prediction) 28 | normalized_ground_truth = normalize_answer(ground_truth) 29 | 30 | ZERO_METRIC = (0, 0, 0) 31 | 32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 33 | return ZERO_METRIC 34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 35 | return ZERO_METRIC 36 | 37 | prediction_tokens = normalized_prediction.split() 38 | ground_truth_tokens = normalized_ground_truth.split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return ZERO_METRIC 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1, precision, recall 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 51 | 52 | def update_answer(metrics, prediction, gold): 53 | em = exact_match_score(prediction, gold) 54 | f1, prec, recall = f1_score(prediction, gold) 55 | metrics['em'] += float(em) 56 | metrics['f1'] += f1 57 | metrics['prec'] += prec 58 | metrics['recall'] += recall 59 | return em, prec, recall 60 | 61 | def evaluate(prediction_file, gold_file): 62 | with open(prediction_file) as f: 63 | prediction = json.load(f) 64 | with open(gold_file) as f: 65 | gold = json.load(f)['data'] 66 | 67 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0} 68 | for dp in gold: 69 | cur_id = dp['paragraphs'][0]['qas'][0]['id'] 70 | if cur_id not in prediction: 71 | print('missing answer {}'.format(cur_id)) 72 | else: 73 | cur_pred = prediction[cur_id][0][0] 74 | cur_label = dp['paragraphs'][0]['qas'][0]['answers'][0]['text'] 75 | em, prec, recall = update_answer(metrics, cur_pred, cur_label) 76 | N = len(gold) 77 | for k in metrics.keys(): 78 | metrics[k] /= N 79 | 80 | print(metrics) 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser(description='Evaluate model 2') 85 | parser.add_argument('prediction_file', help='The prediction file') 86 | parser.add_argument('gold_file', help='The gold file') 87 | 88 | args = parser.parse_args() 89 | 90 | evaluate(args.prediction_file, args.gold_file) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /scripts/eval_single_hop.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Instructions: 4 | # 1. clone the DrQA repo 5 | # - run ./download.sh in that repo and ./install_corenlp.sh 6 | # - Download Glove embeddings http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip 7 | # 2. call this script from the deep-retriever directory 8 | # 3. conda activate hotpot 9 | set -e # stop script if any command fails 10 | 11 | # Script parameters for users (feel free to edit): 12 | OUTDIR="hotpotqa_dev_singlehop_eval" # suggested convention is "[dataset_name]_eval" 13 | QUESTION_FILE="data/hotpotqa/hotpot_dev_distractor_v1.json" 14 | #QA_MODEL_NAME="HOTPOT-20190514-073647" # old model trained with official hotpot data 15 | QA_MODEL_NAME="HOTPOT-20190520-075452" 16 | 17 | DRQA_DIR="DrQA" 18 | EMBED_FILE="${DRQA_DIR}/data/embeddings/glove.840B.300d.txt" 19 | EVAL_FILE="data/hotpotqa/hotpot_dev_distractor_v1.json" 20 | RECOMPUTE_ALL=false # change to `true` to force recompute everything 21 | 22 | # set -x # UNCOMMENT this line for debugging output 23 | 24 | # Change code below this line at your own risk! 25 | ########################################################################################### 26 | 27 | echo "Placing temporary evaluation files in: $OUTDIR" 28 | mkdir -p $OUTDIR 29 | 30 | echo "Trying to connect to ES at localhost:9200..." 31 | if ! curl -s -I localhost:9200 > /dev/null; 32 | then 33 | echo 'running "sh scripts/launch_elasticsearch_6.7.sh"' 34 | sh scripts/launch_elasticsearch_6.7.sh 35 | while ! curl -I localhost:9200; 36 | do 37 | sleep 2; 38 | done 39 | fi 40 | echo "ES is up and running" 41 | 42 | if $RECOMPUTE_ALL || [ ! -f data/hotpotqa/hotpot_dev_single_hop.json ]; then 43 | python -m scripts.build_single_hop_qa_data dev 44 | fi 45 | 46 | WD=`pwd` 47 | cd /u/scr/pengqi/HotPotQA 48 | if $RECOMPUTE_ALL || [ ! -f $WD/$OUTDIR/golden.json ]; then 49 | python main.py --mode prepro --data_file $WD/data/hotpotqa/hotpot_dev_single_hop.json --para_limit 2250 --data_split test --fullwiki 50 | python main.py --mode test --data_split test --save ${QA_MODEL_NAME} --prediction_file $WD/$OUTDIR/golden.json --sp_lambda 10.0 --fullwiki --hidden 128 --batch_size 16 51 | else 52 | echo 'Using existing prediction file:' 53 | fi 54 | ls $WD/$OUTDIR/golden.json 55 | 56 | if [ -f $WD/${EVAL_FILE} ]; then 57 | python hotpot_evaluate_v1.py $WD/$OUTDIR/golden.json $WD/${EVAL_FILE} | python $WD/scripts/format_result.py 58 | fi 59 | 60 | echo "Done! Final results in: $OUTDIR/golden.json" 61 | 62 | -------------------------------------------------------------------------------- /scripts/format_result.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | for line in sys.stdin: 5 | d = json.loads(line.strip().replace("'", '"')) 6 | break 7 | 8 | print(d) 9 | 10 | print('''Ans EM: {:.2%} 11 | Ans F1: {:.2%} 12 | Sup EM: {:.2%} 13 | Sup F1: {:.2%} 14 | Joint EM: {:.2%} 15 | Joint F1: {:.2%}'''.format(*tuple(d[k] for k in ['em', 'f1', 'sp_em', 'sp_f1', 'joint_em', 'joint_f1']))) 16 | -------------------------------------------------------------------------------- /scripts/gen_hop1.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from editdistance import eval as editdistance 3 | import json 4 | from multiprocessing import Pool 5 | from nltk.corpus import stopwords 6 | from nltk.stem.porter import PorterStemmer 7 | import numpy as np 8 | import os 9 | from pprint import pprint 10 | import re 11 | from tqdm import tqdm 12 | 13 | from search.search import bulk_text_query, core_title_filter 14 | from utils.constant import SEP 15 | from utils.corenlp import bulk_tokenize 16 | from utils.lcs import LCSubStr, LCS 17 | 18 | STOP_WORDS = set(stopwords.words('english')) 19 | STOP_WORDS2 = set(stopwords.words('english') + [',', '.', ';', '?', '"', '\'', '(', ')', '&', '?']) 20 | 21 | EXPANSION = {r'(movie|film)s?': 'film', 22 | r'novels?': 'novel', 23 | r'bands?': 'band', 24 | r'books?': 'book', 25 | r'magazines?': 'magazine', 26 | r'albums?': 'album', 27 | r'operas?': 'opera', 28 | r'episodes?': 'episode', 29 | r'series': 'series', 30 | r'board\s+games?': 'board game', 31 | r'director': 'film TV series', 32 | r'publish': 'book novel magazine', 33 | r'(cocktail|drink)s': 'cocktail alcohol'} 34 | 35 | def expand_query(question, query): 36 | return query 37 | 38 | for pattern in EXPANSION: 39 | if re.search(pattern, question): 40 | query += " {}".format(EXPANSION[pattern]) 41 | 42 | return query 43 | 44 | stemmer = PorterStemmer() 45 | 46 | def _filter_stopwords(text): 47 | res = [(x, i) for i, x in enumerate(text) if not x in STOP_WORDS] 48 | if len(res) == 0: 49 | return [], [] 50 | #res = [(x, i) for i, x in enumerate(text)] 51 | return map(list, zip(*res)) 52 | 53 | def _filter_stopwords2(text): 54 | res = [(x, i) for i, x in enumerate(text) if not x in STOP_WORDS2] 55 | return map(list, zip(*res)) 56 | 57 | def CompositeLCS(context_orig, context, title, para, ctx_offsets, TIE_BREAKER=0): 58 | """ 59 | inputs: 60 | - context_orig: str, raw text of the "context" that our query generator would see at this stage 61 | - context: list, list of tokens in the context 62 | - title: title of the target gold paragraph 63 | - para: paragraph text of the gold paragraph 64 | - ctx_offsets: the character offsets of the words in the context, as returned by corenlp 65 | - TIE_BREAKER: deprecated, used to break ties between matches between LCS and LCSubStr matches 66 | 67 | returns: 68 | - a list of [question, (start_char_offset, end_char_offset, start_token_offset, end_token_offset)] 69 | of candidate queries to evaluate against elasticsearch 70 | """ 71 | 72 | 73 | q0, q, t1, c1 = context_orig, context, title, para 74 | q_, qidx = _filter_stopwords2(q) 75 | t1_, _ = _filter_stopwords(t1) 76 | c1_, c1idx = _filter_stopwords(c1) 77 | 78 | q_ = [x.lower() for x in q_] 79 | t1_ = [x.lower() for x in t1_] 80 | c1_ = [x.lower() for x in c1_] 81 | 82 | leading_whitespaces = len(q0) - len(q0.lstrip()) 83 | 84 | def map_indices(idx1): 85 | """ 86 | map matched token offsets back to character offsets and original token offsets 87 | """ 88 | if idx1[0] < 0 or idx1[0] == idx1[1]: 89 | idx1 = 0, len(qidx) 90 | return (ctx_offsets[qidx[idx1[0]]][0] + leading_whitespaces, 91 | ctx_offsets[qidx[idx1[1]-1]][1] + leading_whitespaces, 92 | qidx[idx1[0]], 93 | qidx[idx1[1]-1]+1) 94 | 95 | def map_indices0(idx1): 96 | """ 97 | map matched token offsets back to character offsets and original token offsets 98 | """ 99 | if idx1[0] < 0 or idx1[0] == idx1[1]: 100 | idx1 = 0, len(q) 101 | return (ctx_offsets[idx1[0]][0] + leading_whitespaces, 102 | ctx_offsets[idx1[1]-1][1] + leading_whitespaces, 103 | idx1[0], 104 | idx1[1]) 105 | 106 | l1, _, idx1 = LCS(q_, c1_) 107 | l1t, _, idx1t = LCS(q_, t1_) 108 | l1_substr, _, idx1_substr = LCSubStr(q_, c1_) 109 | l1t_substr, _, idx1t_substr = LCSubStr(q_, t1_) 110 | 111 | q_0 = [x.lower() for x in q] 112 | t1_0 = [x.lower() for x in t1] 113 | l1t_substr0, _, idx1t_substr0 = LCSubStr(q_0, t1_0) 114 | 115 | def typo_aware_in(x, tgt_set, min_len, tolerance): 116 | for c in tgt_set: 117 | if len(c) >= min_len and editdistance(x, c) <= tolerance: 118 | return True 119 | 120 | return False 121 | 122 | def find_overlap(q, c, replacement, qidx): 123 | cset = set(c) 124 | overlap = [] 125 | overlapping = False 126 | last = -1 127 | for i, x in enumerate(q): 128 | if x in cset or typo_aware_in(x, cset, 3, 1): 129 | if x in ['"']: 130 | if last >= 0: 131 | overlap.append((last, i)) 132 | last = -1 133 | overlap.append((i, i+1)) 134 | overlapping = False 135 | continue 136 | if not overlapping: 137 | last = i 138 | overlapping = True 139 | else: 140 | if last >= 0: 141 | overlap.append((last, i)) 142 | last = -1 143 | overlapping = False 144 | 145 | if overlapping: 146 | overlap.append((last, len(q))) 147 | 148 | target_count = 4 149 | if len(overlap) > 0: 150 | cands = [(-sum(x[1]-x[0] for x in overlap[i:j+1])/(overlap[j][1]-overlap[i][0]), overlap[i][0], overlap[j][1]) if '' not in q[overlap[i][0]:overlap[j][1]] and '' not in q[overlap[i][0]:overlap[j][1]] else (1e10, overlap[i][0], overlap[j][1]) for i in range(len(overlap)) for j in range(i, len(overlap))] 151 | cands = [(x[1], x[2]) for x in list(sorted(cands))[:target_count]] 152 | cands += [replacement] * (target_count - len(cands)) 153 | return cands 154 | else: 155 | return [replacement] * target_count 156 | 157 | idx2t = find_overlap(q_, t1_, idx1t, qidx) 158 | 159 | idx2 = find_overlap(q_, c1_, idx1, qidx) 160 | 161 | cand_offsets = [map_indices(idx) for idx in [idx1, idx1t, idx1_substr, idx1t_substr] + idx2t + idx2] 162 | 163 | cand_offsets += [map_indices0(idx) for idx in [idx1t_substr0]] 164 | 165 | return [(q0[st:en], (st, en, st2, en2)) for st, en, st2, en2 in cand_offsets] 166 | 167 | def generate_single_hop1_query(data, TIE_BREAKER=0): 168 | context = dict(data['context']) 169 | supporting = sorted(set([x[0] for x in data['supporting_facts']])) 170 | 171 | q0 = data['question'].strip() 172 | c1 = [supporting[0], ''.join(context[supporting[0]])] 173 | c2 = [supporting[1], ''.join(context[supporting[1]])] 174 | 175 | to_tok = [q0] + c1 + c2 176 | tokenized, offsets = bulk_tokenize(to_tok, return_offsets=True) 177 | q = tokenized[0] 178 | t1, c1 = tokenized[1:3] 179 | t2, c2 = tokenized[3:5] 180 | 181 | cands1 = CompositeLCS(data['question'], q, t1, c1, offsets[0], TIE_BREAKER=TIE_BREAKER) 182 | cands2 = CompositeLCS(data['question'], q, t2, c2, offsets[0], TIE_BREAKER=TIE_BREAKER) 183 | 184 | return cands1, cands2 185 | 186 | def generate_hop1_queries(data, TIE_BREAKER=0): 187 | return [generate_single_hop1_query(datum, TIE_BREAKER=TIE_BREAKER) for datum in data] 188 | 189 | def deduped_bulk_query(queries1, topn=10, lazy=True): 190 | # consolidate queries to remove redundancy 191 | queries2 = [] 192 | queries2_dict = dict() 193 | mapped_idx = [] 194 | for q in queries1: 195 | if q not in queries2_dict: 196 | queries2_dict[q] = len(queries2) 197 | queries2.append(q) 198 | mapped_idx.append(queries2_dict[q]) 199 | 200 | res1 = bulk_text_query(queries2, topn=topn, lazy=lazy) 201 | 202 | # map queries back 203 | res = [res1[idx] for idx in mapped_idx] 204 | 205 | return res 206 | 207 | def main(): 208 | import argparse 209 | 210 | IR_RESULTS_TO_RETAIN = 10 211 | 212 | sanitycheck = dict() 213 | if os.path.exists('data/hop1/sanitycheck.tsv'): 214 | with open('data/hop1/sanitycheck.tsv') as f: 215 | for line in f: 216 | line = line.rstrip().split('\t') 217 | sanitycheck[line[0]] = line[1] 218 | 219 | parser = argparse.ArgumentParser() 220 | 221 | parser.add_argument('split', choices=['train', 'dev']) 222 | parser.add_argument('--analysis', action='store_true') 223 | 224 | args = parser.parse_args() 225 | 226 | if args.split == 'train': 227 | filename = 'data/hotpotqa/hotpot_train_v1.1.json' 228 | labels_file = 'data/hop1/hotpot_hop1_train.json' 229 | ir_result_file = 'data/hop1/hotpot_hop1_train_ir_result.json' 230 | else: 231 | filename = 'data/hotpotqa/hotpot_dev_distractor_v1.json' 232 | labels_file = 'data/hop1/hotpot_hop1_dev.json' 233 | ir_result_file = 'data/hop1/hotpot_hop1_dev_ir_result.json' 234 | 235 | batch_size = 64 236 | Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] 237 | max_n = max(Ns) 238 | 239 | with open(filename) as f: 240 | data = json.load(f) 241 | 242 | batches = [data[b*batch_size:min((b+1)*batch_size, len(data))] 243 | for b in range((len(data) + batch_size - 1) // batch_size)] 244 | 245 | para1 = Counter() 246 | para2 = Counter() 247 | processed = 0 248 | 249 | ir_result = [] 250 | hop1_labels = [] 251 | 252 | f_analysis = open('hop1_analysis_ge5.tsv', 'w') if args.analysis else None 253 | candidates_per_paragraph = 0 254 | 255 | pool = Pool(8) 256 | all_queries = list(tqdm(pool.imap(generate_hop1_queries, batches), total=len(batches))) 257 | for batch, queries in tqdm(zip(batches, all_queries), total=len(batches)): 258 | #queries = generate_hop1_queries(batch) 259 | if candidates_per_paragraph == 0: 260 | candidates_per_paragraph = len(queries[0][0]) 261 | print('Candidates per paragraph evaluated: {}'.format(candidates_per_paragraph)) 262 | assert all([len(y) == candidates_per_paragraph for x in queries for y in x]) 263 | 264 | queries1 = [expand_query(d['question'], z[0]) for x, d in zip(queries, batch) for y in x for z in y] 265 | 266 | res = deduped_bulk_query(queries1, topn=max_n, lazy=True) 267 | 268 | for j, d, q in zip(range(len(batch)), batch, queries): 269 | supporting = sorted(set(x[0] for x in d['supporting_facts'])) # paragraph titles 270 | supporting1 = sorted(set(core_title_filter(x[0]) for x in d['supporting_facts'])) # paragraph titles 271 | ctx = dict(d['context']) 272 | 273 | def process_result_item(query_offsets, item, orig_target, item_idx): 274 | rank = min([i for i, para in enumerate(item) if para['title'] in supporting] + [max_n]) 275 | target_para = json.loads(item[rank]['data_object']) if rank < max_n else None 276 | if target_para is None: 277 | target_para = {'title': orig_target, 'text': ctx[orig_target]} 278 | query, offsets = query_offsets 279 | splitted = query.split() 280 | token_len = len(splitted) 281 | ques_len = len(d['question'].split()) 282 | upper_case_len = sum((not x[0].islower()) or x in ['in', 'the', 'of', 'by', 'a', 'an', 'on', 'to', 'is'] for x in splitted) if len(splitted) <= 5 and (not splitted[0][0].islower()) and (not splitted[-1][0].islower()) and splitted[-1].lower() not in STOP_WORDS else sum(not x[0].islower() for x in splitted) 283 | return max(4, rank), max(token_len, min(10, ques_len * .6)) + offsets[2] * .1 + rank + max(1, sum(title in query for title in supporting1)), offsets[:2], rank, item_idx, query, target_para, token_len 284 | 285 | res1 = [process_result_item(q1, r1, supporting[0], idx) for idx, q1, r1 in zip(range(len(q[0])), q[0], res[j*2*candidates_per_paragraph:(j*2+1)*candidates_per_paragraph])] 286 | res1 += [process_result_item(q1, r1, supporting[1], idx+len(q[0])) for idx, q1, r1 in zip(range(len(q[1])), q[1], res[(j*2+1)*candidates_per_paragraph:(j*2+2)*candidates_per_paragraph])] 287 | 288 | _, _, offsets, rank, res_idx, query, target_para, token_len = list(sorted(res1))[0] 289 | r = res[j*2*candidates_per_paragraph:(j*2+2)*candidates_per_paragraph][res_idx] 290 | 291 | ir_result.append({ 292 | '_id': d['_id'], 293 | 'query': query, 294 | 'target_para': target_para, 295 | 'target_rank': rank, 296 | 'ir_result': [json.loads(x['data_object']) for x in r[:IR_RESULTS_TO_RETAIN]] 297 | }) 298 | 299 | hop1_labels.append({ 300 | '_id': d['_id'], 301 | 'question': d['question'], 302 | 'context': d['question'], 303 | 'label': query, 304 | 'label_offsets': offsets, 305 | 'target_para': target_para, 306 | 'target_rank': rank, 307 | }) 308 | 309 | para1_found = False 310 | para2_found = False 311 | for i, para in enumerate(r): 312 | if para['title'] in supporting: 313 | if not para1_found: 314 | para1[i] += 1 315 | para1_found = True 316 | else: 317 | assert not para2_found 318 | para2[i] += 1 319 | para2_found = True 320 | 321 | if not para1_found: 322 | para1[max_n] += 1 323 | if not para2_found: 324 | para2[max_n] += 1 325 | 326 | processed += len(batch) 327 | 328 | if f_analysis is not None: 329 | f_analysis.close() 330 | 331 | print('Dumping IR result to file... ', end="", flush=True) 332 | with open(ir_result_file, 'w') as f: 333 | json.dump(ir_result, f) 334 | print('Done.', flush=True) 335 | 336 | print('Dumping Hop 1 labels to file... ', end="", flush=True) 337 | with open(labels_file, 'w') as f: 338 | json.dump(hop1_labels, f) 339 | print('Done.', flush=True) 340 | 341 | for n in Ns: 342 | c1 = sum(para1[k] for k in range(n)) 343 | c2 = sum(para2[k] for k in range(n)) 344 | 345 | print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( 346 | n, 100 * (c1+c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed)) 347 | 348 | if __name__ == '__main__': 349 | main() 350 | -------------------------------------------------------------------------------- /scripts/gen_hop2.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | from collections import Counter 4 | import sys 5 | 6 | from tqdm import tqdm 7 | 8 | from search.search import bulk_text_query 9 | from utils.lcs import LCSubStr 10 | from utils.io import load_json_file 11 | from utils.general import chunks, make_context 12 | from utils.corenlp import bulk_tokenize 13 | 14 | from scripts.gen_hop1 import CompositeLCS, deduped_bulk_query, STOP_WORDS 15 | 16 | def analyze(hop2_results): 17 | batch_size = 128 18 | Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] 19 | max_n = max(Ns) 20 | p1_hits = Counter() 21 | p2_hits = Counter() 22 | processed = 0 23 | 24 | for chunk in tqdm(chunks(hop2_results, batch_size)): 25 | 26 | label2s = [x['label'] for x in chunk] 27 | es_bulk_results = bulk_text_query(label2s, topn=max_n, lazy=False) 28 | 29 | for i, (entry, es_results) in enumerate(zip(chunk, es_bulk_results)): 30 | q = entry['question'] 31 | l2 = entry['label'] 32 | t1 = entry['title1'] 33 | p1 = entry['para1'] 34 | t2 = entry['title2'] 35 | p2 = entry['para2'] 36 | 37 | # find rank of t1 in es_results 38 | found_t1 = False 39 | found_t2 = False 40 | t2_rank = max_n 41 | for i, es_entry in enumerate(es_results): 42 | if es_entry['title'] == t1: 43 | p1_hits[i] += 1 44 | found_t1 = True 45 | if es_entry['title'] == t2: 46 | p2_hits[i] += 1 47 | t2_rank = i 48 | found_t2 = True 49 | if not found_t1: 50 | p1_hits[max_n] += 1 51 | if not found_t2: 52 | p2_hits[max_n] += 1 53 | 54 | print_cols = [q, l2, t1, p1, t2, p2, str(t2_rank + 1)] 55 | #print('\t'.join(print_cols)) 56 | processed += 1 57 | 58 | for n in Ns: 59 | c1 = sum(p1_hits[k] for k in range(n)) 60 | c2 = sum(p2_hits[k] for k in range(n)) 61 | 62 | print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( 63 | n, 100 * (c1+c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed)) 64 | 65 | def main(): 66 | import argparse 67 | 68 | HOP1_TO_KEEP = 5 69 | IR_RESULTS_TO_RETAIN = 10 70 | max_n = 50 71 | batch_size = 64 72 | 73 | parser = argparse.ArgumentParser() 74 | 75 | parser.add_argument('split', choices=['train', 'dev']) 76 | parser.add_argument('--analyze', action='store_true') 77 | 78 | args = parser.parse_args() 79 | 80 | if args.split == 'train': 81 | data_file = 'data/hotpotqa/hotpot_train_v1.1.json' 82 | labels_file = 'data/hop1/hotpot_hop1_train.json' 83 | ir_file = 'data/hop1/hotpot_hop1_train_ir_result.json' 84 | output_file = 'data/hop2/hotpot_hop2_train.json' 85 | output_ir_file = 'data/hop2/hotpot_hop2_train_ir_result.json' 86 | else: 87 | data_file = 'data/hotpotqa/hotpot_dev_distractor_v1.json' 88 | labels_file = 'data/hop1/hotpot_hop1_dev.json' 89 | ir_file = 'data/hop1/hotpot_hop1_dev_ir_result.json' 90 | output_file = 'data/hop2/hotpot_hop2_dev.json' 91 | output_ir_file = 'data/hop2/hotpot_hop2_dev_ir_result.json' 92 | 93 | # make a map from id to each entry in the data so that we 94 | # can join with the generated label files 95 | id_to_datum = {} 96 | data = load_json_file(data_file) 97 | for datum in data: 98 | id_to_datum[datum['_id']] = datum 99 | 100 | # same, map from id to ir entry 101 | id_to_ir_entry = {} 102 | ir_data = load_json_file(ir_file) 103 | for entry in ir_data: 104 | id_to_ir_entry[entry['_id']] = entry 105 | 106 | hop1_labels= load_json_file(labels_file) 107 | 108 | hop2_results = [] 109 | hop2_ir_results = [] 110 | candidates_per_example = 0 111 | for batch in tqdm(chunks(hop1_labels, batch_size), total=(len(hop1_labels) + batch_size - 1)//batch_size): 112 | queries = [] 113 | processed_batch = [] 114 | for entry in batch: 115 | _id = entry['_id'] 116 | target_para = entry['target_para'] 117 | 118 | assert target_para is not None 119 | 120 | title1 = target_para['title'] 121 | para1 = ''.join(target_para['text']) 122 | question = entry['question'] 123 | 124 | orig_datum = id_to_datum[_id] 125 | supp_facts = set([f[0] for f in orig_datum['supporting_facts']]) 126 | assert len(supp_facts) == 2, supp_facts 127 | assert title1 in supp_facts 128 | supp_facts.remove(title1) 129 | title2 = supp_facts.pop() 130 | 131 | para2_matches = [ 132 | para for title, para in orig_datum['context'] 133 | if title == title2 134 | ] 135 | assert len(para2_matches) == 1, orig_datum 136 | para2 = ''.join(para2_matches[0]) 137 | para2_list = para2_matches[0] 138 | 139 | # join in hop1 ir results 140 | ir_entry = id_to_ir_entry[_id] 141 | 142 | if title1 in [x['title'] for x in ir_entry['ir_result'][:HOP1_TO_KEEP]]: 143 | ir_context = ir_entry['ir_result'][:HOP1_TO_KEEP] 144 | else: 145 | ir_context = ir_entry['ir_result'][:(HOP1_TO_KEEP-1)] + [{'title': title1, 'text': target_para['text']}] 146 | 147 | hop1_context = make_context(question, ir_context) 148 | 149 | tokenized, offsets = bulk_tokenize( 150 | [hop1_context, title2, para2], 151 | return_offsets=True 152 | ) 153 | token_hop1_context = tokenized[0] 154 | token_title2 = tokenized[1] 155 | token_para2 = tokenized[2] 156 | 157 | candidates = CompositeLCS( 158 | hop1_context, 159 | token_hop1_context, 160 | token_title2, 161 | token_para2, 162 | offsets[0], 163 | ) 164 | 165 | if candidates_per_example == 0: 166 | candidates_per_example = len(candidates) 167 | assert len(candidates) == candidates_per_example 168 | queries.extend([x[0] for x in candidates]) 169 | 170 | processed_batch.append([_id, entry, question, candidates, target_para, title1, para1, title2, para2, para2_list, hop1_context]) 171 | 172 | res = deduped_bulk_query(queries, topn=max_n, lazy=False) 173 | 174 | for i, (_id, entry, question, candidates, target_para, title1, para1, title2, para2, para2_list, hop1_context) in enumerate(processed_batch): 175 | def process_result_item(query_offsets, item, item_idx): 176 | rank = min([i for i, para in enumerate(item) if para['title'] == title2] + [max_n]) 177 | target_para = item[rank]['data_object'] if rank < max_n else None 178 | if target_para is None: 179 | target_para = {'title': title2, 'text': para2_list} 180 | query, offsets = query_offsets 181 | splitted = [x for x in query.split() if len(x)] 182 | token_len = len(splitted) 183 | if len(splitted) == 0: 184 | upper_case_len = 0 185 | else: 186 | upper_case_len = sum((not x[0].islower()) or x in ['in', 'the', 'of', 'by', 'a', 'an', 'on', 'to', 'is'] for x in splitted) if len(splitted) <= 5 and (not splitted[0][0].islower()) and (not splitted[-1][0].islower()) and splitted[-1].lower() not in STOP_WORDS else sum(not x[0].islower() for x in splitted) 187 | return max(4, rank), max(token_len, 10) + rank, offsets[:2], rank, item_idx, query, target_para, token_len 188 | 189 | res1 = [process_result_item(q1, r1, idx) for idx, q1, r1 in zip(range(len(candidates)), candidates, res[i*candidates_per_example:(i+1)*candidates_per_example])] 190 | 191 | _, _, offsets, rank, res_idx, query, target_para, token_len = list(sorted(res1))[0] 192 | 193 | hop2_ir_results.append({ 194 | '_id': _id, 195 | 'query': query, 196 | 'target_para': target_para, 197 | 'target_rank': rank, 198 | 'ir_result': [x['data_object'] for x in res[i*candidates_per_example+res_idx][:IR_RESULTS_TO_RETAIN]] 199 | }) 200 | 201 | hop2_results.append({ 202 | '_id': _id, 203 | 'question': question, 204 | 'label': query, 205 | 'context': hop1_context, 206 | 'label_offsets': offsets, 207 | 'hop1_label': entry['label'], 208 | 'hop1_offsets': entry['label_offsets'], 209 | 'title1': title1, 210 | 'para1': para1, 211 | 'title2': title2, 212 | 'para2': para2, 213 | }) 214 | 215 | print('Dumping Hop 2 labels to file... ', end="", flush=True) 216 | with open(output_file, 'w') as f: 217 | json.dump(hop2_results, f) 218 | print('Done.', flush=True) 219 | 220 | print('Dumping IR result to file... ', end="", flush=True) 221 | with open(output_ir_file, 'w') as f: 222 | json.dump(hop2_ir_results, f) 223 | print('Done.', flush=True) 224 | 225 | if args.analyze: 226 | analyze(hop2_results) 227 | 228 | print('Done!', file=sys.stderr) 229 | 230 | if __name__ == "__main__": 231 | main() 232 | -------------------------------------------------------------------------------- /scripts/index_processed_wiki.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import bz2 3 | from collections import Counter, defaultdict 4 | from elasticsearch import Elasticsearch 5 | from glob import glob 6 | import html 7 | import json 8 | from multiprocessing import Pool 9 | import numpy as np 10 | import os 11 | import pickle 12 | import re 13 | from tqdm import tqdm 14 | from urllib.parse import unquote 15 | 16 | from utils.constant import WIKIPEDIA_INDEX_NAME 17 | from utils.general import chunks 18 | 19 | def process_line(line): 20 | data = json.loads(line) 21 | item = {'id': data['id'], 22 | 'url': data['url'], 23 | 'title': data['title'], 24 | 'title_unescape': html.unescape(data['title']), 25 | 'text': ''.join(data['text']), 26 | 'title_bigram': html.unescape(data['title']), 27 | 'title_unescape_bigram': html.unescape(data['title']), 28 | 'text_bigram': ''.join(data['text']), 29 | 'original_json': line 30 | } 31 | # tell elasticsearch we're indexing documents 32 | return "{}\n{}".format(json.dumps({ 'index': { '_id': 'wiki-{}'.format(data['id']) } }), json.dumps(item)) 33 | 34 | def generate_indexing_queries_from_bz2(bz2file, dry=False): 35 | if dry: 36 | return 37 | 38 | with bz2.open(bz2file, 'rt') as f: 39 | body = [process_line(line) for line in f] 40 | 41 | return '\n'.join(body) 42 | 43 | es = Elasticsearch(timeout=100) 44 | def index_chunk(chunk): 45 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, doc_type='doc', body='\n'.join(chunk), timeout='100s') 46 | assert not res['errors'], res 47 | 48 | def main(args): 49 | # make index 50 | if not args.dry: 51 | if es.indices.exists(index=WIKIPEDIA_INDEX_NAME) and args.reindex: 52 | es.indices.delete(index=WIKIPEDIA_INDEX_NAME, ignore=[400,403]) 53 | if not es.indices.exists(index=WIKIPEDIA_INDEX_NAME): 54 | es.indices.create(index=WIKIPEDIA_INDEX_NAME, ignore=400, 55 | body=json.dumps({ 56 | "mappings":{"doc":{"properties": { 57 | "id": { "type": "keyword" }, 58 | "url": { "type": "keyword" }, 59 | "title": { "type": "text", "analyzer": "simple", "copy_to": "title_all"}, 60 | "title_unescape": { "type": "text", "analyzer": "simple", "copy_to": "title_all"}, 61 | "text": { "type": "text", "analyzer": "my_english_analyzer"}, 62 | "anchortext": { "type": "text", "analyzer": "my_english_analyzer"}, 63 | "title_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"}, 64 | "title_unescape_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"}, 65 | "text_bigram": { "type": "text", "analyzer": "bigram_analyzer"}, 66 | "anchortext_bigram": { "type": "text", "analyzer": "bigram_analyzer"}, 67 | "original_json": { "type": "string" }, 68 | }} 69 | }, 70 | "settings": { 71 | "analysis": { 72 | "my_english_analyzer": { 73 | "type": "standard", 74 | "stopwords": "_english_", 75 | }, 76 | "simple_bigram_analyzer": { 77 | "tokenizer": "standard", 78 | "filter": [ 79 | "lowercase", "shingle", "asciifolding" 80 | ] 81 | }, 82 | "bigram_analyzer": { 83 | "tokenizer": "standard", 84 | "filter": [ 85 | "lowercase", "stop", "shingle", "asciifolding" 86 | ] 87 | } 88 | }, 89 | } 90 | })) 91 | 92 | filelist = glob('data/enwiki-20171001-pages-meta-current-withlinks-abstracts/*/wiki_*.bz2') 93 | 94 | print('Making indexing queries...') 95 | pool = Pool() 96 | all_queries = list(tqdm(pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist))) 97 | 98 | count = sum(len(queries.split('\n')) for queries in all_queries) // 2 99 | 100 | if not args.dry: 101 | print('Indexing...') 102 | chunksize = 50 103 | for chunk in tqdm(chunks(all_queries, chunksize), total=(len(all_queries) + chunksize - 1) // chunksize): 104 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, doc_type='doc', body='\n'.join(chunk), timeout='100s') 105 | assert not res['errors'], res 106 | 107 | print(f"{count} documents indexed in total") 108 | 109 | if __name__ == '__main__': 110 | parser = ArgumentParser() 111 | 112 | parser.add_argument('--reindex', action='store_true', help="Reindex everything") 113 | parser.add_argument('--dry', action='store_true', help="Dry run") 114 | 115 | args = parser.parse_args() 116 | 117 | main(args) 118 | -------------------------------------------------------------------------------- /scripts/launch_elasticsearch_6.7.sh: -------------------------------------------------------------------------------- 1 | cd elasticsearch-6.7.0 2 | bin/elasticsearch 2>&1 >/dev/null & 3 | while ! curl -I localhost:9200 2>/dev/null; 4 | do 5 | sleep 2; 6 | done 7 | -------------------------------------------------------------------------------- /scripts/offline_ir_eval.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import json 3 | import sys 4 | from tqdm import tqdm 5 | 6 | with open(sys.argv[1]) as f: 7 | qa_input = json.load(f) 8 | 9 | with open(sys.argv[2]) as f: 10 | eval_file = json.load(f) 11 | 12 | recall = 0 13 | total = 0 14 | foundall = 0 15 | foundall_total = 0 16 | foundone = 0 17 | foundone_total = 0 18 | 19 | bridge_recall = 0 20 | bridge_total = 0 21 | bridge_foundall = 0 22 | bridge_foundall_total = 0 23 | bridge_foundone = 0 24 | bridge_foundone_total = 0 25 | for d1, d2 in zip(tqdm(qa_input), eval_file): 26 | assert d1['_id'] == d2['_id'] 27 | 28 | c1 = set([x[0] for x in d1['context']]) 29 | c2 = set([x[0] for x in d2['supporting_facts']]) 30 | 31 | assert len(c2) == 2 32 | 33 | found = len(c1 & c2) 34 | recall += found 35 | total += len(c2) 36 | 37 | foundall += (found == len(c2)) 38 | foundall_total += 1 39 | 40 | foundone += (found >= 1) 41 | foundone_total += 1 42 | 43 | if d2['type'] != 'comparison': 44 | found = len(c1 & c2) 45 | bridge_recall += found 46 | bridge_total += len(c2) 47 | 48 | bridge_foundall += (found == len(c2)) 49 | bridge_foundall_total += 1 50 | 51 | bridge_foundone += (found >= 1) 52 | bridge_foundone_total += 1 53 | 54 | 55 | print('Recall: {:5.2%} ({:5d} / {:5d})'.format(recall / total, recall, total)) 56 | print('Found all: {:5.2%} ({:5d} / {:5d})'.format(foundall / foundall_total, foundall, foundall_total)) 57 | print('Found one: {:5.2%} ({:5d} / {:5d})'.format(foundone / foundone_total, foundone, foundone_total)) 58 | print() 59 | print('Bridge-only recall: {:5.2%} ({:5d} / {:5d})'.format(bridge_recall / bridge_total, bridge_recall, bridge_total)) 60 | print('Bridge-only found all: {:5.2%} ({:5d} / {:5d})'.format(bridge_foundall / bridge_foundall_total, bridge_foundall, bridge_foundall_total)) 61 | print('Bridge-only found one: {:5.2%} ({:5d} / {:5d})'.format(bridge_foundone / bridge_foundone_total, bridge_foundone, bridge_foundone_total)) 62 | print() 63 | comparison_recall = recall - bridge_recall 64 | comparison_total = total - bridge_total 65 | comparison_foundall = foundall - bridge_foundall 66 | comparison_foundall_total = foundall_total - bridge_foundall_total 67 | comparison_foundone = foundone - bridge_foundone 68 | comparison_foundone_total = foundone_total - bridge_foundone_total 69 | print('Comparison-only recall: {:5.2%} ({:5d} / {:5d})'.format(comparison_recall / comparison_total, comparison_recall, comparison_total)) 70 | print('Comparison-only found all: {:5.2%} ({:5d} / {:5d})'.format(comparison_foundall / comparison_foundall_total, comparison_foundall, comparison_foundall_total)) 71 | print('Comparison-only found one: {:5.2%} ({:5d} / {:5d})'.format(comparison_foundone / comparison_foundone_total, comparison_foundone, comparison_foundone_total)) 72 | -------------------------------------------------------------------------------- /scripts/preprocess_hop1.py: -------------------------------------------------------------------------------- 1 | import json, os, sys 2 | from tqdm import tqdm 3 | 4 | from argparse import ArgumentParser 5 | 6 | 7 | def parse_data(input_path, output_path): 8 | 9 | if os.path.exists(output_path): 10 | print(f"File already exists, skipping generation: {output_path}") 11 | return 12 | 13 | with open(input_path) as infile: 14 | hop1 = json.load(infile) 15 | 16 | converted_json = {} 17 | converted_json['version'] = '0' 18 | 19 | converted_data = [] 20 | for entry in tqdm(hop1): 21 | paragraphs = [{'context': entry['context'], 22 | 'qas':[{'answers': [{'answer_start': entry['label_offsets'][0], 'text': entry['label']} for _ in range(3)], 23 | 'question': entry['question'], 24 | 'id': entry['_id']}]}] 25 | data = {"title": "", "paragraphs": paragraphs} 26 | converted_data.append(data) 27 | 28 | converted_json['data'] = converted_data 29 | 30 | with open(output_path, 'w') as outfile: 31 | json.dump(converted_json, outfile) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = ArgumentParser() 36 | parser.add_argument('--input_path', required=True, help="hotpot_train_v1.1.json") 37 | parser.add_argument('--output_path', required=True, help="hotpot_train_hop1.json") 38 | 39 | args = parser.parse_args() 40 | parse_data(args.input_path, args.output_path) 41 | 42 | 43 | -------------------------------------------------------------------------------- /scripts/preprocess_hop2.py: -------------------------------------------------------------------------------- 1 | """Preprocess the hop2 input to convert it to the SQuAD format.""" 2 | 3 | import json as json 4 | import os.path 5 | import argparse 6 | import glob 7 | 8 | 9 | def parse_data(INPUT_DIR, INPUT_FILE, query1_file=None): 10 | DATA_PATH = INPUT_DIR + '/' + INPUT_FILE 11 | print ("Processing", DATA_PATH) 12 | data = json.load(open(DATA_PATH, 'r')) 13 | if query1_file: 14 | # dictionary of key value pairs, example: 15 | # '5abf04ae5542993fe9a41dbf': [['Ndebele music', 0.6314478516578674]] 16 | hop1 = json.load(open(INPUT_DIR + '/' + query1_file, 'r')) 17 | 18 | rows = [] 19 | SHUFFLE = False 20 | for d in data: 21 | row = {} 22 | row['title'] = '' 23 | if query1_file: 24 | query1 = hop1[d['_id']][0][0] 25 | row['query1'] = query1 #TODO 26 | paragraph = {} 27 | paragraph['context'] = d['context'] 28 | qas = {} 29 | qas['question'] = d['question'] 30 | 31 | # For test set evaluation, we don't have labels 32 | # Instead we just use (0, "") 33 | if 'label_offsets' in d: 34 | start = d['label_offsets'][0] 35 | span = d['context'][d['label_offsets'][0]:d['label_offsets'][1]] 36 | else: 37 | start = 0 38 | span = '' 39 | 40 | qas['answers'] = [{'answer_start': start, 'text': span}] 41 | qas['id'] = d['_id'] 42 | paragraph['qas'] = [qas] 43 | row['paragraphs'] = [paragraph] 44 | rows.append(row) 45 | 46 | if query1_file: 47 | OUTPUT_FILE = '/SQuAD_query1_' + INPUT_FILE 48 | else: 49 | OUTPUT_FILE = '/SQuAD_' + INPUT_FILE 50 | 51 | with open(INPUT_DIR + OUTPUT_FILE, 'w') as outfile: 52 | json.dump({'data': rows}, outfile) 53 | 54 | print ("Done processing. Output to", INPUT_DIR + OUTPUT_FILE) 55 | 56 | if __name__ == "__main__": 57 | # Example usage: python -m scripts.SQuADify_label2 /u/scr/veralin/DrQA/data/datasets hotpot_hop2_dev_v6.json 58 | # Example usage with query1 file: python -m scripts.SQuADify_label2 /u/scr/veralin/DrQA/data/datasets hotpot_hop2_dev_v7.json --query1_file hotpot_hop1_squad_dev-768-50-v8.preds 59 | parser = argparse.ArgumentParser(description='Convert Label2_v4 to SQuAD format') 60 | parser.add_argument('input_dir', help='The input directory ') 61 | parser.add_argument('input_file', help='The input json file') 62 | parser.add_argument('--query1_file', help='Include query1 in the question field.') 63 | 64 | args = parser.parse_args() 65 | 66 | parse_data(args.input_dir, args.input_file, query1_file=args.query1_file) 67 | -------------------------------------------------------------------------------- /scripts/query_generator_study.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | dev_file = "data/hotpotqa/hotpot_dev_distractor_v1.json" 4 | hop1_file = "data/hop1/hotpot_hop1_dev_v8.json" 5 | hop2_file = "data/hop2/hotpot_hop2_dev_v7.json" 6 | hop1_pred_file = "hotpotqa_dev_eval/hop1_squadified-32-128-500-v8.preds" 7 | hop2_pred_file = "hotpotqa_dev_eval/SQuAD_hop2_input-hop2_v3_gpu_30e.preds" 8 | 9 | with open(dev_file) as f: 10 | dev_data = json.load(f) 11 | 12 | with open(hop1_file) as f: 13 | hop1_data = json.load(f) 14 | 15 | with open(hop2_file) as f: 16 | hop2_data = json.load(f) 17 | 18 | with open(hop1_pred_file) as f: 19 | hop1_pred = json.load(f) 20 | 21 | with open(hop2_pred_file) as f: 22 | hop2_pred = json.load(f) 23 | 24 | for d, h1, h2 in zip(dev_data, hop1_data, hop2_data): 25 | assert d['_id'] == h1['_id'] == h2['_id'] 26 | id = d['_id'] 27 | 28 | support = list(set([x[0] for x in d['supporting_facts']])) 29 | print('\t'.join([d['question']] + support + [h1['label'], hop1_pred[id][0][0], h2['label'], hop2_pred[id][0][0]]).replace('"', '"')) 30 | -------------------------------------------------------------------------------- /scripts/query_labels_to_pred.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | with open(sys.argv[1]) as f: 5 | data = json.load(f) 6 | 7 | with open(sys.argv[2], 'w') as f: 8 | json.dump({x['_id']: [[x['label'], 1]] for x in data}, f) 9 | -------------------------------------------------------------------------------- /search/jvm.options: -------------------------------------------------------------------------------- 1 | ## JVM configuration 2 | 3 | ################################################################ 4 | ## IMPORTANT: JVM heap size 5 | ################################################################ 6 | ## 7 | ## You should always set the min and max JVM heap 8 | ## size to the same value. For example, to set 9 | ## the heap to 4 GB, set: 10 | ## 11 | ## -Xms4g 12 | ## -Xmx4g 13 | ## 14 | ## See https://www.elastic.co/guide/en/elasticsearch/reference/current/heap-size.html 15 | ## for more information 16 | ## 17 | ################################################################ 18 | 19 | # Xms represents the initial size of total heap space 20 | # Xmx represents the maximum size of total heap space 21 | 22 | -Xms4g 23 | -Xmx4g 24 | 25 | ################################################################ 26 | ## Expert settings 27 | ################################################################ 28 | ## 29 | ## All settings below this section are considered 30 | ## expert settings. Don't tamper with them unless 31 | ## you understand what you are doing 32 | ## 33 | ################################################################ 34 | 35 | ## GC configuration 36 | -XX:+UseConcMarkSweepGC 37 | -XX:CMSInitiatingOccupancyFraction=75 38 | -XX:+UseCMSInitiatingOccupancyOnly 39 | 40 | ## G1GC Configuration 41 | # NOTE: G1GC is only supported on JDK version 10 or later. 42 | # To use G1GC uncomment the lines below. 43 | # 10-:-XX:-UseConcMarkSweepGC 44 | # 10-:-XX:-UseCMSInitiatingOccupancyOnly 45 | # 10-:-XX:+UseG1GC 46 | # 10-:-XX:InitiatingHeapOccupancyPercent=75 47 | 48 | ## DNS cache policy 49 | # cache ttl in seconds for positive DNS lookups noting that this overrides the 50 | # JDK security property networkaddress.cache.ttl; set to -1 to cache forever 51 | -Des.networkaddress.cache.ttl=60 52 | # cache ttl in seconds for negative DNS lookups noting that this overrides the 53 | # JDK security property networkaddress.cache.negative ttl; set to -1 to cache 54 | # forever 55 | -Des.networkaddress.cache.negative.ttl=10 56 | 57 | ## optimizations 58 | 59 | # pre-touch memory pages used by the JVM during initialization 60 | -XX:+AlwaysPreTouch 61 | 62 | ## basic 63 | 64 | # explicitly set the stack size 65 | -Xss1m 66 | 67 | # set to headless, just in case 68 | -Djava.awt.headless=true 69 | 70 | # ensure UTF-8 encoding by default (e.g. filenames) 71 | -Dfile.encoding=UTF-8 72 | 73 | # use our provided JNA always versus the system one 74 | -Djna.nosys=true 75 | 76 | # turn off a JDK optimization that throws away stack traces for common 77 | # exceptions because stack traces are important for debugging 78 | -XX:-OmitStackTraceInFastThrow 79 | 80 | # flags to configure Netty 81 | -Dio.netty.noUnsafe=true 82 | -Dio.netty.noKeySetOptimization=true 83 | -Dio.netty.recycler.maxCapacityPerThread=0 84 | 85 | # log4j 2 86 | -Dlog4j.shutdownHookEnabled=false 87 | -Dlog4j2.disable.jmx=true 88 | 89 | -Djava.io.tmpdir=${ES_TMPDIR} 90 | 91 | ## heap dumps 92 | 93 | # generate a heap dump when an allocation from the Java heap fails 94 | # heap dumps are created in the working directory of the JVM 95 | -XX:+HeapDumpOnOutOfMemoryError 96 | 97 | # specify an alternative path for heap dumps; ensure the directory exists and 98 | # has sufficient space 99 | -XX:HeapDumpPath=data 100 | 101 | # specify an alternative path for JVM fatal error logs 102 | -XX:ErrorFile=logs/hs_err_pid%p.log 103 | 104 | ## JDK 8 GC logging 105 | 106 | 8:-XX:+PrintGCDetails 107 | 8:-XX:+PrintGCDateStamps 108 | 8:-XX:+PrintTenuringDistribution 109 | 8:-XX:+PrintGCApplicationStoppedTime 110 | 8:-Xloggc:logs/gc.log 111 | 8:-XX:+UseGCLogFileRotation 112 | 8:-XX:NumberOfGCLogFiles=32 113 | 8:-XX:GCLogFileSize=64m 114 | 115 | # JDK 9+ GC logging 116 | 9-:-Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,pid,tags:filecount=32,filesize=64m 117 | # due to internationalization enhancements in JDK 9 Elasticsearch need to set the provider to COMPAT otherwise 118 | # time/date parsing will break in an incompatible way for some date patterns and locals 119 | 9-:-Djava.locale.providers=COMPAT 120 | 121 | # temporary workaround for C2 bug with JDK 10 on hardware with AVX-512 122 | 10-:-XX:UseAVX=2 123 | -------------------------------------------------------------------------------- /search/search.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch 2 | import json 3 | import re 4 | 5 | from utils.constant import WIKIPEDIA_INDEX_NAME 6 | 7 | es = Elasticsearch(timeout=300) 8 | 9 | core_title_matcher = re.compile('([^()]+[^\s()])(?:\s*\(.+\))?') 10 | core_title_filter = lambda x: core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x 11 | 12 | def _extract_one(item, lazy=False): 13 | res = {k: item['_source'][k] for k in ['id', 'url', 'title', 'text', 'title_unescape']} 14 | res['_score'] = item['_score'] 15 | res['data_object'] = item['_source']['original_json'] if lazy else json.loads(item['_source']['original_json']) 16 | 17 | return res 18 | 19 | def _single_query_constructor(query, topn=50): 20 | return { 21 | "query": { 22 | "multi_match": { 23 | "query": query, 24 | "fields": ["title^1.25", "title_unescape^1.25", "text", "title_bigram^1.25", "title_unescape_bigram^1.25", "text_bigram"] 25 | } 26 | }, 27 | "size": topn 28 | } 29 | 30 | def single_text_query(query, topn=10, lazy=False, rerank_topn=50): 31 | body = _single_query_constructor(query, topn=max(topn, rerank_topn)) 32 | res = es.search(index=WIKIPEDIA_INDEX_NAME, doc_type='doc', body=json.dumps(body)) 33 | 34 | res = [_extract_one(x, lazy=lazy) for x in res['hits']['hits']] 35 | res = rerank_with_query(query, res)[:topn] 36 | 37 | return res 38 | 39 | def bulk_text_query(queries, topn=10, lazy=False, rerank_topn=50): 40 | body = ["{}\n" + json.dumps(_single_query_constructor(query, topn=max(topn, rerank_topn))) for query in queries] 41 | res = es.msearch(index=WIKIPEDIA_INDEX_NAME, doc_type='doc', body='\n'.join(body)) 42 | 43 | res = [[_extract_one(x, lazy=lazy) for x in r['hits']['hits']] for r in res['responses']] 44 | res = [rerank_with_query(query, results)[:topn] for query, results in zip(queries, res)] 45 | 46 | return res 47 | 48 | def rerank_with_query(query, results): 49 | def score_boost(item, query): 50 | score = item['_score'] 51 | core_title = core_title_filter(item['title_unescape']) 52 | if query.startswith('The ') or query.startswith('the '): 53 | query1 = query[4:] 54 | else: 55 | query1 = query 56 | if query == item['title_unescape'] or query1 == item['title_unescape']: 57 | score *= 1.5 58 | elif query.lower() == item['title_unescape'].lower() or query1.lower() == item['title_unescape'].lower(): 59 | score *= 1.2 60 | elif item['title'].lower() in query: 61 | score *= 1.1 62 | elif query == core_title or query1 == core_title: 63 | score *= 1.2 64 | elif query.lower() == core_title.lower() or query1.lower() == core_title.lower(): 65 | score *= 1.1 66 | elif core_title.lower() in query.lower(): 67 | score *= 1.05 68 | 69 | item['_score'] = score 70 | return item 71 | 72 | return list(sorted([score_boost(item, query) for item in results], key=lambda item: -item['_score'])) 73 | 74 | if __name__ == "__main__": 75 | print([x['title'] for x in single_text_query("In which city did Mark Zuckerberg go to college?")]) 76 | print([[y['title'] for y in x] for x in bulk_text_query(["In which city did Mark Zuckerberg go to college?"])]) 77 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | python -c 'import sys; print(sys.version_info[:])' 4 | echo "Please make sure you are running python version 3.6.X" 5 | 6 | echo "Installing required Python packages..." 7 | pip install -r requirements.txt 8 | 9 | echo "Setting up DrQA..." 10 | pushd DrQA 11 | pip install -r requirements.txt 12 | python setup.py develop 13 | ./install_corenlp.sh 14 | popd 15 | 16 | wget http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip 17 | mkdir DrQA/data/embeddings 18 | unzip glove.840B.300d.zip 19 | rm glove.840B.300d.zip 20 | mv glove.840B.300d.txt DrQA/data/embeddings/glove.840B.300d.txt 21 | 22 | echo "Downloading models..." 23 | bash scripts/download_golden_retriever_models.sh 24 | 25 | echo "Getting HotpotQA dataset..." 26 | bash scripts/download_hotpotqa.sh 27 | 28 | echo "Downloading Elasticsearch..." 29 | bash scripts/download_elastic_6.7.sh 30 | 31 | echo "NOTE: we set jvm options -Xms and -Xmx for Elasticsearch to be 4GB" 32 | echo "We suggest you set them as large as possible in: elasticsearch-6.7.0/config/jvm.options" 33 | cp search/jvm.options elasticsearch-6.7.0/config/jvm.options 34 | 35 | echo "Downloading wikipedia source documents..." 36 | bash scripts/download_processed_wiki.sh 37 | 38 | echo "Running Elasticsearch and indexing Wikipedia documents..." 39 | bash scripts/launch_elasticsearch_6.7.sh 40 | python -m scripts.index_processed_wiki 41 | 42 | echo "Download CoreNLP..." 43 | bash scripts/download_corenlp.sh 44 | 45 | echo "Setup BiDAF++..." 46 | pip install spacy 47 | pushd BiDAFpp 48 | ./download.sh 49 | popd 50 | 51 | -------------------------------------------------------------------------------- /utils/constant.py: -------------------------------------------------------------------------------- 1 | WIKIPEDIA_INDEX_NAME='wikipedia' 2 | 3 | QSEP = '' 4 | SEP = '' 5 | -------------------------------------------------------------------------------- /utils/corenlp.py: -------------------------------------------------------------------------------- 1 | from stanfordnlp.server import CoreNLPClient 2 | 3 | tokenizer_client = CoreNLPClient(annotators=['tokenize', 'ssplit'], timeout=30000, memory='16G', properties={'tokenize.ptb3Escaping': False, 'ssplit.eolonly': True, 'tokenize.options': "splitHyphenated=true"}, server_id='pengqi') 4 | def bulk_tokenize(text, return_offsets=False): 5 | ann = tokenizer_client.annotate('\n'.join(text)) 6 | 7 | if return_offsets: 8 | return [[token.originalText for token in sentence.token] for sentence in ann.sentence], [[(token.beginChar, token.endChar) for token in sentence.token] for sentence in ann.sentence] 9 | else: 10 | return [[token.originalText for token in sentence.token] for sentence in ann.sentence] 11 | 12 | if __name__ == "__main__": 13 | print(bulk_tokenize(['This is a test sentence.', ' This is another.'])) 14 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | 2 | def chunks(l, n): 3 | """Yield successive n-sized chunks from l.""" 4 | for i in range(0, len(l), n): 5 | yield l[i:i + n] 6 | 7 | def make_context(question, ir_results): 8 | """ 9 | Creates a single context string. 10 | 11 | :question: string 12 | :ir_results: list of dictionaries objects each of which 13 | should have 'title' and 'text' 14 | (e.g. each entry of result from bulk_text_query) 15 | """ 16 | return question + ' ' + concat_paragraphs(ir_results) 17 | 18 | def concat_paragraphs(ir_results): 19 | return ' '.join([f" {p['title']} {''.join(p['text'])}" for p in ir_results]) 20 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic I/O utilities 3 | """ 4 | 5 | import json 6 | 7 | def load_json_file(filename): 8 | with open(filename, 'r') as f: 9 | return json.load(f) 10 | 11 | def write_json_file(data, filename): 12 | with open(filename, 'w') as f: 13 | json.dump(data, f) 14 | -------------------------------------------------------------------------------- /utils/lcs.py: -------------------------------------------------------------------------------- 1 | def LCSubStr(X, Y): 2 | 3 | # Create a table to store lengths of 4 | # longest common suffixes of substrings. 5 | # Note that LCSuff[i][j] contains the 6 | # length of longest common suffix of 7 | # X[0...i-1] and Y[0...j-1]. The first 8 | # row and first column entries have no 9 | # logical meaning, they are used only 10 | # for simplicity of the program. 11 | 12 | # LCSuff is the table with zero 13 | # value initially in each cell 14 | m = len(X) 15 | n = len(Y) 16 | LCSuff = [[0 for k in range(n+1)] for l in range(m+1)] 17 | 18 | # To store the length of 19 | # longest common substring 20 | result = 0 21 | max_str = "" 22 | # Following steps to build 23 | # LCSuff[m+1][n+1] in bottom up fashion 24 | xidx = (0, 0) 25 | for i in range(m + 1): 26 | for j in range(n + 1): 27 | if (i == 0 or j == 0): 28 | LCSuff[i][j] = 0 29 | elif (X[i-1] == Y[j-1]): 30 | LCSuff[i][j] = LCSuff[i-1][j-1] + 1 31 | if LCSuff[i][j] > result: 32 | result = LCSuff[i][j] 33 | max_str = X[i - result:i] 34 | xidx = (i-result, i) 35 | else: 36 | LCSuff[i][j] = 0 37 | return result, max_str, xidx 38 | 39 | def LCS(a, b): 40 | # generate matrix of length of longest common subsequence for substrings of both words 41 | lengths = [[0] * (len(b)+1) for _ in range(len(a)+1)] 42 | for i, x in enumerate(a): 43 | for j, y in enumerate(b): 44 | if x == y: 45 | lengths[i+1][j+1] = lengths[i][j] + 1 46 | else: 47 | lengths[i+1][j+1] = max(lengths[i+1][j], lengths[i][j+1]) 48 | 49 | # read a substring from the matrix 50 | result = [] 51 | j = len(b) 52 | xst = -1 53 | xen = 0 54 | for i in range(1, len(a)+1): 55 | if lengths[i][j] != lengths[i-1][j]: 56 | result.append(a[i-1]) 57 | if xst < 0: 58 | xst = i-1 59 | xen = i 60 | 61 | return len(result), result, (xst, xen) 62 | --------------------------------------------------------------------------------