├── .gitignore ├── LICENSE ├── README.md ├── docqa ├── __init__.py ├── config.py ├── configurable.py ├── data_analysis │ ├── __init__.py │ ├── find_noisy_paragraph.py │ ├── prepare_squad_question_csv.py │ ├── show_rank_errors.py │ ├── show_sampling.py │ ├── show_squad.py │ ├── show_unk.py │ ├── squad_upper_bound.py │ ├── triviaqa_anwer_paragraph.py │ ├── triviaqa_stats.py │ └── visualize_confidence.py ├── data_processing │ ├── __init__.py │ ├── document_splitter.py │ ├── multi_paragraph_qa.py │ ├── preprocessed_corpus.py │ ├── qa_training_data.py │ ├── span_data.py │ ├── text_features.py │ ├── text_utils.py │ ├── wiki.py │ └── word_vectors.py ├── dataset.py ├── doc_qa_models.py ├── elmo │ ├── README.md │ ├── __init__.py │ ├── ablate_elmo_model.py │ ├── data.py │ ├── elmo.py │ ├── eval_elmo_minimal.py │ ├── lm_model.py │ ├── lm_qa_models.py │ ├── run_on_user_text.py │ └── show_weights.py ├── encoder.py ├── eval │ ├── __init__.py │ ├── eval_squad_minimal.py │ ├── ranked_scores.py │ ├── squad_eval.py │ ├── squad_full_document_eval.py │ └── triviaqa_full_document_eval.py ├── evaluator.py ├── model.py ├── model_dir.py ├── nn │ ├── __init__.py │ ├── attention.py │ ├── embedder.py │ ├── layers.py │ ├── ops.py │ ├── recurrent_layers.py │ ├── similarity_layers.py │ ├── span_prediction.py │ └── span_prediction_ops.py ├── scripts │ ├── ablate_squad.py │ ├── ablate_triviaqa.py │ ├── ablate_triviaqa_unfiltered.py │ ├── ablate_triviaqa_wiki.py │ ├── build_pruned_voc.py │ ├── continue.py │ ├── convert_to_cpu.py │ ├── run_on_user_documents.py │ ├── show_parameters.py │ └── train_bidaf.py ├── server │ ├── README.md │ ├── __init__.py │ ├── boilerpipe.jar │ ├── qa_system.py │ ├── requirements.txt │ ├── server.py │ ├── static │ │ ├── about.html │ │ └── index.html │ ├── web_searcher.py │ └── wiki.py ├── squad │ ├── __init__.py │ ├── build_squad_dataset.py │ ├── document_rd_corpus.py │ ├── squad_data.py │ ├── squad_document_qa.py │ └── squad_official_evaluation.py ├── test │ ├── __init__.py │ ├── test_batching.py │ ├── test_embedder.py │ ├── test_evaluator.py │ ├── test_lstm.py │ ├── test_span_prediction.py │ ├── test_splitter.py │ ├── test_ut_coordinates.py │ └── test_word_features.py ├── text_preprocessor.py ├── trainer.py ├── triviaqa │ ├── __init__.py │ ├── answer_detection.py │ ├── build_complete_vocab.py │ ├── build_span_corpus.py │ ├── evidence_corpus.py │ ├── read_data.py │ ├── training_data.py │ └── trivia_qa_eval.py └── utils.py ├── requirements-exact.txt └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | data 4 | out 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Document QA 2 | This repo contains code for our paper [Simple and Effective Multi-Paragraph Reading Comprehension](https://arxiv.org/abs/1710.10723). 3 | It can be used to train neural question answering models in tensorflow, 4 | and in particular for the case when we want to run the model over multiple paragraphs for 5 | each question. Code is included to train on the [TriviaQA](http://nlp.cs.washington.edu/triviaqa/) 6 | and [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) datasets. 7 | 8 | A demo of this work can be found at [documentqa.allenai.org](https://documentqa.allenai.org) 9 | 10 | Small forewarning, this is still much more of a research codebase then a library. 11 | we anticipate porting this work in [allennlp](https://github.com/allenai/allennlp) where it will 12 | enjoy a cleaner implementation and more stable support. 13 | 14 | 15 | ## Setup 16 | ### Dependencies 17 | We require python >= 3.5, tensorflow 1.3, and a handful of other supporting libraries. 18 | Tensorflow should be installed separately following the docs. To install the other dependencies use 19 | 20 | `pip install -r requirements.txt` 21 | 22 | The stopword corpus and punkt sentence tokenizer for nltk are needed and can be fetched with: 23 | 24 | `python -m nltk.downloader punkt stopwords` 25 | 26 | The easiest way to run this code is to use: 27 | 28 | ``export PYTHONPATH=${PYTHONPATH}:`pwd` `` 29 | 30 | ### Data 31 | By default, we expect source data to be stored in "\~/data" and preprocessed data to be 32 | stored in "./data". The expected file locations can be changed by altering config.py. 33 | 34 | 35 | #### Word Vectors 36 | The models we train use the common crawl 840 billion token GloVe word vectors from [here](https://nlp.stanford.edu/projects/glove/). 37 | They are expected to exist in "\~/data/glove/glove.840B.300d.txt" or "\~/data/glove/glove.840B.300d.txt.gz". 38 | 39 | For example: 40 | 41 | ``` 42 | mkdir -p ~/data 43 | mkdir -p ~/data/glove 44 | cd ~/data/glove 45 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip 46 | unzip glove.840B.300d.zip 47 | rm glove.840B.300d.zip 48 | ``` 49 | 50 | #### SQuAD Data 51 | Training or testing on SQuAD requires downloading the SQuAD train/dev files into ~/data/squad. 52 | This can be done as follows: 53 | 54 | ``` 55 | mkdir -p ~/data/squad 56 | cd ~/data/squad 57 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json 58 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json 59 | ``` 60 | 61 | then running: 62 | 63 | ``python docqa/squad/build_squad_dataset.py`` 64 | 65 | This builds pkl files of the tokenized data in "./data/squad" 66 | 67 | #### TriviaQA Data 68 | The raw TriviaQA data is expected to be unzipped in "\~/data/triviaqa". Training 69 | or testing in the unfiltered setting requires the unfiltered data to be 70 | download to "\~/data/triviaqa-unfiltered". 71 | 72 | ``` 73 | mkdir -p ~/data/triviaqa 74 | cd ~/data/triviaqa 75 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz 76 | tar xf triviaqa-rc.tar.gz 77 | rm triviaqa-rc.tar.gz 78 | 79 | cd ~/data 80 | wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz 81 | tar xf triviaqa-unfiltered.tar.gz 82 | rm triviaqa-unfiltered.tar.gz 83 | ``` 84 | 85 | To use TriviaQA we need to tokenize the evidence documents, which can be done by 86 | 87 | `python docqa/triviaqa/evidence_corpus.py` 88 | 89 | This can be slow, we support multi-processing 90 | 91 | `python docqa/triviaqa/evidence_corpus.py --n_processes 8` 92 | 93 | This builds evidence files in "./data/triviaqa/evidence" that are split into 94 | paragraphs, sentences, and tokens. Then we need to tokenize the questions and locate the relevant 95 | answers spans in each document. Run 96 | 97 | `python docqa/triviaqa/build_span_corpus.py {web|wiki|open} --n_processes 8` 98 | 99 | to build the desired set. This builds pkl files "./data/triviaqa/{web|wiki|open}" 100 | 101 | 102 | ## Training 103 | Once the data is in place our models can be trained by 104 | 105 | `python docqa/scripts/ablate_{triviaqa|squad|triviaqa_wiki|triviaqa_unfiltered}.py` 106 | 107 | 108 | See the help menu for these scripts for more details. Note that since we use the Cudnn RNN implementations, 109 | these models can only be trained on a GPU. We do provide a script for converting 110 | the (trained) models to CPU versions: 111 | 112 | `python docqa/scripts/convert_to_cpu.py` 113 | 114 | Modifying the hyper-parameters beyond the ablations requires building your own train script. 115 | 116 | ## Testing 117 | ### SQuAD 118 | Use "docqa/eval/squad_eval.py" to evaluate on paragraph-level (i.e., standard) SQuAD. For example: 119 | 120 | `python docqa/eval/squad_eval.py -o output.json -c dev /path/to/model/directory` 121 | 122 | "output.json" can be used with the official evaluation script, for example: 123 | 124 | `python docqa/squad/squad_official_evaluation.py ~/data/squad/dev-v1.1.json output.json` 125 | 126 | Use "docqa/eval/squad_full_document_eval.py" to evaluate on the document-level. For example 127 | 128 | `python docqa/eval/squad_full_document_eval.py -c dev /path/to/model/directory output.csv` 129 | 130 | This will store the per-paragraph results in output.csv, we can then run: 131 | 132 | `python docqa/eval/ranked_scores.py output.csv` 133 | 134 | to get ranked scores as more paragraphs are used. 135 | 136 | 137 | ### TriviaQA 138 | Use "docqa/eval/triviaqa_full_document_eval.py" to evaluate on TriviaQA datasets, like: 139 | 140 | `python docqa/eval/triviaqa_full_document_eval.py --n_processes 8 -c web-dev --tokens 800 -o question-output.json -p paragraph-output.csv /path/to/model/directory` 141 | 142 | Then the "question-output.json" can be used with the standard triviaqa evaluation [script](https://github.com/mandarjoshi90/triviaqa), 143 | the "paragraph-output.csv" contains per-paragraph output, we can run 144 | 145 | `python docqa/eval/ranked_scores.py paragraph-output.csv` 146 | 147 | to get ranked scores as more paragraphs as used for each question, or 148 | 149 | `python docqa/eval/ranked_scores.py --per_doc paragraph-output.csv` 150 | 151 | to get ranked scores as more paragraphs as used for each (question, document) pair, 152 | as should be done for TrivaQA web. 153 | 154 | 155 | ### User Input 156 | "docqa/scripts/run_on_user_documents.py" serves as a heavily commented example of how to run our models 157 | and pre-processing pipeline on other kinds of text. For example: 158 | 159 | `python docqa/scripts/run_on_user_documents.py /path/to/model/directory 160 | "Who wrote the satirical essay 'A Modest Proposal'?" 161 | ~/data/triviaqa/evidence/wikipedia/A_Modest_Proposal.txt 162 | ~/data/triviaqa/evidence/wikipedia/Jonathan_Swift.txt` 163 | 164 | ## Pre-Trained Models 165 | We have four pre-trained models 166 | 167 | 1. "squad" Our model trained on the standard SQuAD dataset, this model is listed on the SQuAD leaderboard 168 | as BiDAF + Self Attention 169 | 170 | 2. "squad-shared-norm" Our model trained on document-level SQuAD using the shared-norm approach. 171 | 172 | 3. "triviaqa-web-shared-norm" Our model trained on TriviaQA web with the shared-norm approach. This 173 | is the model we used to submit scores to the TriviaQA leader board. 174 | 175 | 4. "triviaqa-unfiltered-shared-norm" Our model trained on TriviaQA unfiltered with the shared-norm approach. 176 | This is the model that powers our demo. 177 | 178 | The models can be downloaded [here](https://drive.google.com/open?id=1Hj9WBQHVa__bqoD5RIOPu2qDpvfJQwjR) 179 | 180 | The models use the cuDNN implementation of GRUs by default, which means they can only be run on 181 | the GPU. We also have slower, but CPU compatible, versions [here](https://drive.google.com/open?id=1NRmb2YilnZOfyKULUnL7gu3HE5nT0sMy). -------------------------------------------------------------------------------- /docqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/__init__.py -------------------------------------------------------------------------------- /docqa/config.py: -------------------------------------------------------------------------------- 1 | from os.path import join, expanduser, dirname 2 | 3 | """ 4 | Global config options 5 | """ 6 | 7 | VEC_DIR = join(expanduser("~"), "data", "glove") 8 | SQUAD_SOURCE_DIR = join(expanduser("~"), "data", "squad") 9 | SQUAD_TRAIN = join(SQUAD_SOURCE_DIR, "train-v1.1.json") 10 | SQUAD_DEV = join(SQUAD_SOURCE_DIR, "dev-v1.1.json") 11 | 12 | 13 | TRIVIA_QA = join(expanduser("~"), "data", "triviaqa") 14 | TRIVIA_QA_UNFILTERED = join(expanduser("~"), "data", "triviaqa-unfiltered") 15 | LM_DIR = join(expanduser("~"), "data", "lm") 16 | DOCUMENT_READER_DB = join(expanduser("~"), "data", "doc-rd", "docs.db") 17 | 18 | 19 | CORPUS_DIR = join(dirname(dirname(__file__)), "data") 20 | -------------------------------------------------------------------------------- /docqa/configurable.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | from inspect import signature 4 | from warnings import warn 5 | 6 | import numpy as np 7 | from sklearn.base import BaseEstimator 8 | 9 | 10 | class Configuration(object): 11 | def __init__(self, name, version, params): 12 | if not isinstance(name, str): 13 | raise ValueError() 14 | if not isinstance(params, dict): 15 | raise ValueError() 16 | self.name = name 17 | self.version = version 18 | self.params = params 19 | 20 | def __str__(self): 21 | if len(self.params) == 0: 22 | return "%s-v%s" % (self.name, self.version) 23 | json_params = config_to_json(self.params) 24 | if len(json_params) < 200: 25 | return "%s-v%s: %s" % (self.name, self.version, json_params) 26 | else: 27 | return "%s-v%s {...}" % (self.name, self.version) 28 | 29 | def __eq__(self, other): 30 | return isinstance(other, Configuration) and \ 31 | self.name == other.name and \ 32 | self.version == other.version and \ 33 | self.params == other.params 34 | 35 | 36 | class Configurable(object): 37 | """ 38 | Configurable classes have names, versions, and a set of parameters that are either "simple" aka JSON serializable 39 | types or other Configurable objects. Configurable objects should also be serializable via pickle. 40 | Configurable classes are defined mainly to give us a human-readable way of reading of the `parameters` 41 | set for different objects and to attach version numbers to them. 42 | 43 | By default we follow the format sklearn uses for its `BaseEstimator` class, where parameters are automatically 44 | derived based on the constructor parameters. 45 | """ 46 | 47 | @classmethod 48 | def _get_param_names(cls): 49 | # fetch the constructor or the original constructor before 50 | init = cls.__init__ 51 | if init is object.__init__: 52 | # No explicit constructor to introspect 53 | return [] 54 | 55 | init_signature = signature(init) 56 | parameters = [p for p in init_signature.parameters.values() 57 | if p.name != 'self'] 58 | if any(p.kind == p.VAR_POSITIONAL for p in parameters): 59 | raise RuntimeError() 60 | return sorted([p.name for p in parameters]) 61 | 62 | @property 63 | def name(self): 64 | return self.__class__.__name__ 65 | 66 | @property 67 | def version(self): 68 | return 0 69 | 70 | def get_params(self): 71 | out = {} 72 | for key in self._get_param_names(): 73 | v = getattr(self, key, None) 74 | if isinstance(v, Configurable): 75 | out[key] = v.get_config() 76 | elif hasattr(v, "get_config"): # for keras objects 77 | out[key] = {"name": v.__class__.__name__, "config": v.get_config()} 78 | else: 79 | out[key] = v 80 | return out 81 | 82 | def get_config(self) -> Configuration: 83 | params = {k: describe(v) for k,v in self.get_params().items()} 84 | return Configuration(self.name, self.version, params) 85 | 86 | def __getstate__(self): 87 | state = dict(self.__dict__) 88 | if "version" in state: 89 | if state["version"] != self.version: 90 | raise RuntimeError() 91 | else: 92 | state["version"] = self.version 93 | return state 94 | 95 | def __setstate__(self, state): 96 | if "version" not in state: 97 | raise RuntimeError("Version should be in state (%s)" % self.__class__.__name__) 98 | if state["version"] != self.version: 99 | warn(("%s loaded with version %s, but class " + 100 | "version is %s") % (self.__class__.__name__, state["version"], self.version)) 101 | 102 | if "state" in state: 103 | self.__dict__ = state["state"] 104 | else: 105 | del state["version"] 106 | self.__dict__ = state 107 | 108 | 109 | def describe(obj): 110 | if isinstance(obj, Configurable): 111 | return obj.get_config() 112 | else: 113 | obj_type = type(obj) 114 | 115 | if obj_type in (list, set, frozenset, tuple): 116 | return obj_type([describe(e) for e in obj]) 117 | elif isinstance(obj, tuple): 118 | # Name tuple, convert to tuple 119 | return tuple(describe(e) for e in obj) 120 | elif obj_type in (dict, OrderedDict): 121 | output = OrderedDict() 122 | for k, v in obj.items(): 123 | if isinstance(k, Configurable): 124 | raise ValueError() 125 | output[k] = describe(v) 126 | return output 127 | else: 128 | return obj 129 | 130 | 131 | class EncodeDescription(json.JSONEncoder): 132 | """ Json encoder that encodes 'Configurable' objects as dictionaries and handles 133 | some numpy types. Note decoding this output will not reproduce the original input, 134 | for these types, this is only intended to be used to produce human readable output. 135 | '""" 136 | def default(self, obj): 137 | if isinstance(obj, np.integer): 138 | return int(obj) 139 | elif isinstance(obj, np.dtype): 140 | return str(obj) 141 | elif isinstance(obj, np.floating): 142 | return float(obj) 143 | elif isinstance(obj, np.bool_): 144 | return bool(obj) 145 | elif isinstance(obj, np.ndarray): 146 | return obj.tolist() 147 | elif isinstance(obj, BaseEstimator): # handle sklearn estimators 148 | return Configuration(obj.__class__.__name__, 0, obj.get_params()) 149 | elif isinstance(obj, Configuration): 150 | if "version" in obj.params or "name" in obj.params: 151 | raise ValueError() 152 | out = OrderedDict() 153 | out["name"] = obj.name 154 | if obj.version != 0: 155 | out["version"] = obj.version 156 | out.update(obj.params) 157 | return out 158 | elif isinstance(obj, Configurable): 159 | return obj.get_config() 160 | elif isinstance(obj, set): 161 | return sorted(obj) # Ensure deterministic order 162 | else: 163 | try: 164 | return super().default(obj) 165 | except TypeError: 166 | return str(obj) 167 | 168 | 169 | def config_to_json(data, indent=None): 170 | return json.dumps(data, sort_keys=False, cls=EncodeDescription, indent=indent) 171 | -------------------------------------------------------------------------------- /docqa/data_analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/data_analysis/__init__.py -------------------------------------------------------------------------------- /docqa/data_analysis/find_noisy_paragraph.py: -------------------------------------------------------------------------------- 1 | from docqa.data_processing.document_splitter import TopTfIdf, MergeParagraphs 2 | from docqa.data_processing.text_utils import NltkPlusStopWords 3 | from docqa.triviaqa.build_span_corpus import TriviaQaWebDataset 4 | from docqa.utils import flatten_iterable 5 | 6 | 7 | def main(): 8 | data = TriviaQaWebDataset() 9 | 10 | stop = NltkPlusStopWords() 11 | splitter = MergeParagraphs(400) 12 | selector = TopTfIdf(stop, 4) 13 | 14 | print("Loading data..") 15 | train = data.get_train() 16 | print("Start") 17 | for q in train: 18 | for doc in q.all_docs: 19 | if len(doc.answer_spans) > 3: 20 | text = splitter.split_annotated(data.evidence.get_document(doc.doc_id), doc.answer_spans) 21 | text = selector.prune(q.question, text) 22 | for para in text: 23 | if len(para.answer_spans) > 3: 24 | print(q.question) 25 | text = flatten_iterable(para.text) 26 | for s,e in para.answer_spans: 27 | text[s] = "{{{" + text[s] 28 | text[e] = text[e] + "}}}" 29 | print(" ".join(text)) 30 | input() 31 | 32 | if __name__ == "__main__": 33 | main() -------------------------------------------------------------------------------- /docqa/data_analysis/prepare_squad_question_csv.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import namedtuple 3 | from os.path import join 4 | 5 | import numpy as np 6 | 7 | from docqa.config import SQUAD_SOURCE_DIR 8 | 9 | ParagraphAndQuestion = namedtuple('ParagraphAndQuestion', ['article_title', 'paragraph', 'question', 'answers', 'question_id']) 10 | 11 | 12 | def init_google_csv(seed, n_samples, output_file, source): 13 | with open(source, 'r') as f: 14 | source_data = json.load(f) 15 | questions = [] 16 | for article_ix, article in enumerate(source_data['data']): 17 | title = article["title"] 18 | for para_ix, para in enumerate(article['paragraphs']): 19 | text = para["context"] 20 | for question_ix, question in enumerate(para['qas']): 21 | q_id = question['id'] 22 | questions.append(ParagraphAndQuestion(title, text, question["question"], 23 | question['answers'], q_id)) 24 | 25 | questions = sorted(questions, key=lambda x: x.question_id) 26 | np.random.RandomState(seed).shuffle(questions) 27 | 28 | with open(output_file, 'w') as f: 29 | f.write("question_id\tquestion\tanswer\tarticle_title\tcontext\n") 30 | for q in questions[:n_samples]: 31 | 32 | marked = q.paragraph 33 | for ans in q.answers[::-1]: 34 | start = ans["answer_start"] 35 | end = start + len(ans["text"]) 36 | marked = marked[:end] + "}}}" + marked[end:] 37 | marked = marked[:start] + "{{{" + marked[start:] 38 | 39 | f.write(q.question_id) 40 | f.write("\t") 41 | f.write("\"" + q.question + "\"") 42 | f.write("\t") 43 | f.write("\"" + q.answers[0]["text"] + "\"") 44 | f.write("\t") 45 | f.write("\"" + q.article_title + "\"") 46 | f.write("\t") 47 | f.write(marked) 48 | f.write("\n") 49 | 50 | if __name__ == "__main__": 51 | init_google_csv(0, 500, "/tmp/annotations.tsv", join(SQUAD_SOURCE_DIR, "train-v1.1.json")) 52 | -------------------------------------------------------------------------------- /docqa/data_analysis/show_sampling.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | from sklearn.feature_extraction.text import strip_accents_unicode 5 | from tqdm import tqdm 6 | 7 | from docqa.data_processing.document_splitter import MergeParagraphs, TopTfIdf, ShallowOpenWebRanker 8 | from docqa.data_processing.text_utils import NltkPlusStopWords 9 | from docqa.triviaqa.build_span_corpus import TriviaQaWebDataset, TriviaQaOpenDataset 10 | from docqa.utils import flatten_iterable 11 | 12 | 13 | class bcolors: 14 | CORRECT = '\033[94m' 15 | ERROR = '\033[91m' 16 | CYAN = "\033[96m" 17 | ENDC = '\033[0m' 18 | 19 | 20 | def show_stats(): 21 | splitter = MergeParagraphs(400) 22 | stop = NltkPlusStopWords(True) 23 | ranker = TopTfIdf(stop, 6) 24 | 25 | corpus = TriviaQaWebDataset() 26 | train = corpus.get_train() 27 | points = flatten_iterable([(q, d) for d in q.all_docs] for q in train) 28 | np.random.shuffle(points) 29 | 30 | counts = np.zeros(6) 31 | answers = np.zeros(6) 32 | n_answers = [] 33 | 34 | points = points[:1000] 35 | for q, d in tqdm(points): 36 | doc = corpus.evidence.get_document(d.doc_id) 37 | doc = splitter.split_annotated(doc, d.answer_spans) 38 | ranked = ranker.prune(q.question, doc) 39 | counts[:len(ranked)] += 1 40 | for i, para in enumerate(ranked): 41 | if len(para.answer_spans) > 0: 42 | answers[i] += 1 43 | n_answers.append(tuple(i for i, x in enumerate(ranked) if len(x.answer_spans) > 0)) 44 | 45 | print(answers/counts) 46 | c = Counter() 47 | other = 0 48 | for tup in n_answers: 49 | if len(tup) <= 2: 50 | c[tup] += 1 51 | else: 52 | other += 1 53 | 54 | for p in sorted(c.keys()): 55 | print(p, c.get(p)/len(points)) 56 | print(other/len(points)) 57 | 58 | 59 | def show_web_paragraphs(): 60 | splitter = MergeParagraphs(400) 61 | stop = NltkPlusStopWords(True) 62 | ranker = TopTfIdf(stop, 6) 63 | stop_words = stop.words 64 | 65 | corpus = TriviaQaWebDataset() 66 | train = corpus.get_train() 67 | points = flatten_iterable([(q, d) for d in q.all_docs] for q in train) 68 | np.random.shuffle(points) 69 | 70 | for q, d in points: 71 | q_words = {strip_accents_unicode(w.lower()) for w in q.question} 72 | q_words = {x for x in q_words if x not in stop_words} 73 | 74 | doc = corpus.evidence.get_document(d.doc_id) 75 | doc = splitter.split_annotated(doc, d.answer_spans) 76 | ranked = ranker.dists(q.question, doc) 77 | if len(ranked) < 2 or len(ranked[1][0].answer_spans) == 0: 78 | continue 79 | print(" ".join(q.question)) 80 | print(q.answer.all_answers) 81 | for i, (para, dist) in enumerate(ranked[0:2]): 82 | text = flatten_iterable(para.text) 83 | print("Start=%d, Rank=%d, Dist=%.4f" % (para.start, i, dist)) 84 | if len(para.answer_spans) == 0: 85 | continue 86 | for s, e in para.answer_spans: 87 | text[s] = bcolors.CYAN + text[s] 88 | text[e] = text[e] + bcolors.ENDC 89 | for i, w in enumerate(text): 90 | if strip_accents_unicode(w.lower()) in q_words: 91 | text[i] = bcolors.ERROR + text[i] + bcolors.ENDC 92 | print(" ".join(text)) 93 | input() 94 | 95 | 96 | def show_open_paragraphs(start: int, end: int): 97 | splitter = MergeParagraphs(400) 98 | stop = NltkPlusStopWords(True) 99 | ranker = ShallowOpenWebRanker(6) 100 | stop_words = stop.words 101 | 102 | print("Loading train") 103 | corpus = TriviaQaOpenDataset() 104 | train = corpus.get_dev() 105 | np.random.shuffle(train) 106 | 107 | for q in train: 108 | q_words = {strip_accents_unicode(w.lower()) for w in q.question} 109 | q_words = {x for x in q_words if x not in stop_words} 110 | 111 | para = [] 112 | for d in q.all_docs: 113 | doc = corpus.evidence.get_document(d.doc_id) 114 | para += splitter.split_annotated(doc, d.answer_spans) 115 | 116 | ranked = ranker.prune(q.question, para) 117 | if len(ranked) < start: 118 | continue 119 | ranked = ranked[start:end] 120 | 121 | print(" ".join(q.question)) 122 | print(q.answer.all_answers) 123 | for i in range(start, end): 124 | para = ranked[i] 125 | text = flatten_iterable(para.text) 126 | print("Start=%d, Rank=%d" % (para.start, i)) 127 | if len(para.answer_spans) == 0: 128 | # print("No Answer!") 129 | continue 130 | for s, e in para.answer_spans: 131 | text[s] = bcolors.CYAN + text[s] 132 | text[e] = text[e] + bcolors.ENDC 133 | for i, w in enumerate(text): 134 | if strip_accents_unicode(w.lower()) in q_words: 135 | text[i] = bcolors.ERROR + text[i] + bcolors.ENDC 136 | print(" ".join(text)) 137 | input() 138 | 139 | 140 | if __name__ == "__main__": 141 | show_open_paragraphs(0, 4) -------------------------------------------------------------------------------- /docqa/data_analysis/show_squad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from docqa.squad.squad_data import SquadCorpus, split_docs 3 | 4 | 5 | def main(): 6 | data = split_docs(SquadCorpus().get_train()) 7 | np.random.shuffle(data) 8 | for point in data: 9 | print(" ".join(point.question)) 10 | 11 | 12 | if __name__ == "__main__": 13 | main() -------------------------------------------------------------------------------- /docqa/data_analysis/show_unk.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import Counter, defaultdict 3 | 4 | import numpy as np 5 | 6 | from docqa.data_processing.text_features import BasicWordFeatures 7 | from docqa.squad.squad_data import SquadCorpus 8 | 9 | 10 | def show_unk(corpus: SquadCorpus, vec_name: str, 11 | context: bool=True, question: bool=True): 12 | vecs = corpus.get_pruned_word_vecs(vec_name) 13 | docs = corpus.get_train() 14 | 15 | lower_unk = Counter() 16 | unk = Counter() 17 | 18 | for doc in docs: 19 | for para in doc.paragraphs: 20 | if context: 21 | for sent in para.text: 22 | for word in sent: 23 | if word not in vecs: 24 | unk[word] += 1 25 | word = word.lower() 26 | if word not in vecs: 27 | lower_unk[word] += 1 28 | if question: 29 | for question in para.questions: 30 | for word in question.words: 31 | if word not in vecs: 32 | unk[word] += 1 33 | word = word.lower() 34 | if word not in vecs: 35 | lower_unk[word] += 1 36 | 37 | print("\n".join("%s: %d" % (k,v) for k,v in lower_unk.most_common())) 38 | 39 | 40 | def show_features(corpus: SquadCorpus, vec_name): 41 | print("Loading train docs") 42 | data = corpus.get_train() 43 | np.random.shuffle(data) 44 | data = data[:100] 45 | 46 | print("Loading vectors") 47 | vecs = corpus.get_pruned_word_vecs(vec_name) 48 | fe = BasicWordFeatures() 49 | 50 | grouped_by_features = defaultdict(Counter) 51 | 52 | print("start") 53 | 54 | for doc in data: 55 | paragraphs = list(doc.paragraphs) 56 | np.random.shuffle(paragraphs) 57 | for para in paragraphs: 58 | sentences = list(para.text) + [x.words for x in para.questions] 59 | np.random.shuffle(sentences) 60 | for words in sentences: 61 | for i, word in enumerate(words): 62 | if word.lower() not in vecs: 63 | x = fe.get_word_features(word) 64 | for i, val in enumerate(x): 65 | if val > 0: 66 | grouped_by_features[i][word] += 1 67 | 68 | for i in sorted(grouped_by_features.keys()): 69 | name = BasicWordFeatures.features_names[i] 70 | if name in ["Len"]: 71 | continue 72 | vals = grouped_by_features[i] 73 | print() 74 | print("*"*30) 75 | print("%s-%d %d (%d)" % (name, i, len(vals), sum(vals.values()))) 76 | for k,v in vals.most_common(30): 77 | print("%s: %d" % (k, v)) 78 | 79 | 80 | def show_nums(corpus: SquadCorpus): 81 | n_regex = re.compile(".*[0-9].*") 82 | data = corpus.get_train() 83 | np.random.shuffle(data) 84 | 85 | for doc in data: 86 | paragraphs = list(doc.paragraphs) 87 | np.random.shuffle(paragraphs) 88 | for para in paragraphs: 89 | sentences = list(para.context) + [x.words for x in para.questions] 90 | np.random.shuffle(sentences) 91 | for words in sentences: 92 | for i, word in enumerate(words): 93 | if n_regex.match(word) is not None: 94 | print(word) 95 | 96 | 97 | def show_in_context_unks(corpus: SquadCorpus, vec_name): 98 | data = corpus.get_train() 99 | np.random.shuffle(data) 100 | vecs = corpus.get_pruned_word_vecs(vec_name) 101 | 102 | for doc in data: 103 | paragraphs = list(doc.paragraphs) 104 | np.random.shuffle(paragraphs) 105 | for para in paragraphs: 106 | sentences = list(para.text) + [x.words for x in para.questions] 107 | np.random.shuffle(sentences) 108 | for words in sentences: 109 | for i, word in enumerate(words): 110 | if word.lower() not in vecs: 111 | words[i] = "{{{" + word + "}}}" 112 | print(" ".join(words[max(0,i-10):min(len(words),i+10)])) 113 | words[i] = word 114 | 115 | 116 | def main(): 117 | show_unk(SquadCorpus(), "glove.840B.300d") 118 | # show_unk(SpanCorpus("squad"), "glove.6B.100d") 119 | # show_nums(SpanCorpus("squad")) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() -------------------------------------------------------------------------------- /docqa/data_analysis/squad_upper_bound.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from docqa.squad.squad_data import SquadCorpus 4 | from docqa.squad.squad_official_evaluation import f1_score 5 | from docqa.utils import flatten_iterable 6 | 7 | """ 8 | Explore errors caused by our tokenziation and/or token cleaning 9 | """ 10 | 11 | 12 | def main(): 13 | data = SquadCorpus() 14 | 15 | string_f1 = 0 16 | mapped_string_f1 = 0 17 | 18 | docs = data.get_train() 19 | n_questions = 0 20 | 21 | for doc in tqdm(docs): 22 | for para in doc.paragraphs: 23 | words = flatten_iterable(para.text) 24 | for question in para.questions: 25 | n_questions += 1 26 | span_answer = question.answer[0] 27 | span_str = " ".join(words[span_answer.para_word_start:span_answer.para_word_end+1]) 28 | raw_answer = span_answer.text 29 | mapped_str = para.get_original_text(span_answer.para_word_start, span_answer.para_word_end) 30 | 31 | string_f1 += f1_score(raw_answer, span_str) 32 | mapped_string_f1 += f1_score(raw_answer, mapped_str) 33 | 34 | print(string_f1 / n_questions) 35 | print(mapped_string_f1 / n_questions) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() -------------------------------------------------------------------------------- /docqa/data_analysis/triviaqa_anwer_paragraph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from docqa.data_processing.document_splitter import MergeParagraphs, TopTfIdf 4 | from docqa.data_processing.preprocessed_corpus import preprocess_par 5 | from docqa.data_processing.text_utils import NltkPlusStopWords 6 | from docqa.triviaqa.build_span_corpus import TriviaQaOpenDataset 7 | from docqa.triviaqa.training_data import ExtractMultiParagraphsPerQuestion 8 | 9 | 10 | def main(): 11 | data = TriviaQaOpenDataset() 12 | # data = TriviaQaWebDataset() 13 | print("Loading...") 14 | all_questions = data.get_dev() 15 | 16 | questions = [q for q in all_questions if any(len(x.answer_spans) > 0 for x in q.all_docs)] 17 | print("%d/%d (%.4f) have an answer" % (len(questions), len(all_questions), len(questions)/len(all_questions))) 18 | 19 | np.random.shuffle(questions) 20 | 21 | pre = ExtractMultiParagraphsPerQuestion(MergeParagraphs(400), 22 | TopTfIdf(NltkPlusStopWords(), 20), 23 | require_an_answer=False) 24 | print("Done") 25 | 26 | out = preprocess_par(questions[:2000], data.evidence, pre, 2, 1000) 27 | 28 | n_counts = np.zeros(20) 29 | n_any = np.zeros(20) 30 | n_any_all = np.zeros(20) 31 | 32 | for q in out.data: 33 | for i, p in enumerate(q.paragraphs): 34 | n_counts[i] += 1 35 | n_any[i] += len(p.answer_spans) > 0 36 | 37 | for i, p in enumerate(q.paragraphs): 38 | if len(p.answer_spans) > 0: 39 | n_any_all[i:] += 1 40 | break 41 | 42 | print(n_any_all / out.true_len) 43 | print(n_any/n_counts) 44 | print(n_counts) 45 | 46 | 47 | 48 | if __name__ == "__main__": 49 | main() -------------------------------------------------------------------------------- /docqa/data_analysis/triviaqa_stats.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from docqa.data_processing.document_splitter import MergeParagraphs, ContainsQuestionWord, DocumentSplitter, \ 7 | ExtractedParagraphWithAnswers 8 | from docqa.data_processing.text_utils import NltkPlusStopWords 9 | from docqa.triviaqa.build_span_corpus import TriviaQaWebDataset 10 | from docqa.triviaqa.read_data import TriviaQaQuestion 11 | from docqa.utils import flatten_iterable 12 | 13 | 14 | def basic_stats(corpus): 15 | train = corpus.get_train() 16 | n_docs = sum(len(q.all_docs) for q in train) 17 | n_with_answer = sum(sum(len(doc.answer_spans) > 0 for doc in q.all_docs) for q in train) 18 | print(n_docs) 19 | print(n_with_answer) 20 | 21 | 22 | def paragraph_stats(corpus, splitter: DocumentSplitter, sample): 23 | stop = NltkPlusStopWords(punctuation=True).words 24 | 25 | data = corpus.get_dev() 26 | pairs = flatten_iterable([(q, doc) for doc in q.all_docs] for q in data) 27 | data = [pairs[i] for i in np.random.choice(np.arange(0, len(pairs)), sample, replace=False)] 28 | 29 | word_matches = Counter() 30 | n_para = [] 31 | n_answers = [] 32 | n_question_words = [] 33 | 34 | for q,doc in data: 35 | if len(doc.answer_spans) == 0: 36 | continue 37 | q_words = set(x.lower() for x in q.question) 38 | q_words -= stop 39 | # q_words = set(norm.normalize(w) for w in q_words) 40 | 41 | text = corpus.evidence.get_document(doc.doc_id) 42 | para = splitter.split_annotated(text, doc.answer_spans) 43 | n_para.append(len(para)) 44 | n_answers += [len(x.answer_spans) for x in para] 45 | 46 | for x in para: 47 | match_set = set() 48 | n_matches = 0 49 | text = flatten_iterable(x.text) 50 | for word in text: 51 | word = word.lower() 52 | if word in q_words: 53 | n_matches += 1 54 | match_set.add(word) 55 | if len(match_set) == 0 and len(x.answer_spans) > 0: 56 | print_paragraph(q, x) 57 | input() 58 | word_matches.update(match_set) 59 | n_question_words.append(n_matches) 60 | 61 | n_answers = np.array(n_answers) 62 | n_question_words = np.array(n_question_words) 63 | any_answers = n_answers > 0 64 | any_question_word = n_question_words > 0 65 | 66 | total_para = len(any_answers) 67 | total_q = len(n_para) 68 | 69 | no_quesiton_and_answer = any_answers[np.logical_not(any_question_word)] 70 | 71 | print("%d/%d (%.4f) pairs have an answer" % (total_q, len(data), total_q/len(data))) 72 | print("%d para in %d questions (av %.4f)" % (sum(n_para), total_q, sum(n_para)/total_q)) 73 | print("%d/%d (%.4f) paragraphs have answers" % (any_answers.sum(), total_para, any_answers.mean())) 74 | print("%d/%d (%.4f) paragraphs have question word" % (any_question_word.sum(), total_para, any_question_word.mean())) 75 | print("%d/%d (%.4f) no question words have answers" % (no_quesiton_and_answer.sum(), 76 | len(no_quesiton_and_answer), 77 | no_quesiton_and_answer.mean())) 78 | # for k,v in word_matches.most_common(100): 79 | # print("%s: %d" % (k, v)) 80 | 81 | 82 | def print_paragraph(question: TriviaQaQuestion, para: ExtractedParagraphWithAnswers): 83 | print(" ".join(question.question)) 84 | print(question.answer.all_answers) 85 | context = flatten_iterable(para.text) 86 | for s,e in para.answer_spans: 87 | context[s] = "{{{" + context[s] 88 | context[e] = context[e] + "}}}" 89 | print(" ".join(context)) 90 | 91 | 92 | def print_questions(question, answers, context, answer_span): 93 | print(" ".join(question)) 94 | print(answers) 95 | context = flatten_iterable(context) 96 | for s,e in answer_span: 97 | context[s] = "{{{" + context[s] 98 | context[e] = context[e] + "}}}" 99 | print(" ".join(context)) 100 | 101 | 102 | def contains_question_word(): 103 | data = TriviaQaWebDataset() 104 | stop = NltkPlusStopWords(punctuation=True).words 105 | doc_filter = ContainsQuestionWord(NltkPlusStopWords(punctuation=True)) 106 | splits = MergeParagraphs(400) 107 | # splits = Truncate(400) 108 | questions = data.get_dev() 109 | pairs = flatten_iterable([(q, doc) for doc in q.all_docs] for q in questions) 110 | pairs.sort(key=lambda x: (x[0].question_id, x[1].doc_id)) 111 | np.random.RandomState(0).shuffle(questions) 112 | has_token = 0 113 | total = 0 114 | used = Counter() 115 | 116 | for q, doc in tqdm(pairs[:1000]): 117 | text = data.evidence.get_document(doc.doc_id, splits.reads_first_n) 118 | q_tokens = set(x.lower() for x in q.question) 119 | q_tokens -= stop 120 | for para in splits.split_annotated(text, doc.answer_spans): 121 | # if para.start == 0: 122 | # continue 123 | if len(para.answer_spans) == 0: 124 | continue 125 | if any(x.lower() in q_tokens for x in flatten_iterable(para.text)): 126 | has_token += 1 127 | for x in flatten_iterable(para.text): 128 | if x in q_tokens: 129 | used[x] += 1 130 | # else: 131 | # print_questions(q.question, q.answer.all_answers, para.text, para.answer_spans) 132 | # input() 133 | total += 1 134 | for k,v in used.most_common(200): 135 | print("%s: %d" % (k, v)) 136 | print(has_token/total) 137 | 138 | 139 | if __name__ == "__main__": 140 | paragraph_stats(TriviaQaWebDataset(), MergeParagraphs(400), 1000) 141 | -------------------------------------------------------------------------------- /docqa/data_analysis/visualize_confidence.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import basename 3 | 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("answer_files", nargs="+") 11 | args = parser.parse_args() 12 | 13 | dfs = {} 14 | for x in args.answer_files: 15 | name = basename(x) 16 | name = name[:name.rfind(".")] 17 | dfs[name] = pd.read_csv(x) 18 | 19 | for k, df in dfs.items(): 20 | df = df[df["n_answers"] > 0] 21 | plt.hist(df["predicted_score"] - df["predicted_score"].mean(), 50, label=k, alpha=0.5) 22 | 23 | plt.legend() 24 | plt.show() 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /docqa/data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/data_processing/__init__.py -------------------------------------------------------------------------------- /docqa/data_processing/text_features.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | from nltk import Counter, WordNetLemmatizer 5 | 6 | from docqa.configurable import Configurable 7 | 8 | """ 9 | Adding classic/shallow text features, I have only done shallow experiments with these 10 | and not found them to be of much use 11 | """ 12 | 13 | 14 | any_num_regex = re.compile("^.*[\d].*$") 15 | int_prefixes = "s|st|th|nd|rd" 16 | all_prefixes = "km|m|v|K|b|bn|billion|k|million|th\+" 17 | careful_num_regex = re.compile("^\+?" 18 | "(\d{1,3}(,\d{3})*|\d+|(?=\.))" 19 | "(?:(\.\d+)?(?P%s)?|(?P%s)?)\+?$" % (all_prefixes, int_prefixes)) 20 | 21 | 22 | def is_number(token): 23 | match = careful_num_regex.fullmatch(token) 24 | if match is None: 25 | return None 26 | p1 = match.group("p1") 27 | p2 = match.group("p2") 28 | if p1 is not None: 29 | return p1 30 | elif p2 is not None: 31 | return p2 32 | else: 33 | return "" 34 | 35 | 36 | class QaTextFeautrizer(Configurable): 37 | 38 | def n_context_features(self): 39 | raise NotImplementedError() 40 | 41 | def n_question_features(self): 42 | raise NotImplementedError() 43 | 44 | def get_features(self, question, context): 45 | """ 46 | return arrays of shape (n_question_words, feature_dim) (n_context_words, feature_dim) 47 | """ 48 | raise NotImplementedError() 49 | 50 | 51 | class BasicWordFeatures(QaTextFeautrizer): 52 | features_names = ["Num", "NumPrefix", "NumExp", "AnyNum", "Punct", 53 | "Cap", "Upper", "Alpha", "NonEng", "Len"] 54 | 55 | def __init__(self): 56 | self.any_num_regex = re.compile("^.*\d.*$") 57 | self.num_exp = re.compile("^[\d+x\-/\\\=\u2013,:\W]*$") 58 | self.punc_regex = re.compile("^\W+$") 59 | self.alpha = re.compile("^[a-z]+$") 60 | self.any_non_english = re.compile(".*[^a-zA-Z0-9\W].*") 61 | self.non_english = re.compile("^[^a-zA-Z0-9\W]+$") 62 | self._feature_cache = {} 63 | 64 | def get_word_features(self, word): 65 | if word not in self._feature_cache: 66 | num_prefix = is_number(word) 67 | non_eng = self.non_english.match(word) is not None 68 | punc = self.punc_regex.match(word) is not None 69 | features = np.array([ 70 | num_prefix is not None, 71 | num_prefix is not None and num_prefix != "", 72 | self.num_exp.match(word) is not None and num_prefix is None and not punc, 73 | self.any_num_regex.match(word) is not None and not punc, 74 | punc, 75 | word[0].isupper() and word[1:].islower() and not non_eng, 76 | word.isupper() and not non_eng, 77 | self.alpha.match(word) is not None, 78 | non_eng, 79 | np.log(len(word)) 80 | ]) 81 | self._feature_cache[word] = features 82 | return features 83 | return self._feature_cache[word] 84 | 85 | @property 86 | def n_features(self): 87 | return 10 88 | 89 | def n_context_features(self): 90 | return self.n_features 91 | 92 | def n_question_features(self): 93 | return self.n_features 94 | 95 | def get_sentence_features(self, sent): 96 | features = np.zeros((len(sent), self.n_features)) 97 | for i, word in enumerate(sent): 98 | features[i, :self.n_features] = self.get_word_features(word) 99 | return features 100 | 101 | def get_features(self, question, context): 102 | return self.get_sentence_features(question), self.get_sentence_features(context) 103 | 104 | 105 | def extract_year(token): 106 | ends_with_s = False 107 | if token[-1] == "s": 108 | token = token[:-1] 109 | ends_with_s = True 110 | try: 111 | val = int(token) 112 | if val < 100 and val % 10 == 0 and ends_with_s: 113 | return 1900 + val 114 | if 1000 <= val <= 2017: 115 | return val 116 | return None 117 | except ValueError: 118 | return None 119 | 120 | 121 | class MatchWordFeatures(QaTextFeautrizer): 122 | def __init__(self, require_unique_match, lemmatizer="word_net", 123 | empty_question_features=False, stop_words=None): 124 | self.lemmatizer = lemmatizer 125 | self.stop_words = stop_words 126 | self.empty_question_features = empty_question_features 127 | if lemmatizer == "word_net": 128 | self._lemmatizer = WordNetLemmatizer() 129 | else: 130 | raise ValueError() 131 | self._cache = {} 132 | self.require_unique_match = require_unique_match 133 | 134 | def n_context_features(self): 135 | return 3 136 | 137 | def n_question_features(self): 138 | return 3 if self.empty_question_features else 0 139 | 140 | def lemmatize_word(self, word): 141 | cur = self._cache.get(word) 142 | if cur is None: 143 | cur = self._lemmatizer.lemmatize(word) 144 | self._cache[word] = cur 145 | return cur 146 | 147 | def get_features(self, question, context): 148 | stop = set() if self.stop_words is None else self.stop_words.words 149 | context_features = np.zeros((len(context), 3)) 150 | 151 | if not self.require_unique_match: 152 | question_words = set(x for x in question if x.lower() not in stop) 153 | quesiton_words_lower = set(x.lower() for x in question) 154 | quesiton_words_stem = set(self.lemmatize_word(x) for x in quesiton_words_lower) 155 | else: 156 | question_words = set(k for k,v in Counter(question).items() if v == 1) 157 | quesiton_words_lower = set(k for k,v in Counter(x.lower() for x in question_words).items() if v == 1) 158 | quesiton_words_stem = set(k for k, v in Counter(self.lemmatize_word(x) for x 159 | in quesiton_words_lower).items() if v == 1) 160 | 161 | for i, word in enumerate(context): 162 | if word in question_words: 163 | context_features[i][:3] = 1 164 | elif word.lower() in quesiton_words_lower: 165 | context_features[i][:2] = 1 166 | elif self._lemmatizer.lemmatize(word) in quesiton_words_stem: 167 | context_features[i][2] = 1 168 | 169 | if self.empty_question_features: 170 | return np.zeros((len(question), 3)), context_features 171 | else: 172 | return np.zeros((len(question), 0)), context_features 173 | 174 | def __setstate__(self, state): 175 | self.__init__(**state) 176 | 177 | def __getstate__(self): 178 | state = dict(self.__dict__) 179 | del state["_cache"] 180 | del state["_lemmatizer"] 181 | return state 182 | 183 | -------------------------------------------------------------------------------- /docqa/data_processing/word_vectors.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle 3 | from os.path import join, exists 4 | from typing import Iterable, Optional 5 | 6 | import numpy as np 7 | 8 | from docqa.config import VEC_DIR 9 | 10 | 11 | """ Loading words vectors """ 12 | 13 | 14 | def load_word_vectors(vec_name: str, vocab: Optional[Iterable[str]]=None, is_path=False): 15 | if not is_path: 16 | vec_path = join(VEC_DIR, vec_name) 17 | else: 18 | vec_path = vec_name 19 | if exists(vec_path + ".txt"): 20 | vec_path = vec_path + ".txt" 21 | elif exists(vec_path + ".txt.gz"): 22 | vec_path = vec_path + ".txt.gz" 23 | elif exists(vec_path + ".pkl"): 24 | vec_path = vec_path + ".pkl" 25 | else: 26 | raise ValueError("No file found for vectors %s" % vec_name) 27 | return load_word_vector_file(vec_path, vocab) 28 | 29 | 30 | def load_word_vector_file(vec_path: str, vocab: Optional[Iterable[str]] = None): 31 | if vocab is not None: 32 | vocab = set(x.lower() for x in vocab) 33 | 34 | # notes some of the large vec files produce utf-8 errors for some words, just skip them 35 | if vec_path.endswith(".pkl"): 36 | with open(vec_path, "rb") as f: 37 | return pickle.load(f) 38 | elif vec_path.endswith(".txt.gz"): 39 | handle = lambda x: gzip.open(x, 'r', encoding='utf-8', errors='ignore') 40 | else: 41 | handle = lambda x: open(x, 'r', encoding='utf-8', errors='ignore') 42 | 43 | pruned_dict = {} 44 | with handle(vec_path) as fh: 45 | for line in fh: 46 | word_ix = line.find(" ") 47 | word = line[:word_ix] 48 | if (vocab is None) or (word.lower() in vocab): 49 | pruned_dict[word] = np.array([float(x) for x in line[word_ix + 1:-1].split(" ")], dtype=np.float32) 50 | return pruned_dict 51 | -------------------------------------------------------------------------------- /docqa/elmo/README.md: -------------------------------------------------------------------------------- 1 | ## ELMo 2 | This contains the (pretty rough) code for running our SQuAD model with ELMo weights 3 | 4 | To train or test the model you need the pre-trained ELMo model. It can be downloaded 5 | [here](https://docs.google.com/uc?export=download&id=1vXsiRHxJqsj3HLesUIet0x4Yrjw0S54D). 6 | Then unzip it and store in ~/data/lm (or change config.py to alter its expected location). For example: 7 | 8 | ``` 9 | mkdir -p ~/data/lm 10 | cd ~/data/lm 11 | mv ~/Download/squad-context-concat-skip.tar.gz . 12 | tar -xzf squad-context-concat-skip.tar.gz 13 | rm squad-context-concat-skip.tar.gz 14 | ``` 15 | 16 | ### Training 17 | Now the model can be trained using: 18 | 19 | `python docqa/elmo/ablate_elmo_model.py` 20 | 21 | ### Testing 22 | The model can be tested on the dev set using our standard evaluation script: 23 | 24 | `python docqa/eval/squad_eval.py -o output.json -c dev /path/to/model/directory` 25 | 26 | Note by default the language model will use word vectors 27 | that were pre-computed for the SQuAD corpus. Running it on 28 | other kinds of data takes a bit more work, 29 | see "docqa/elmo/run_on_user_text.py" for an example with comments. 30 | Using the script we can run the model on user-defined input, for example: 31 | 32 | `docqa/elmo/run_on_user_text.py /path/to/model/directory "What color are apples" "Apples are blue"` 33 | 34 | 35 | ### Pre-Trained Model 36 | The pre-trained model we used for SQuAD can be downloaded [here](https://drive.google.com/open?id=1GuKh2TJFF6FIhiFpoFslJ1WPlGAxAISt) 37 | 38 | 39 | ### Codalab 40 | The codalab worksheet we use to get our SQuAD test scores is 41 | [here](https://worksheets.codalab.org/worksheets/0xc7fd7c36337146838b9b064a327e59fd/). 42 | -------------------------------------------------------------------------------- /docqa/elmo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/elmo/__init__.py -------------------------------------------------------------------------------- /docqa/elmo/ablate_elmo_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from tensorflow.contrib.keras.python.keras.initializers import TruncatedNormal 5 | 6 | from docqa import trainer 7 | from docqa.data_processing.qa_training_data import ContextLenKey 8 | from docqa.dataset import ClusteredBatcher 9 | from docqa.encoder import DocumentAndQuestionEncoder, SingleSpanAnswerEncoder 10 | from docqa.evaluator import LossEvaluator, SpanEvaluator 11 | from docqa.elmo.elmo import ElmoLayer 12 | from docqa.elmo.lm_qa_models import AttentionWithElmo, SquadContextConcatSkip 13 | from docqa.model_dir import ModelDir 14 | from docqa.nn.attention import BiAttention, StaticAttentionSelf 15 | from docqa.nn.embedder import FixedWordEmbedder, CharWordEmbedder, LearnedCharEmbedder 16 | from docqa.nn.layers import FullyConnected, ChainBiMapper, NullBiMapper, MaxPool, Conv1d, SequenceMapperSeq, \ 17 | VariationalDropoutLayer, ResidualLayer, ConcatWithProduct, MapperSeq, DropoutLayer 18 | from docqa.nn.recurrent_layers import CudnnGru 19 | from docqa.nn.similarity_layers import TriLinear 20 | from docqa.nn.span_prediction import BoundsPredictor 21 | from docqa.squad.squad_data import SquadCorpus, DocumentQaTrainingData 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser("Train our ELMo model on SQuAD") 26 | parser.add_argument("output_dir") 27 | parser.add_argument("--dim", type=int, default=90) 28 | parser.add_argument("--l2", type=float, default=0) 29 | parser.add_argument("--mode", choices=["input", "output", "both", "none"], default="both") 30 | parser.add_argument("--top_layer_only", action="store_true") 31 | args = parser.parse_args() 32 | 33 | out = args.output_dir + "-" + datetime.now().strftime("%m%d-%H%M%S") 34 | 35 | dim = args.dim 36 | recurrent_layer = CudnnGru(dim, w_init=TruncatedNormal(stddev=0.05)) 37 | 38 | params = trainer.TrainParams(trainer.SerializableOptimizer("Adadelta", dict(learning_rate=1.0)), 39 | ema=0.999, max_checkpoints_to_keep=2, async_encoding=10, 40 | num_epochs=24, log_period=30, eval_period=1200, save_period=1200, 41 | best_weights=("dev", "b17/text-f1"), 42 | eval_samples=dict(dev=None, train=8000)) 43 | 44 | lm_reduce = MapperSeq( 45 | ElmoLayer(args.l2, layer_norm=False, top_layer_only=args.top_layer_only), 46 | DropoutLayer(0.5), 47 | ) 48 | 49 | model = AttentionWithElmo( 50 | encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()), 51 | lm_model=SquadContextConcatSkip(), 52 | append_before_atten=(args.mode == "both" or args.mode == "output"), 53 | append_embed=(args.mode == "both" or args.mode == "input"), 54 | max_batch_size=128, 55 | word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True), 56 | char_embed=CharWordEmbedder( 57 | LearnedCharEmbedder(word_size_th=14, char_th=49, char_dim=20, init_scale=0.05, force_cpu=True), 58 | MaxPool(Conv1d(100, 5, 0.8)), 59 | shared_parameters=True 60 | ), 61 | embed_mapper=SequenceMapperSeq( 62 | VariationalDropoutLayer(0.8), 63 | recurrent_layer, 64 | VariationalDropoutLayer(0.8), 65 | ), 66 | lm_reduce=None, 67 | lm_reduce_shared=lm_reduce, 68 | per_sentence=False, 69 | memory_builder=NullBiMapper(), 70 | attention=BiAttention(TriLinear(bias=True), True), 71 | match_encoder=SequenceMapperSeq(FullyConnected(dim * 2, activation="relu"), 72 | ResidualLayer(SequenceMapperSeq( 73 | VariationalDropoutLayer(0.8), 74 | recurrent_layer, 75 | VariationalDropoutLayer(0.8), 76 | StaticAttentionSelf(TriLinear(bias=True), ConcatWithProduct()), 77 | FullyConnected(dim * 2, activation="relu"), 78 | )), 79 | VariationalDropoutLayer(0.8)), 80 | predictor = BoundsPredictor(ChainBiMapper( 81 | first_layer=recurrent_layer, 82 | second_layer=recurrent_layer 83 | )) 84 | ) 85 | 86 | batcher = ClusteredBatcher(45, ContextLenKey(), False, False) 87 | data = DocumentQaTrainingData(SquadCorpus(), None, batcher, batcher) 88 | 89 | with open(__file__, "r") as f: 90 | notes = f.read() 91 | notes = str(sorted(args.__dict__.items(), key=lambda x:x[0])) + "\n" + notes 92 | 93 | trainer.start_training(data, model, params, 94 | [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")], 95 | ModelDir(out), notes) 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /docqa/elmo/elmo.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | from docqa.nn.layers import Mapper 5 | 6 | 7 | class ElmoLayer(Mapper): 8 | def __init__(self, l2_coef: float, layer_norm: bool, top_layer_only: bool): 9 | self.l2_coef = l2_coef 10 | self.layer_norm = layer_norm 11 | self.top_layer_only = top_layer_only 12 | 13 | def apply(self, is_train, x, mask=None): 14 | mask = tf.sequence_mask(mask, tf.shape(x)[1]) 15 | output = weight_layers(1, x, mask, self.l2_coef, do_layer_norm=self.layer_norm, 16 | use_top_only=self.top_layer_only)["weighted_ops"][0] 17 | return output 18 | 19 | def __setstate__(self, state): 20 | if "softmax" not in state: 21 | state["softmax"] = True 22 | if "layer_norm" not in state: 23 | state["layer_norm"] = True 24 | if "top_layer_only" not in state: 25 | state["top_layer_only"] = False 26 | super().__setstate__(state) 27 | 28 | 29 | def weight_layers(n_out_layers, lm_embeddings, mask, l2_coef=None, 30 | use_top_only=False, do_layer_norm=True): 31 | ''' 32 | Weight the layers of a biLM with trainable scalar weights. 33 | For each output layer, this returns two ops. The first computes 34 | a layer specific weighted average of the biLM layers, and 35 | the second the l2 regularizer loss term. 36 | The regularization terms are also add to tf.GraphKeys.REGULARIZATION_LOSSES 37 | Input: 38 | n_out_layers: the number of weighted output layers 39 | bidirectional_lm: an instance of BidirectionalLanguageModel 40 | l2_coef: the l2 regularization coefficient 41 | Output: 42 | { 43 | 'weighted_ops': [ 44 | op to compute weighted average for output layer1, 45 | op to compute regularization term for output layer2, 46 | ... 47 | ], 48 | 'regularization_ops': [ 49 | op to compute regularization term for output layer1, 50 | op to compute regularization term for output layer2, 51 | ... 52 | ] 53 | } 54 | ''' 55 | 56 | def _l2_regularizer(weights): 57 | return l2_coef * tf.reduce_sum(tf.square(weights)) 58 | 59 | # Get ops for computing LM embeddings and mask 60 | n_lm_layers = int(lm_embeddings.get_shape()[1]) 61 | lm_dim = int(lm_embeddings.get_shape()[3]) 62 | 63 | if not tf.get_variable_scope().reuse: 64 | prefix = "monitor/" 65 | if "weight_embed" in tf.get_variable_scope().name: 66 | prefix += "input/" 67 | else: 68 | prefix += "output/" 69 | else: 70 | prefix = None 71 | 72 | with tf.control_dependencies([lm_embeddings, mask]): 73 | # Cast the mask and broadcast for layer use. 74 | mask_float = tf.cast(mask, 'float32') 75 | broadcast_mask = tf.expand_dims(mask_float, axis=-1) 76 | 77 | def _do_ln(x): 78 | # do layer normalization excluding the mask 79 | x_masked = x * broadcast_mask 80 | N = tf.reduce_sum(mask_float) * lm_dim 81 | mean = tf.reduce_sum(x_masked) / N 82 | variance = tf.reduce_sum(((x_masked - mean) * broadcast_mask) ** 2 83 | ) / N 84 | return tf.nn.batch_normalization( 85 | x, mean, variance, None, None, 1E-12 86 | ) 87 | 88 | ret = {'weighted_ops': [], 'regularization_ops': []} 89 | for k in range(n_out_layers): 90 | if use_top_only: 91 | layers = tf.split(lm_embeddings, n_lm_layers, axis=1) 92 | # just the top layer 93 | sum_pieces = tf.squeeze(layers[-1], squeeze_dims=1) 94 | # no regularization 95 | reg = [0.0] 96 | else: 97 | W = tf.get_variable( 98 | 'ELMo_W_{}'.format(k), 99 | shape=(n_lm_layers,), 100 | initializer=tf.zeros_initializer, 101 | regularizer=_l2_regularizer, 102 | trainable=True, 103 | ) 104 | 105 | if prefix is not None: 106 | for i in range(3): 107 | print("Monitoring " + prefix + "%d/" % i) 108 | tf.add_to_collection(prefix + "%d/" % i, W[i]) 109 | 110 | # normalize the weights 111 | normed_weights = tf.split( 112 | tf.nn.softmax(W + 1.0 / n_lm_layers), n_lm_layers 113 | ) 114 | # split LM layers 115 | layers = tf.split(lm_embeddings, n_lm_layers, axis=1) 116 | 117 | # compute the weighted, normalized LM activations 118 | pieces = [] 119 | for w, t in zip(normed_weights, layers): 120 | if do_layer_norm: 121 | pieces.append(w * _do_ln(tf.squeeze(t, squeeze_dims=1))) 122 | else: 123 | pieces.append(w * tf.squeeze(t, squeeze_dims=1)) 124 | sum_pieces = tf.add_n(pieces) 125 | 126 | # get the regularizer 127 | reg = [ 128 | r for r in tf.get_collection( 129 | tf.GraphKeys.REGULARIZATION_LOSSES) 130 | if r.name.find('ELMo_W_{}/'.format(k)) >= 0 131 | ] 132 | 133 | # scale the weighted sum by gamma 134 | gamma = tf.get_variable( 135 | 'ELMo_gamma_{}'.format(k), 136 | shape=(1,), 137 | initializer=tf.ones_initializer, 138 | regularizer=None, 139 | trainable=True, 140 | ) 141 | 142 | if prefix is not None: 143 | tf.add_to_collection(prefix + "gamma", gamma[0]) 144 | 145 | weighted_lm_layers = sum_pieces * gamma 146 | 147 | ret['weighted_ops'].append(weighted_lm_layers) 148 | ret['regularization_ops'].append(reg[0]) 149 | 150 | return ret -------------------------------------------------------------------------------- /docqa/elmo/eval_elmo_minimal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from os.path import join 4 | 5 | import nltk 6 | import tensorflow as tf 7 | 8 | from docqa.data_processing.qa_training_data import ParagraphAndQuestionDataset, ContextLenKey 9 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer 10 | from docqa.data_processing.word_vectors import load_word_vector_file 11 | from docqa.dataset import ClusteredBatcher 12 | from docqa.model_dir import ModelDir 13 | from docqa.squad.build_squad_dataset import parse_squad_data 14 | from docqa.squad.squad_data import split_docs 15 | from docqa.utils import ResourceLoader 16 | 17 | """ 18 | Used to submit our official SQuAD scores via codalab 19 | """ 20 | 21 | 22 | def run(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("input_data") 25 | parser.add_argument("output_data") 26 | parser.add_argument("--n", type=int, default=None) 27 | parser.add_argument("-b", "--batch_size", type=int, default=100) 28 | parser.add_argument("--ema", action="store_true") 29 | args = parser.parse_args() 30 | 31 | input_data = args.input_data 32 | output_path = args.output_data 33 | model_dir = ModelDir("model") 34 | nltk.data.path.append("nltk_data") 35 | 36 | print("Loading data") 37 | docs = parse_squad_data(input_data, "", NltkAndPunctTokenizer(), False) 38 | pairs = split_docs(docs) 39 | dataset = ParagraphAndQuestionDataset(pairs, ClusteredBatcher(args.batch_size, ContextLenKey(), False, True)) 40 | 41 | print("Done, init model") 42 | model = model_dir.get_model() 43 | # small hack, just load the vector file at its expected location rather then using the config location 44 | loader = ResourceLoader(lambda a, b: load_word_vector_file("glove.840B.300d.txt", b)) 45 | lm_model = model.lm_model 46 | basedir = "lm" 47 | lm_model.lm_vocab_file = join(basedir, "squad_train_dev_all_unique_tokens.txt") 48 | lm_model.options_file = join(basedir, "options_squad_lm_2x4096_512_2048cnn_2xhighway_skip.json") 49 | lm_model.weight_file = join(basedir, "squad_context_concat_lm_2x4096_512_2048cnn_2xhighway_skip.hdf5") 50 | lm_model.embed_weights_file = None 51 | 52 | model.set_inputs([dataset], loader) 53 | 54 | print("Done, building graph") 55 | sess = tf.Session() 56 | with sess.as_default(): 57 | pred = model.get_prediction() 58 | best_span = pred.get_best_span(17)[0] 59 | 60 | all_vars = tf.global_variables() + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS) 61 | dont_restore_names = {x.name for x in all_vars if x.name.startswith("bilm")} 62 | print(sorted(dont_restore_names)) 63 | vars = [x for x in all_vars if x.name not in dont_restore_names] 64 | 65 | print("Done, loading weights") 66 | checkpoint = model_dir.get_best_weights() 67 | if checkpoint is None: 68 | print("Loading most recent checkpoint") 69 | checkpoint = model_dir.get_latest_checkpoint() 70 | else: 71 | print("Loading best weights") 72 | 73 | saver = tf.train.Saver(vars) 74 | saver.restore(sess, checkpoint) 75 | 76 | if args.ema: 77 | ema = tf.train.ExponentialMovingAverage(0) 78 | saver = tf.train.Saver({ema.average_name(x): x for x in tf.trainable_variables()}) 79 | saver.restore(sess, checkpoint) 80 | 81 | sess.run(tf.variables_initializer([x for x in all_vars if x.name in dont_restore_names])) 82 | 83 | print("Done, starting evaluation") 84 | out = {} 85 | for i, batch in enumerate(dataset.get_epoch()): 86 | if args.n is not None and i == args.n: 87 | break 88 | print("On batch: %d" % (i +1)) 89 | enc = model.encode(batch, False) 90 | spans = sess.run(best_span, feed_dict=enc) 91 | for (s, e), point in zip(spans, batch): 92 | out[point.question_id] = point.get_original_text(s, e) 93 | 94 | sess.close() 95 | 96 | print("Done, saving") 97 | with open(output_path, "w") as f: 98 | json.dump(out, f) 99 | 100 | print("Mission accomplished!") 101 | 102 | 103 | if __name__ == "__main__": 104 | run() 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /docqa/elmo/run_on_user_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | 5 | from docqa.data_processing.qa_training_data import ParagraphAndQuestion, ParagraphAndQuestionSpec 6 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer 7 | from docqa.elmo.lm_qa_models import ElmoQaModel 8 | from docqa.model_dir import ModelDir 9 | 10 | """ 11 | Script to run a model on user provided question/context input. 12 | Its main purpose is to be an example of how to use the model on new question/context pairs. 13 | """ 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="Run an ELMo model on user input") 18 | parser.add_argument("model", help="Model directory") 19 | parser.add_argument("question", help="Question to answer") 20 | parser.add_argument("context", help="Context to answer the question with") 21 | args = parser.parse_args() 22 | 23 | # Tokenize the input, the models expected data to be tokenized using `NltkAndPunctTokenizer` 24 | # Note the model expects case-sensitive input 25 | tokenizer = NltkAndPunctTokenizer() 26 | question = tokenizer.tokenize_paragraph_flat(args.question) 27 | context = tokenizer.tokenize_paragraph_flat(args.context) 28 | 29 | print("Loading model") 30 | model_dir = ModelDir(args.model) 31 | model = model_dir.get_model() 32 | if not isinstance(model, ElmoQaModel): 33 | raise ValueError("This script is build to work for ElmoQaModel models only") 34 | 35 | # Important! This tells the language model not to use the pre-computed word vectors, 36 | # which are only applicable for the SQuAD dev/train sets. 37 | # Instead the language model will use its character-level CNN to compute 38 | # the word vectors dynamically. 39 | model.lm_model.embed_weights_file = None 40 | 41 | # Tell the model the batch size and vocab to expect, This will load the needed 42 | # word vectors and fix the batch size when building the graph / encoding the input 43 | print("Setting up model") 44 | voc = set(question) 45 | voc.update(context) 46 | model.set_input_spec(ParagraphAndQuestionSpec(batch_size=1), voc) 47 | 48 | # Now we build the actual tensorflow graph, `best_span` and `conf` are 49 | # tensors holding the predicted span (inclusive) and confidence scores for each 50 | # element in the input batch 51 | print("Build tf graph") 52 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 53 | with sess.as_default(): 54 | # 17 means to limit the span to size 17 or less 55 | best_spans, conf = model.get_prediction().get_best_span(17) 56 | 57 | # Now restore the weights, this is a bit fiddly since we need to avoid restoring the 58 | # bilm weights, and instead load them from the pre-computed data 59 | all_vars = tf.global_variables() + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS) 60 | lm_var_names = {x.name for x in all_vars if x.name.startswith("bilm")} 61 | vars = [x for x in all_vars if x.name not in lm_var_names] 62 | model_dir.restore_checkpoint(sess, vars) 63 | 64 | # Run the initializer of the lm weights, which will load them from the lm directory 65 | sess.run(tf.variables_initializer([x for x in all_vars if x.name in lm_var_names])) 66 | 67 | # Now the model is ready to run 68 | # The model takes input in the form of `ContextAndQuestion` objects, for example: 69 | data = [ParagraphAndQuestion(context, question, None, "user-question1")] 70 | 71 | print("Starting run") 72 | # The model is run in two steps, first it "encodes" the paragraph/context pairs 73 | # into numpy arrays, then to use `sess` to run the actual model get the predictions 74 | encoded = model.encode(data, is_train=False) # batch of `ContextAndQuestion` -> feed_dict 75 | best_spans, conf = sess.run([best_spans, conf], feed_dict=encoded) # feed_dict -> predictions 76 | print("Best span: " + str(best_spans[0])) 77 | print("Answer text: " + " ".join(context[best_spans[0][0]:best_spans[0][1]+1])) 78 | print("Confidence: " + str(conf[0])) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /docqa/elmo/show_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | from docqa.model_dir import ModelDir 4 | import numpy as np 5 | 6 | 7 | def softmax(x): 8 | x = np.exp(x) 9 | return x / x.sum() 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("model") 15 | args = parser.parse_args() 16 | 17 | model_dir = ModelDir(args.model) 18 | checkpoint = model_dir.get_best_weights() 19 | reader = tf.train.NewCheckpointReader(checkpoint) 20 | 21 | if reader.has_tensor("weight_embed_context_lm/layer_0/w"): 22 | x = "w" 23 | else: 24 | x = "ELMo_W_0" 25 | 26 | for i in reader.get_variable_to_shape_map().items(): 27 | print(i) 28 | 29 | input_w = reader.get_tensor("weight_embed_lm/layer_0/%s/ExponentialMovingAverage" % x) 30 | output_w = reader.get_tensor("weight_lm/layer_0/%s/ExponentialMovingAverage" % x) 31 | 32 | print("Input") 33 | print(input_w) 34 | print("(Softmax): " + str(softmax(input_w))) 35 | 36 | print("Output") 37 | print(output_w) 38 | print("(Softmax): " + str(softmax(output_w))) 39 | 40 | if __name__ == "__main__": 41 | main() 42 | 43 | -------------------------------------------------------------------------------- /docqa/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/eval/__init__.py -------------------------------------------------------------------------------- /docqa/eval/eval_squad_minimal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import tensorflow as tf 4 | import nltk 5 | 6 | from docqa.data_processing.qa_training_data import ParagraphAndQuestionDataset, ContextLenKey 7 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer 8 | from docqa.data_processing.word_vectors import load_word_vectors, load_word_vector_file 9 | from docqa.dataset import ClusteredBatcher 10 | from docqa.squad.build_squad_dataset import parse_squad_data 11 | from docqa.squad.squad_data import split_docs 12 | from docqa.model_dir import ModelDir 13 | from docqa.utils import ResourceLoader 14 | 15 | 16 | """ 17 | Used to submit our official SQuAD scores via codalab 18 | """ 19 | 20 | 21 | def run(): 22 | input_data = sys.argv[1] 23 | output_path = sys.argv[2] 24 | model_dir = ModelDir("model") 25 | nltk.data.path.append("nltk_data") 26 | 27 | print("Loading data") 28 | docs = parse_squad_data(input_data, "", NltkAndPunctTokenizer(), False) 29 | pairs = split_docs(docs) 30 | dataset = ParagraphAndQuestionDataset(pairs, ClusteredBatcher(100, ContextLenKey(), False, True)) 31 | 32 | print("Done, init model") 33 | model = model_dir.get_model() 34 | # small hack, just load the vector file at its expected location rather then using the config location 35 | loader = ResourceLoader(lambda a, b: load_word_vector_file("glove.840B.300d.txt", b)) 36 | model.set_inputs([dataset], loader) 37 | 38 | print("Done, building graph") 39 | sess = tf.Session() 40 | with sess.as_default(): 41 | pred = model.get_prediction() 42 | best_span = pred.get_best_span(17)[0] 43 | 44 | print("Done, loading weights") 45 | checkpoint = model_dir.get_latest_checkpoint() 46 | saver = tf.train.Saver() 47 | saver.restore(sess, checkpoint) 48 | ema = tf.train.ExponentialMovingAverage(0) 49 | saver = tf.train.Saver({ema.average_name(x): x for x in tf.trainable_variables()}) 50 | saver.restore(sess, checkpoint) 51 | 52 | print("Done, starting evaluation") 53 | out = {} 54 | for batch in dataset.get_epoch(): 55 | enc = model.encode(batch, False) 56 | spans = sess.run(best_span, feed_dict=enc) 57 | for (s, e), point in zip(spans, batch): 58 | out[point.question_id] = point.get_original_text(s, e) 59 | 60 | sess.close() 61 | 62 | print("Done, saving") 63 | with open(output_path, "w") as f: 64 | json.dump(out, f) 65 | 66 | print("Mission accomplished!") 67 | 68 | 69 | if __name__ == "__main__": 70 | run() 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /docqa/eval/ranked_scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from docqa.utils import print_table 8 | 9 | 10 | def compute_ranked_scores(df, max_over, target_score, group_cols): 11 | scores = [] 12 | for _, group in df[[max_over, target_score] + group_cols].groupby(group_cols): 13 | if target_score == max_over: 14 | scores.append(group[target_score].cummax().values) 15 | else: 16 | used_predictions = group[max_over].expanding().apply(lambda x: x.argmax()) 17 | scores.append(group[target_score].iloc[used_predictions].values) 18 | 19 | max_para = max(len(x) for x in scores) 20 | summed_scores = np.zeros(max_para) 21 | for s in scores: 22 | summed_scores[:len(s)] += s 23 | summed_scores[len(s):] += s[-1] 24 | return summed_scores/len(scores) 25 | 26 | 27 | def show_scores_table(df, cols): 28 | rows = [["Rank"] + cols] 29 | for i in range(len(df)): 30 | rows.append(["%d" % (i+1)] + ["%.4f" % df[k].iloc[i] for k in cols]) 31 | print_table(rows) 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser(description= 36 | "Compute scores as more paragraphs are used, using " 37 | "a per-paragraph csv file as built from our evaluation scripts ") 38 | parser.add_argument('answers', help='answer file(s)', nargs="+") 39 | parser.add_argument('--per_doc', action="store_true", 40 | help="Show scores treating each (quesiton, document) pair as a " 41 | "datapoint, instead of each question. Should be used for the TriviaQA Web" 42 | " dataset") 43 | args = parser.parse_args() 44 | 45 | print("Loading answers..") 46 | answer_dfs = [] 47 | for filename in args.answers: 48 | answer_dfs.append(pd.read_csv(filename)) 49 | 50 | print("Computing ranks...") 51 | if args.per_doc: 52 | group_by = ["question_id", "doc_id"] 53 | else: 54 | group_by = ["question_id"] 55 | 56 | data = OrderedDict() 57 | for i, answer_df in enumerate(answer_dfs): 58 | answer_df.sort_values(["question_id", "rank"], inplace=True) 59 | model_scores = compute_ranked_scores(answer_df, "predicted_score", "text_em", group_by) 60 | data["answers_%d_em" % i] = model_scores 61 | model_scores = compute_ranked_scores(answer_df, "predicted_score", "text_f1", group_by) 62 | data["answers_%d_f1" % i] = model_scores 63 | 64 | show_scores_table(pd.DataFrame(data), 65 | sorted(data.keys(), key=lambda x: (0, x) if x.endswith("em") else (1, x))) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | 72 | -------------------------------------------------------------------------------- /docqa/eval/squad_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import List 4 | 5 | import numpy as np 6 | 7 | from docqa import trainer 8 | from docqa.data_processing.qa_training_data import ParagraphAndQuestionDataset, ContextAndQuestion 9 | from docqa.dataset import FixedOrderBatcher 10 | from docqa.evaluator import Evaluator, Evaluation, SpanEvaluator 11 | from docqa.model_dir import ModelDir 12 | from docqa.squad.squad_data import SquadCorpus, split_docs 13 | from docqa.utils import transpose_lists, print_table 14 | 15 | """ 16 | Run an evaluation on squad and record the official output 17 | """ 18 | 19 | 20 | class RecordSpanPrediction(Evaluator): 21 | def __init__(self, bound: int): 22 | self.bound = bound 23 | 24 | def tensors_needed(self, prediction): 25 | span, score = prediction.get_best_span(self.bound) 26 | return dict(spans=span, model_scores=score) 27 | 28 | def evaluate(self, data: List[ContextAndQuestion], true_len, **kargs): 29 | spans, model_scores = kargs["spans"], kargs["model_scores"] 30 | results = {"model_conf": model_scores, 31 | "predicted_span": spans, 32 | "question_id": [x.question_id for x in data]} 33 | return Evaluation({}, results) 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD') 38 | parser.add_argument('model', help='model directory to evaluate') 39 | parser.add_argument("-o", "--official_output", type=str, help="where to output an official result file") 40 | parser.add_argument('-n', '--sample_questions', type=int, default=None, 41 | help="(for testing) run on a subset of questions") 42 | parser.add_argument('--answer_bounds', nargs='+', type=int, default=[17], 43 | help="Max size of answer") 44 | parser.add_argument('-b', '--batch_size', type=int, default=200, 45 | help="Batch size, larger sizes can be faster but uses more memory") 46 | parser.add_argument('-s', '--step', default=None, 47 | help="Weights to load, can be a checkpoint step or 'latest'") 48 | parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev") 49 | parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist") 50 | args = parser.parse_args() 51 | 52 | model_dir = ModelDir(args.model) 53 | 54 | corpus = SquadCorpus() 55 | if args.corpus == "dev": 56 | questions = corpus.get_dev() 57 | else: 58 | questions = corpus.get_train() 59 | questions = split_docs(questions) 60 | 61 | if args.sample_questions: 62 | np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id)) 63 | questions = questions[:args.sample_questions] 64 | 65 | questions.sort(key=lambda x:x.n_context_words, reverse=True) 66 | dataset = ParagraphAndQuestionDataset(questions, FixedOrderBatcher(args.batch_size, True)) 67 | 68 | evaluators = [SpanEvaluator(args.answer_bounds, text_eval="squad")] 69 | if args.official_output is not None: 70 | evaluators.append(RecordSpanPrediction(args.answer_bounds[0])) 71 | 72 | if args.step is not None: 73 | if args.step == "latest": 74 | checkpoint = model_dir.get_latest_checkpoint() 75 | else: 76 | checkpoint = model_dir.get_checkpoint(int(args.step)) 77 | else: 78 | checkpoint = model_dir.get_best_weights() 79 | if checkpoint is not None: 80 | print("Using best weights") 81 | else: 82 | print("Using latest checkpoint") 83 | checkpoint = model_dir.get_latest_checkpoint() 84 | 85 | model = model_dir.get_model() 86 | 87 | evaluation = trainer.test(model, evaluators, {args.corpus: dataset}, 88 | corpus.get_resource_loader(), checkpoint, not args.no_ema)[args.corpus] 89 | 90 | # Print the scalar results in a two column table 91 | scalars = evaluation.scalars 92 | cols = list(sorted(scalars.keys())) 93 | table = [cols] 94 | header = ["Metric", ""] 95 | table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols]) 96 | print_table([header] + transpose_lists(table)) 97 | 98 | # Save the official output 99 | if args.official_output is not None: 100 | quid_to_para = {} 101 | for x in questions: 102 | quid_to_para[x.question_id] = x.paragraph 103 | 104 | q_id_to_answers = {} 105 | q_ids = evaluation.per_sample["question_id"] 106 | spans = evaluation.per_sample["predicted_span"] 107 | for q_id, (start, end) in zip(q_ids, spans): 108 | text = quid_to_para[q_id].get_original_text(start, end) 109 | q_id_to_answers[q_id] = text 110 | 111 | with open(args.official_output, "w") as f: 112 | json.dump(q_id_to_answers, f) 113 | 114 | if __name__ == "__main__": 115 | main() 116 | # tmp() 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /docqa/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from docqa.dataset import Dataset 4 | from tensorflow import Tensor 5 | from docqa.utils import ResourceLoader 6 | 7 | from docqa.configurable import Configurable 8 | 9 | 10 | class Prediction(object): 11 | """ Prediction from a model, subclasses should provide access to a tensor 12 | representations of the model's output """ 13 | pass 14 | 15 | 16 | class Model(Configurable): 17 | """ 18 | Our most general specification of a model/neural network, for our purposes a model 19 | is basically a pair of functions 20 | 1) a way to map a (unspecified) kind of python object to numpy tensors and 21 | 2) a tensorflow function that maps those kinds of tensors to a set of (also unspecified) output tensors 22 | 23 | For convenience, models maintain a of set of input placeholders that clients can make use of to 24 | feed the tensorflow function (or reference to construct their own tensor inputs). 25 | 26 | Models have two stages of initialization. First it needs 27 | to be initialized with the training data using `init` (typically this does things like deciding what 28 | words/chars to train embeddings for). This should only be done once for this object's lifetime. 29 | 30 | Afterwards use `set_inputs` to specify the input format, this does things like determine the batch size 31 | or the vocabulary that will be used 32 | 33 | After initialiation, `encode` will produce map of placeholder -> numpy array 34 | which can be used directly as a feed dict for the output of `get_predictions` 35 | 36 | For more advanced usage, `get_predictions_for` can be used with any tensors of the 37 | same shape/dtype as the input place holders. Clients should pass in a dict mapping 38 | the placeholders to the input tensors they want to use instead. 39 | 40 | `get_predictions_for` methods behave like any other tensorflow function, in that it will 41 | load/initialize/reuse variables depending on the current tensorflow scope and can add 42 | to tf.collections. Our trainer method makes use of some of these collections, including: 43 | tf.GraphKeys.LOSSES 44 | tf.GraphKeys.REGULARIZATION_LOSSES 45 | tf.GraphKeys.SUMMARIES 46 | tf.GraphKeys.SAVEABLE_OBJECTS 47 | tf.GraphKeys.TRAINABLE_VARIABLES 48 | "monitor/*" collections, which will be summed, and the EMA result logged to tensorboard 49 | """ 50 | 51 | @property 52 | def name(self): 53 | return self.__class__.__name__ 54 | 55 | def init(self, train_data, resource_loader: ResourceLoader): 56 | raise NotImplementedError() 57 | 58 | def set_inputs(self, datasets: List[Dataset], resource_loader: ResourceLoader) -> List[Tensor]: 59 | raise NotImplementedError() 60 | 61 | def get_prediction(self) -> Prediction: 62 | return self.get_predictions_for({x: x for x in self.get_placeholders()}) 63 | 64 | def get_placeholders(self) -> List[Tensor]: 65 | raise NotImplementedError() 66 | 67 | def get_predictions_for(self, input_tensors: Dict[Tensor, Tensor]) -> Prediction: 68 | raise NotImplementedError() 69 | 70 | def encode(self, examples, is_train: bool) -> Dict[Tensor, object]: 71 | raise NotImplementedError() 72 | -------------------------------------------------------------------------------- /docqa/model_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from genericpath import exists 4 | from os.path import isabs, join 5 | 6 | import tensorflow as tf 7 | 8 | from docqa.model import Model 9 | 10 | 11 | class ModelDir(object): 12 | """ Wrapper for accessing a folder we are storing a model in""" 13 | 14 | def __init__(self, name: str): 15 | if isabs(name): 16 | print("WARNING!!!, using an absolute paths for models name can break restoring " 17 | "the model in different directories after being checkpointed") 18 | # why is this even a thing? 19 | self.dir = name 20 | 21 | def get_model(self) -> Model: 22 | with open(join(self.dir, "model.pkl"), "rb") as f: 23 | return pickle.load(f) 24 | 25 | def get_eval_dir(self): 26 | answer_dir = join(self.dir, "answers") 27 | if not exists(answer_dir): 28 | os.mkdir(answer_dir) 29 | return answer_dir 30 | 31 | def get_last_train_params(self): 32 | last_train_file = None 33 | last_train_step = -1 34 | for file in os.listdir(self.dir): 35 | if file.startswith("train_from_") and file.endswith("pkl"): 36 | step = int(file[11:file.rfind(".pkl")]) 37 | if step > last_train_step: 38 | last_train_step = step 39 | last_train_file = join(self.dir, file) 40 | 41 | print("Resuming using the parameters stored in: " + last_train_file) 42 | with open(last_train_file, "rb") as f: 43 | return pickle.load(f) 44 | 45 | def get_latest_checkpoint(self): 46 | return tf.train.latest_checkpoint(self.save_dir) 47 | 48 | def get_checkpoint(self, step): 49 | # I cant find much formal documentation on how to do this, but this seems to work 50 | return join(self.save_dir, "checkpoint-%d-%d" % (step, step)) 51 | 52 | def get_best_weights(self): 53 | if exists(self.best_weight_dir): 54 | return tf.train.latest_checkpoint(self.best_weight_dir) 55 | return None 56 | 57 | def restore_checkpoint(self, sess, var_list=None, load_ema=True): 58 | """ 59 | Restores either the best weights or the most recent checkpoint, assuming the correct 60 | variables have already been added to the tf default graph e.g., .get_prediction() 61 | has been called the model stored in `self`. 62 | Automatically detects if EMA weights exists, and if they do loads them instead 63 | """ 64 | checkpoint = self.get_best_weights() 65 | if checkpoint is None: 66 | print("Loading most recent checkpoint") 67 | checkpoint = self.get_latest_checkpoint() 68 | else: 69 | print("Loading best weights") 70 | 71 | if load_ema: 72 | if var_list is None: 73 | # Same default used by `Saver` 74 | var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + \ 75 | tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS) 76 | 77 | # Automatically check if there are EMA variables, if so use those 78 | reader = tf.train.NewCheckpointReader(checkpoint) 79 | ema = tf.train.ExponentialMovingAverage(0) 80 | ema_names = {ema.average_name(x): x for x in var_list 81 | if reader.has_tensor(ema.average_name(x))} 82 | if len(ema_names) > 0: 83 | print("Found EMA weights, loading them") 84 | ema_vars = set(x for x in ema_names.values()) 85 | var_list = {v.op.name: v for v in var_list if v not in ema_vars} 86 | var_list.update(ema_names) 87 | 88 | saver = tf.train.Saver(var_list) 89 | saver.restore(sess, checkpoint) 90 | 91 | 92 | @property 93 | def save_dir(self): 94 | # Stores training checkpoint 95 | return join(self.dir, "save") 96 | 97 | @property 98 | def best_weight_dir(self): 99 | # Stores training checkpoint 100 | return join(self.dir, "best-weights") 101 | 102 | @property 103 | def log_dir(self): 104 | return join(self.dir, "log") -------------------------------------------------------------------------------- /docqa/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/nn/__init__.py -------------------------------------------------------------------------------- /docqa/nn/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | VERY_NEGATIVE_NUMBER = -1e29 4 | 5 | 6 | def dropout(x, keep_prob, is_train, noise_shape=None, seed=None): 7 | if keep_prob >= 1.0: 8 | return x 9 | return tf.cond(is_train, lambda: tf.nn.dropout(x, keep_prob, noise_shape=noise_shape, seed=seed), lambda: x) 10 | 11 | 12 | def segment_logsumexp(xs, segments): 13 | """ Similar tf.segment_sum but compute logsumexp rather then sum """ 14 | # Stop gradients following the implementation of tf.reduce_logsumexp 15 | maxs = tf.stop_gradient(tf.reduce_max(xs, axis=1)) 16 | segment_maxes = tf.segment_max(maxs, segments) 17 | xs -= tf.expand_dims(tf.gather(segment_maxes, segments), 1) 18 | sums = tf.reduce_sum(tf.exp(xs), axis=1) 19 | return tf.log(tf.segment_sum(sums, segments)) + segment_maxes 20 | 21 | 22 | def exp_mask(val, mask): 23 | mask = tf.cast(tf.sequence_mask(mask, tf.shape(val)[1]), tf.float32) 24 | return val * mask + (1 - mask) * VERY_NEGATIVE_NUMBER 25 | 26 | 27 | -------------------------------------------------------------------------------- /docqa/nn/similarity_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from docqa.configurable import Configurable 4 | from docqa.nn.layers import get_keras_initialization, get_keras_activation 5 | 6 | 7 | def compute_attention_mask(x_mask, mem_mask, x_word_dim, key_word_dim): 8 | """ computes a (batch, x_word_dim, key_word_dim) bool mask for clients that want masking """ 9 | if x_mask is None and mem_mask is None: 10 | return None 11 | elif x_mask is None or mem_mask is None: 12 | raise NotImplementedError() 13 | 14 | x_mask = tf.sequence_mask(x_mask, x_word_dim) 15 | mem_mask = tf.sequence_mask(mem_mask, key_word_dim) 16 | join_mask = tf.logical_and(tf.expand_dims(x_mask, 2), tf.expand_dims(mem_mask, 1)) 17 | return join_mask 18 | 19 | 20 | class SimilarityFunction(Configurable): 21 | """ 22 | Computes a pairwise score between elements in each sequence 23 | (batch, time1, dim1], (batch, time2, dim2) -> (batch, time1, time2) 24 | """ 25 | def get_scores(self, tensor_1, tensor_2): 26 | raise NotImplementedError 27 | 28 | def get_one_sided_scores(self, tensor_1, tensor_2): 29 | return tf.squeeze(self.get_scores(tf.expand_dims(tensor_1, 1), tensor_2), squeeze_dims=[1]) 30 | 31 | 32 | class _WithBias(SimilarityFunction): 33 | def __init__(self, bias: bool): 34 | # Note since we typically do softmax on the result, having a bias is usually redundant 35 | self.bias = bias 36 | 37 | def get_scores(self, tensor_1, tensor_2): 38 | out = self._distance_logits(tensor_1, tensor_2) 39 | if self.bias: 40 | bias = tf.get_variable("bias", shape=(), dtype=tf.float32) 41 | out += bias 42 | return out 43 | 44 | def _distance_logits(self, tensor_1, tensor_2): 45 | raise NotImplemented() 46 | 47 | 48 | class DotProduct(_WithBias): 49 | """ Dot-Prod attention with scaling as seen in https://arxiv.org/pdf/1706.03762.pdf """ 50 | 51 | def __init__(self, bias: bool=False, scale: bool=False): 52 | super().__init__(bias) 53 | self.scale = scale 54 | 55 | def _distance_logits(self, tensor_1, tensor_2): 56 | dots = tf.matmul(tensor_1, tensor_2, transpose_b=True) 57 | if self.scale: 58 | last_dim = dots.shape.as_list()[-1] 59 | if last_dim is None: 60 | last_dim = tf.cast(tf.shape(dots)[-1], tf.float32) 61 | dots /= tf.sqrt(last_dim) 62 | return dots 63 | 64 | 65 | class DotProductProject(_WithBias): 66 | """ Dot-Prod attention while projecting the input layers """ 67 | 68 | def __init__(self, project_size, bias: bool=False, scale: bool=False, 69 | project_bias: bool=False, init="glorot_uniform", share_project=False): 70 | super().__init__(bias) 71 | self.project_bias = project_bias 72 | self.init = init 73 | self.scale = scale 74 | self.project_size = project_size 75 | self.share_project = share_project 76 | 77 | def _distance_logits(self, x1, x2): 78 | init = get_keras_initialization(self.init) 79 | 80 | project1 = tf.get_variable("project1", (x1.shape.as_list()[-1], self.project_size), initializer=init) 81 | x1 = tf.tensordot(x1, project1, [[2], [0]]) 82 | 83 | if self.share_project: 84 | if x2.shape.as_list()[-1] != x1.shape.as_list()[-1]: 85 | raise ValueError() 86 | project2 = project1 87 | else: 88 | project2 = tf.get_variable("project2", (x2.shape.as_list()[-1], self.project_size), initializer=init) 89 | x2 = tf.tensordot(x2, project2, [[2], [0]]) 90 | 91 | if self.project_bias: 92 | x1 += tf.get_variable("bias1", (1, 1, self.project_size), initializer=tf.zeros_initializer()) 93 | x2 += tf.get_variable("bias2", (1, 1, self.project_size), initializer=tf.zeros_initializer()) 94 | 95 | dots = tf.matmul(x1, x2, transpose_b=True) 96 | if self.scale: 97 | dots /= tf.sqrt(tf.cast(self.project_size, tf.float32)) 98 | return dots 99 | 100 | 101 | class BiLinearSum(_WithBias): 102 | 103 | def __init__(self, bias: bool=False, init="glorot_uniform"): 104 | self.init = init 105 | super().__init__(bias) 106 | 107 | def _distance_logits(self, x, keys): 108 | init = get_keras_initialization(self.init) 109 | key_w = tf.get_variable("key_w", shape=keys.shape.as_list()[-1], initializer=init, dtype=tf.float32) 110 | key_logits = tf.tensordot(keys, key_w, axes=[[2], [0]]) # (batch, key_len) 111 | 112 | x_w = tf.get_variable("x_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) 113 | x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len) 114 | 115 | # Broadcasting will expand the arrays to (batch, x_len, key_len) 116 | return tf.expand_dims(x_logits, axis=2) + tf.expand_dims(key_logits, axis=1) 117 | 118 | 119 | class BiLinear(_WithBias): 120 | 121 | def __init__(self, projected_size: int, activation="tanh", bias: bool=False, 122 | init="glorot_uniform", shared_projection=False): 123 | self.init = init 124 | self.activation = activation 125 | self.shared_project = shared_projection 126 | self.projected_size = projected_size 127 | super().__init__(bias) 128 | 129 | def _distance_logits(self, x, keys): 130 | init = get_keras_initialization(self.init) 131 | key_w = tf.get_variable("key_w", shape=(keys.shape.as_list()[-1], self.projected_size), initializer=init, dtype=tf.float32) 132 | key_logits = tf.tensordot(keys, key_w, axes=[[2], [0]]) # (batch, key_len, projected_size) 133 | 134 | if self.shared_project: 135 | x_w = key_w 136 | else: 137 | x_w = tf.get_variable("x_w", shape=(x.shape.as_list()[-1], self.projected_size), initializer=init, dtype=tf.float32) 138 | 139 | x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len, projected_size) 140 | 141 | summed = tf.expand_dims(x_logits, axis=2) + tf.expand_dims(key_logits, axis=1) # (batch, key_len, x_len, poject_size) 142 | 143 | summed = get_keras_activation(self.activation)(summed) 144 | 145 | combine_w = tf.get_variable("combine_w", shape=self.projected_size, initializer=init, dtype=tf.float32) 146 | 147 | return tf.tensordot(summed, combine_w, axes=[[3], [0]]) # (batch, key_len, x_len) 148 | 149 | 150 | class TriLinear(_WithBias): 151 | """ Function used by BiDaF, bi-linear with an extra component for the dots of the vectors """ 152 | def __init__(self, init="glorot_uniform", bias=False): 153 | super().__init__(bias) 154 | self.init = init 155 | 156 | def _distance_logits(self, x, keys): 157 | init = get_keras_initialization(self.init) 158 | 159 | key_w = tf.get_variable("key_w", shape=keys.shape.as_list()[-1], initializer=init, dtype=tf.float32) 160 | key_logits = tf.tensordot(keys, key_w, axes=[[2], [0]]) # (batch, key_len) 161 | 162 | x_w = tf.get_variable("input_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) 163 | x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len) 164 | 165 | dot_w = tf.get_variable("dot_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) 166 | 167 | # Compute x * dot_weights first, the batch mult with x 168 | x_dots = x * tf.expand_dims(tf.expand_dims(dot_w, 0), 0) 169 | dot_logits = tf.matmul(x_dots, keys, transpose_b=True) 170 | 171 | return dot_logits + tf.expand_dims(key_logits, 1) + tf.expand_dims(x_logits, 2) 172 | 173 | @property 174 | def version(self): 175 | return 1 176 | -------------------------------------------------------------------------------- /docqa/nn/span_prediction_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | """ 6 | Some utility functions for dealing with span prediction in tensorflow 7 | """ 8 | 9 | 10 | def best_span_from_bounds(start_logits, end_logits, bound=None): 11 | """ 12 | Brute force approach to finding the best span from start/end logits in tensorflow, still usually 13 | faster then the python dynamic-programming version 14 | """ 15 | b = tf.shape(start_logits)[0] 16 | 17 | # Using `top_k` to get the index and value at once is faster 18 | # then using argmax and then gather to get in the value 19 | top_k = tf.nn.top_k(start_logits + end_logits, k=1) 20 | values, indices = [tf.squeeze(x, axis=[1]) for x in top_k] 21 | 22 | # Convert to (start_position, length) format 23 | indices = tf.stack([indices, tf.fill((b,), 0)], axis=1) 24 | 25 | # TODO Might be better to build the batch x n_word x n_word 26 | # matrix and use tf.matrix_band to zero out the unwanted ones... 27 | 28 | if bound is None: 29 | n_lengths = tf.shape(start_logits)[1] 30 | else: 31 | # take the min in case the bound > the context 32 | n_lengths = tf.minimum(bound, tf.shape(start_logits)[1]) 33 | 34 | def compute(i, values, indices): 35 | top_k = tf.nn.top_k(start_logits[:, :-i] + end_logits[:, i:]) 36 | b_values, b_indices = [tf.squeeze(x, axis=[1]) for x in top_k] 37 | 38 | b_indices = tf.stack([b_indices, tf.fill((b, ), i)], axis=1) 39 | indices = tf.where(b_values > values, b_indices, indices) 40 | values = tf.maximum(values, b_values) 41 | return i+1, values, indices 42 | 43 | _, values, indices = tf.while_loop( 44 | lambda ix, values, indices: ix < n_lengths, 45 | compute, 46 | [1, values, indices], 47 | back_prop=False) 48 | 49 | spans = tf.stack([indices[:, 0], indices[:, 0] + indices[:, 1]], axis=1) 50 | return spans, values 51 | 52 | 53 | def packed_span_f1_mask(spans, l, bound): 54 | starts = [] 55 | ends = [] 56 | for i in range(bound): 57 | s = tf.range(0, l - i, dtype=tf.int32) 58 | starts.append(s) 59 | ends.append(s + i) 60 | starts = tf.concat(starts, axis=0) 61 | ends = tf.concat(ends, axis=0) 62 | starts = tf.tile(tf.expand_dims(starts, 0), [tf.shape(spans)[0], 1]) 63 | ends = tf.tile(tf.expand_dims(ends, 0), [tf.shape(spans)[0], 1]) 64 | 65 | pred_len = tf.cast(ends - starts + 1, tf.float32) 66 | 67 | span_start = tf.maximum(starts, spans[:, 0:1]) 68 | span_stop = tf.minimum(ends, spans[:, 1:2]) 69 | 70 | overlap_len = tf.cast(span_stop - span_start + 1, tf.float32) 71 | true_len = tf.cast(spans[:, 1:2] - spans[:, 0:1] + 1, tf.float32) 72 | 73 | p = overlap_len / pred_len 74 | r = overlap_len / true_len 75 | return tf.where(overlap_len > 0, 2 * p * r / (p + r), tf.zeros(tf.shape(starts))) 76 | 77 | 78 | def to_packed_coordinates(spans, l, bound=None): 79 | """ Converts the spans to vector of packed coordiantes, in the packed format 80 | spans are indexed first by length, then by start position. If bound is given 81 | spans are truncated to be of `bound` length """ 82 | lens = spans[:, 1] - spans[:, 0] 83 | if bound is not None: 84 | lens = tf.minimum(lens, bound-1) 85 | return spans[:, 0] + l * lens - lens * (lens - 1) // 2 86 | 87 | 88 | def to_packed_coordinates_np(spans, l, bound=None): 89 | """ Converts the spans to vector of packed coordiantes, in the packed format 90 | spans are indexed first by length, then by start position in a flattened array. 91 | If bound is given spans are truncated to be of `bound` length """ 92 | lens = spans[:, 1] - spans[:, 0] 93 | if bound is not None: 94 | lens = np.minimum(lens, bound-1) 95 | return spans[:, 0] + l * lens - lens * (lens - 1) // 2 96 | 97 | 98 | def to_unpacked_coordinates(ix, l, bound): 99 | ix = tf.cast(ix, tf.int32) 100 | # You can actually compute the lens in closed form: 101 | # lens = tf.floor(0.5 * (-tf.sqrt(4 * tf.square(l) + 4 * l - 8 * ix + 1) + 2 * l + 1)) 102 | # but it is very ugly and rounding errors could cause problems, so this approach seems safer 103 | lens = [] 104 | for i in range(bound): 105 | lens.append(tf.fill((l - i,), i)) 106 | lens = tf.concat(lens, axis=0) 107 | lens = tf.gather(lens, ix) 108 | answer_start = ix - l * lens + lens * (lens - 1) // 2 109 | return tf.stack([answer_start, answer_start+lens], axis=1) 110 | 111 | -------------------------------------------------------------------------------- /docqa/scripts/ablate_squad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from docqa import model_dir 5 | from docqa import trainer 6 | from docqa.data_processing.multi_paragraph_qa import StratifyParagraphSetsBuilder, StratifyParagraphsBuilder, \ 7 | RandomParagraphSetDatasetBuilder 8 | from docqa.data_processing.preprocessed_corpus import PreprocessedData 9 | from docqa.data_processing.qa_training_data import ContextLenBucketedKey, ContextLenKey 10 | from docqa.data_processing.text_utils import NltkPlusStopWords 11 | from docqa.dataset import ClusteredBatcher 12 | from docqa.evaluator import LossEvaluator, MultiParagraphSpanEvaluator, SpanEvaluator 13 | from docqa.scripts.ablate_triviaqa import get_model 14 | from docqa.squad.squad_data import SquadCorpus, DocumentQaTrainingData 15 | from docqa.squad.squad_document_qa import SquadTfIdfRanker 16 | from docqa.text_preprocessor import WithIndicators 17 | from docqa.trainer import TrainParams, SerializableOptimizer 18 | 19 | 20 | def train_params(n_epochs): 21 | return TrainParams(SerializableOptimizer("Adadelta", dict(learning_rate=1.0)), 22 | ema=0.999, max_checkpoints_to_keep=3, async_encoding=10, 23 | num_epochs=n_epochs, log_period=30, eval_period=1200, save_period=1200, 24 | eval_samples=dict(dev=None, train=8000)) 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser(description='Train a model on document-level SQuAD') 29 | parser.add_argument('mode', choices=["paragraph", "confidence", "shared-norm", "merge", "sigmoid"]) 30 | parser.add_argument("name", help="Output directory") 31 | args = parser.parse_args() 32 | mode = args.mode 33 | out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") 34 | 35 | corpus = SquadCorpus() 36 | if mode == "merge": 37 | # Adds paragraph start tokens, since we will be concatenating paragraphs together 38 | pre = WithIndicators(True, para_tokens=False, doc_start_token=False) 39 | else: 40 | pre = None 41 | 42 | model = get_model(50, 100, args.mode, pre) 43 | 44 | if mode == "paragraph": 45 | # Run in the "standard" known-paragraph setting 46 | if model.preprocessor is not None: 47 | raise NotImplementedError() 48 | n_epochs = 26 49 | 50 | train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True, False) 51 | eval_batching = ClusteredBatcher(45, ContextLenKey(), False, False) 52 | data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching) 53 | eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")] 54 | else: 55 | eval_set_mode = { 56 | "confidence": "flatten", 57 | "sigmoid": "flatten", 58 | "shared-norm": "group", 59 | "merge": "merge"}[mode] 60 | eval_dataset = RandomParagraphSetDatasetBuilder(100, eval_set_mode, True, 0) 61 | 62 | if mode == "confidence" or mode == "sigmoid": 63 | if mode == "sigmoid": 64 | # needs to be trained for a really long time for reasons unknown, even this might be too small 65 | n_epochs = 100 66 | else: 67 | n_epochs = 50 # more epochs since we only "see" the label very other epoch-osh 68 | train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True, False) 69 | data = PreprocessedData( 70 | SquadCorpus(), 71 | SquadTfIdfRanker(NltkPlusStopWords(True), 4, True, model.preprocessor), 72 | StratifyParagraphsBuilder(train_batching, 1), 73 | eval_dataset, 74 | eval_on_verified=False, 75 | ) 76 | else: 77 | n_epochs = 26 78 | data = PreprocessedData( 79 | SquadCorpus(), 80 | SquadTfIdfRanker(NltkPlusStopWords(True), 4, True, model.preprocessor), 81 | StratifyParagraphSetsBuilder(25, args.mode == "merge", True, 1), 82 | eval_dataset, 83 | eval_on_verified=False, 84 | ) 85 | 86 | eval = [LossEvaluator(), MultiParagraphSpanEvaluator(17, "squad")] 87 | data.preprocess(1) 88 | 89 | with open(__file__, "r") as f: 90 | notes = f.read() 91 | notes = args.mode + "\n" + notes 92 | 93 | params = train_params(n_epochs) 94 | if mode == "paragraph": 95 | params.best_weights = ("dev", "b17/text-f1") 96 | 97 | trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /docqa/scripts/ablate_triviaqa_unfiltered.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from docqa import model_dir 5 | from docqa import trainer 6 | from docqa.data_processing.document_splitter import MergeParagraphs, ShallowOpenWebRanker 7 | from docqa.data_processing.multi_paragraph_qa import StratifyParagraphsBuilder, \ 8 | StratifyParagraphSetsBuilder, RandomParagraphSetDatasetBuilder 9 | from docqa.data_processing.preprocessed_corpus import PreprocessedData 10 | from docqa.data_processing.qa_training_data import ContextLenBucketedKey 11 | from docqa.dataset import ClusteredBatcher 12 | from docqa.evaluator import LossEvaluator, MultiParagraphSpanEvaluator 13 | from docqa.scripts.ablate_triviaqa import get_model 14 | from docqa.text_preprocessor import WithIndicators 15 | from docqa.trainer import SerializableOptimizer, TrainParams 16 | from docqa.triviaqa.build_span_corpus import TriviaQaOpenDataset 17 | from docqa.triviaqa.training_data import ExtractMultiParagraphsPerQuestion 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA unfiltered') 22 | parser.add_argument('mode', choices=["confidence", "merge", "shared-norm", 23 | "sigmoid", "paragraph"]) 24 | parser.add_argument("name", help="Where to store the model") 25 | parser.add_argument("-t", "--n_tokens", default=400, type=int, 26 | help="Paragraph size") 27 | parser.add_argument('-n', '--n_processes', type=int, default=2, 28 | help="Number of processes (i.e., select which paragraphs to train on) " 29 | "the data with" 30 | ) 31 | args = parser.parse_args() 32 | mode = args.mode 33 | 34 | out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") 35 | 36 | model = get_model(100, 140, mode, WithIndicators()) 37 | 38 | extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens), ShallowOpenWebRanker(16), 39 | model.preprocessor, intern=True) 40 | 41 | eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge", per_doc=False)] 42 | oversample = [1] * 4 43 | 44 | if mode == "paragraph": 45 | n_epochs = 120 46 | test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) 47 | train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), 48 | oversample, only_answers=True) 49 | elif mode == "confidence" or mode == "sigmoid": 50 | if mode == "sigmoid": 51 | n_epochs = 640 52 | else: 53 | n_epochs = 160 54 | test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) 55 | train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample) 56 | else: 57 | n_epochs = 80 58 | test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, oversample) 59 | train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) 60 | 61 | data = TriviaQaOpenDataset() 62 | 63 | params = TrainParams( 64 | SerializableOptimizer("Adadelta", dict(learning_rate=1)), 65 | num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2, 66 | async_encoding=10, log_period=30, eval_period=1800, save_period=1800, 67 | eval_samples=dict(dev=None, train=6000) 68 | ) 69 | 70 | data = PreprocessedData(data, extract, train, test, eval_on_verified=False) 71 | 72 | data.preprocess(args.n_processes, 1000) 73 | 74 | with open(__file__, "r") as f: 75 | notes = f.read() 76 | notes = "Mode: " + args.mode + "\n" + notes 77 | 78 | trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /docqa/scripts/ablate_triviaqa_wiki.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from docqa import model_dir 5 | from docqa import trainer 6 | from docqa.data_processing.document_splitter import MergeParagraphs, ShallowOpenWebRanker 7 | from docqa.data_processing.multi_paragraph_qa import StratifyParagraphsBuilder, \ 8 | StratifyParagraphSetsBuilder, RandomParagraphSetDatasetBuilder 9 | from docqa.data_processing.preprocessed_corpus import PreprocessedData 10 | from docqa.data_processing.qa_training_data import ContextLenBucketedKey 11 | from docqa.dataset import ClusteredBatcher 12 | from docqa.evaluator import LossEvaluator, MultiParagraphSpanEvaluator 13 | from docqa.scripts.ablate_triviaqa import get_model 14 | from docqa.text_preprocessor import WithIndicators 15 | from docqa.trainer import SerializableOptimizer, TrainParams 16 | from docqa.triviaqa.build_span_corpus import TriviaQaOpenDataset, TriviaQaWikiDataset 17 | from docqa.triviaqa.training_data import ExtractMultiParagraphsPerQuestion 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description='Train a model on TriviaQA wiki') 22 | parser.add_argument('mode', choices=["confidence", "merge", "shared-norm", 23 | "sigmoid", "paragraph"]) 24 | # Note I haven't tested modes other than `shared-norm` on this corpus, so 25 | # some things might need adjusting 26 | parser.add_argument("name", help="Where to store the model") 27 | parser.add_argument("-t", "--n_tokens", default=400, type=int, 28 | help="Paragraph size") 29 | parser.add_argument('-n', '--n_processes', type=int, default=2, 30 | help="Number of processes (i.e., select which paragraphs to train on) " 31 | "the data with" 32 | ) 33 | args = parser.parse_args() 34 | mode = args.mode 35 | 36 | out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") 37 | 38 | model = get_model(100, 140, mode, WithIndicators()) 39 | 40 | extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens), 41 | ShallowOpenWebRanker(16), 42 | model.preprocessor, intern=True) 43 | 44 | eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge", per_doc=False)] 45 | oversample = [1] * 2 # Sample the top two answer-containing paragraphs twice 46 | 47 | if mode == "paragraph": 48 | n_epochs = 120 49 | test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) 50 | train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), 51 | oversample, only_answers=True) 52 | elif mode == "confidence" or mode == "sigmoid": 53 | if mode == "sigmoid": 54 | n_epochs = 640 55 | else: 56 | n_epochs = 160 57 | test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) 58 | train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample) 59 | else: 60 | n_epochs = 80 61 | test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, oversample) 62 | train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) 63 | 64 | data = TriviaQaWikiDataset() 65 | 66 | params = TrainParams( 67 | SerializableOptimizer("Adadelta", dict(learning_rate=1)), 68 | num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2, 69 | async_encoding=10, log_period=30, eval_period=1800, save_period=1800, 70 | best_weights=("dev", "b8/question-text-f1"), 71 | eval_samples=dict(dev=None, train=6000) 72 | ) 73 | 74 | data = PreprocessedData(data, extract, train, test, eval_on_verified=False) 75 | 76 | data.preprocess(args.n_processes, 1000) 77 | 78 | with open(__file__, "r") as f: 79 | notes = f.read() 80 | notes = "Mode: " + args.mode + "\n" + notes 81 | 82 | trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /docqa/scripts/build_pruned_voc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from docqa.data_processing.word_vectors import load_word_vectors 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("vecs") 10 | parser.add_argument("vocab") 11 | parser.add_argument("output") 12 | args = parser.parse_args() 13 | 14 | voc = set() 15 | with open(args.vocab) as f: 16 | for line in f: 17 | voc.add(line.strip()) 18 | 19 | voc = load_word_vectors(args.vecs, voc) 20 | with open(args.output, "wb") as f: 21 | pickle.dump(voc, f) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() -------------------------------------------------------------------------------- /docqa/scripts/continue.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from docqa.trainer import resume_training 4 | from docqa.model_dir import ModelDir 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='') 9 | parser.add_argument('name', help='name of output to exmaine') 10 | parser.add_argument('--eval', "-e", action="store_true") 11 | args = parser.parse_args() 12 | 13 | resume_training(ModelDir(args.name), start_eval=args.eval) 14 | 15 | 16 | if __name__ == "__main__": 17 | main() -------------------------------------------------------------------------------- /docqa/scripts/convert_to_cpu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from os import mkdir, listdir 4 | from os.path import exists, isfile, join 5 | from shutil import copyfile 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops 10 | 11 | from docqa.data_processing.qa_training_data import ParagraphAndQuestionSpec, ParagraphAndQuestion 12 | from docqa.model_dir import ModelDir 13 | from docqa.nn.recurrent_layers import BiRecurrentMapper, CompatGruCellSpec 14 | from docqa.utils import ResourceLoader 15 | 16 | """ 17 | Script to convert out models to cpu version, Due to what appears to be a bug in tf 18 | (https://github.com/tensorflow/tensorflow/issues/13254) RNNParamsSavable is not working for me, 19 | if it was we could probably implement this in the Cudnn layers. Instead we complete the transform 20 | manually, which means this will only work for our models. 21 | """ 22 | 23 | 24 | def convert(model_dir, output_dir, best_weights=False): 25 | print("Load model") 26 | md = ModelDir(model_dir) 27 | model = md.get_model() 28 | dim = model.embed_mapper.layers[1].n_units 29 | global_step = tf.get_variable('global_step', shape=[], dtype='int32', 30 | initializer=tf.constant_initializer(0), trainable=False) 31 | 32 | print("Setting up cudnn version") 33 | # global_step = tf.get_variable('global_step', shape=[], dtype='int32', trainable=False) 34 | sess = tf.Session() 35 | with sess.as_default(): 36 | model.set_input_spec(ParagraphAndQuestionSpec(1, None, None, 14), 37 | {"the"}, 38 | ResourceLoader(lambda a,b: {"the": np.zeros(300, np.float32)})) 39 | 40 | print("Buiding graph") 41 | pred = model.get_prediction() 42 | 43 | test_questions = ParagraphAndQuestion(["Harry", "Potter", "was", "written", "by", "JK"], 44 | ["Who", "wrote", "Harry", "Potter", "?"], 45 | None, "test_questions") 46 | 47 | print("Load vars") 48 | md.restore_checkpoint(sess) 49 | 50 | feed = model.encode([test_questions], False) 51 | cuddn_out = sess.run([pred.start_logits, pred.end_logits], feed_dict=feed) 52 | 53 | print("Done, copying files...") 54 | if not exists(output_dir): 55 | mkdir(output_dir) 56 | for file in listdir(model_dir): 57 | if isfile(file) and file != "model.npy": 58 | copyfile(join(model_dir, file), join(output_dir, file)) 59 | 60 | print("Done, mapping tensors...") 61 | to_save = [] 62 | to_init = [] 63 | for x in tf.trainable_variables(): 64 | if x.name.endswith("/gru_parameters:0"): 65 | key = x.name[:-len("/gru_parameters:0")] 66 | fw_params = x 67 | if "map_embed" in x.name: 68 | c = cudnn_rnn_ops.CudnnGRU(1, dim, 400) 69 | elif "chained-out" in x.name: 70 | c = cudnn_rnn_ops.CudnnGRU(1, dim, dim * 4) 71 | else: 72 | c = cudnn_rnn_ops.CudnnGRU(1, dim, dim * 2) 73 | params_saveable = cudnn_rnn_ops.RNNParamsSaveable( 74 | c, c.params_to_canonical, 75 | c.canonical_to_params, [fw_params], 76 | key) 77 | 78 | for spec in params_saveable.specs: 79 | if spec.name.endswith("bias_cudnn 0") or \ 80 | spec.name.endswith("bias_cudnn 1"): 81 | # ??? What do these even do? 82 | continue 83 | name = spec.name.split("/") 84 | name.remove("cell_0") 85 | if "forward" in name: 86 | ix = name.index("forward") 87 | name.insert(ix+2, "fw") 88 | else: 89 | ix = name.index("backward") 90 | name.insert(ix + 2, "bw") 91 | del name[ix] 92 | 93 | ix = name.index("multi_rnn_cell") 94 | name[ix] = "bidirectional_rnn" 95 | name = "/".join(name) 96 | v = tf.Variable(sess.run(spec.tensor), name=name) 97 | to_init.append(v) 98 | to_save.append(v) 99 | 100 | else: 101 | to_save.append(x) 102 | 103 | other = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if x not in tf.trainable_variables()] 104 | print(other) 105 | sess.run(tf.initialize_variables(to_init)) 106 | saver = tf.train.Saver(to_save + other) 107 | save_dir = join(output_dir, "save") 108 | if not exists(save_dir): 109 | mkdir(save_dir) 110 | 111 | saver.save(sess, join(save_dir, "checkpoint"), sess.run(global_step)) 112 | 113 | sess.close() 114 | tf.reset_default_graph() 115 | 116 | print("Updating model...") 117 | model.embed_mapper.layers = [model.embed_mapper.layers[0], 118 | BiRecurrentMapper(CompatGruCellSpec(dim))] 119 | model.match_encoder.layers = list(model.match_encoder.layers) 120 | other = model.match_encoder.layers[1].other 121 | other.layers = list(other.layers) 122 | other.layers[1] = BiRecurrentMapper(CompatGruCellSpec(dim)) 123 | 124 | pred = model.predictor.predictor 125 | pred.first_layer = BiRecurrentMapper(CompatGruCellSpec(dim)) 126 | pred.second_layer = BiRecurrentMapper(CompatGruCellSpec(dim)) 127 | 128 | with open(join(output_dir, "model.pkl"), "wb") as f: 129 | pickle.dump(model, f) 130 | 131 | print("Testing...") 132 | with open(join(output_dir, "model.pkl"), "rb") as f: 133 | model = pickle.load(f) 134 | 135 | sess = tf.Session() 136 | 137 | model.set_input_spec(ParagraphAndQuestionSpec(1, None, None, 14), 138 | {"the"}, 139 | ResourceLoader(lambda a, b: {"the": np.zeros(300, np.float32)})) 140 | pred = model.get_prediction() 141 | 142 | print("Rebuilding") 143 | saver = tf.train.Saver() 144 | saver.restore(sess, tf.train.latest_checkpoint(save_dir)) 145 | 146 | feed = model.encode([test_questions], False) 147 | cpu_out = sess.run([pred.start_logits, pred.end_logits], feed_dict=feed) 148 | 149 | print("These should be close:") 150 | print([np.allclose(a, b) for a,b in zip(cpu_out, cuddn_out)]) 151 | print(cpu_out) 152 | print(cuddn_out) 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument("target_model") 158 | parser.add_argument("output_dir") 159 | parser.add_argument("--best_weights", action="store_true") 160 | args = parser.parse_args() 161 | convert(args.target_model, args.output_dir, args.best_weights) 162 | 163 | if __name__ == "__main__": 164 | main() 165 | 166 | -------------------------------------------------------------------------------- /docqa/scripts/run_on_user_documents.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import isfile 3 | 4 | import re 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from docqa.data_processing.document_splitter import MergeParagraphs, TopTfIdf, ShallowOpenWebRanker, PreserveParagraphs 9 | from docqa.data_processing.qa_training_data import ParagraphAndQuestion, ParagraphAndQuestionSpec 10 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer, NltkPlusStopWords 11 | from docqa.doc_qa_models import ParagraphQuestionModel 12 | from docqa.model_dir import ModelDir 13 | from docqa.utils import flatten_iterable 14 | 15 | """ 16 | Script to run a model on user provided question/context document. 17 | This demonstrates how to use our document-pipeline on new input 18 | """ 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description="Run an ELMo model on user input") 23 | parser.add_argument("model", help="Model directory") 24 | parser.add_argument("question", help="Question to answer") 25 | parser.add_argument("documents", help="List of text documents to answer the question with", nargs='+') 26 | args = parser.parse_args() 27 | 28 | print("Preprocessing...") 29 | 30 | # Load the model 31 | model_dir = ModelDir(args.model) 32 | model = model_dir.get_model() 33 | if not isinstance(model, ParagraphQuestionModel): 34 | raise ValueError("This script is built to work for ParagraphQuestionModel models only") 35 | 36 | # Read the documents 37 | documents = [] 38 | for doc in args.documents: 39 | if not isfile(doc): 40 | raise ValueError(doc + " does not exist") 41 | with open(doc, "r") as f: 42 | documents.append(f.read()) 43 | print("Loaded %d documents" % len(documents)) 44 | 45 | # Split documents into lists of paragraphs 46 | documents = [re.split("\s*\n\s*", doc) for doc in documents] 47 | 48 | # Tokenize the input, the models expects data to be tokenized using `NltkAndPunctTokenizer` 49 | # Note the model expects case-sensitive input 50 | tokenizer = NltkAndPunctTokenizer() 51 | question = tokenizer.tokenize_paragraph_flat(args.question) # List of words 52 | # Now list of document->paragraph->sentence->word 53 | documents = [[tokenizer.tokenize_paragraph(p) for p in doc] for doc in documents] 54 | 55 | # Now group the document into paragraphs, this returns `ExtractedParagraph` objects 56 | # that additionally remember the start/end token of the paragraph within the source document 57 | splitter = MergeParagraphs(400) 58 | # splitter = PreserveParagraphs() # Uncomment to use the natural paragraph grouping 59 | documents = [splitter.split(doc) for doc in documents] 60 | 61 | # Now select the top paragraphs using a `ParagraphFilter` 62 | if len(documents) == 1: 63 | # Use TF-IDF to select top paragraphs from the document 64 | selector = TopTfIdf(NltkPlusStopWords(True), n_to_select=5) 65 | context = selector.prune(question, documents[0]) 66 | else: 67 | # Use a linear classifier to select top paragraphs among all the documents 68 | selector = ShallowOpenWebRanker(n_to_select=10) 69 | context = selector.prune(question, flatten_iterable(documents)) 70 | 71 | print("Select %d paragraph" % len(context)) 72 | 73 | if model.preprocessor is not None: 74 | # Models are allowed to define an additional pre-processing step 75 | # This will turn the `ExtractedParagraph` objects back into simple lists of tokens 76 | context = [model.preprocessor.encode_text(question, x) for x in context] 77 | else: 78 | # Otherwise just use flattened text 79 | context = [flatten_iterable(x.text) for x in context] 80 | 81 | print("Setting up model") 82 | # Tell the model the batch size (can be None) and vocab to expect, This will load the 83 | # needed word vectors and fix the batch size to use when building the graph / encoding the input 84 | voc = set(question) 85 | for txt in context: 86 | voc.update(txt) 87 | model.set_input_spec(ParagraphAndQuestionSpec(batch_size=len(context)), voc) 88 | 89 | # Now we build the actual tensorflow graph, `best_span` and `conf` are 90 | # tensors holding the predicted span (inclusive) and confidence scores for each 91 | # element in the input batch, confidence scores being the pre-softmax logit for the span 92 | print("Build tf graph") 93 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 94 | # We need to use sess.as_default when working with the cuNND stuff, since we need an active 95 | # session to figure out the # of parameters needed for each layer. The cpu-compatible models don't need this. 96 | with sess.as_default(): 97 | # 8 means to limit the span to size 8 or less 98 | best_spans, conf = model.get_prediction().get_best_span(8) 99 | 100 | # Loads the saved weights 101 | model_dir.restore_checkpoint(sess) 102 | 103 | # Now the model is ready to run 104 | # The model takes input in the form of `ContextAndQuestion` objects, for example: 105 | data = [ParagraphAndQuestion(x, question, None, "user-question%d"%i) 106 | for i, x in enumerate(context)] 107 | 108 | print("Starting run") 109 | # The model is run in two steps, first it "encodes" a batch of paragraph/context pairs 110 | # into numpy arrays, then we use `sess` to run the actual model get the predictions 111 | encoded = model.encode(data, is_train=False) # batch of `ContextAndQuestion` -> feed_dict 112 | best_spans, conf = sess.run([best_spans, conf], feed_dict=encoded) # feed_dict -> predictions 113 | 114 | best_para = np.argmax(conf) # We get output for each paragraph, select the most-confident one to print 115 | print("Best Paragraph: " + str(best_para)) 116 | print("Best span: " + str(best_spans[best_para])) 117 | print("Answer text: " + " ".join(context[best_para][best_spans[best_para][0]:best_spans[best_para][1]+1])) 118 | print("Confidence: " + str(conf[best_para])) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() -------------------------------------------------------------------------------- /docqa/scripts/show_parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | from docqa.model_dir import ModelDir 4 | import numpy as np 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='') 9 | parser.add_argument("model") 10 | args = parser.parse_args() 11 | 12 | model_dir = ModelDir(args.model) 13 | checkpoint = model_dir.get_best_weights() 14 | print(checkpoint) 15 | if checkpoint is None: 16 | print("Show latest checkpoint") 17 | checkpoint = model_dir.get_latest_checkpoint() 18 | else: 19 | print("Show best weights") 20 | 21 | reader = tf.train.NewCheckpointReader(checkpoint) 22 | param_map = reader.get_variable_to_shape_map() 23 | total = 0 24 | for k in sorted(param_map): 25 | v = param_map[k] 26 | print('%s: %s' % (k, str(v))) 27 | total += np.prod(v) 28 | 29 | print("%d total" % total) 30 | 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | 36 | 37 | -------------------------------------------------------------------------------- /docqa/scripts/train_bidaf.py: -------------------------------------------------------------------------------- 1 | from docqa import model_dir 2 | from docqa import trainer 3 | from docqa.data_processing.qa_training_data import ContextLenBucketedKey, ContextLenKey 4 | from docqa.dataset import ClusteredBatcher 5 | from docqa.doc_qa_models import Attention 6 | from docqa.encoder import DocumentAndQuestionEncoder, SingleSpanAnswerEncoder 7 | from docqa.evaluator import LossEvaluator, SpanEvaluator 8 | from docqa.nn.attention import BiAttention 9 | from docqa.nn.embedder import FixedWordEmbedder, CharWordEmbedder, LearnedCharEmbedder 10 | from docqa.nn.layers import NullBiMapper, NullMapper, SequenceMapperSeq, ReduceLayer, Conv1d, HighwayLayer, ChainConcat, \ 11 | DropoutLayer 12 | from docqa.nn.recurrent_layers import CudnnLstm 13 | from docqa.nn.similarity_layers import TriLinear 14 | from docqa.nn.span_prediction import BoundsPredictor 15 | from docqa.squad.build_squad_dataset import SquadCorpus 16 | from docqa.squad.squad_data import DocumentQaTrainingData 17 | from docqa.trainer import SerializableOptimizer, TrainParams 18 | from docqa.utils import get_output_name_from_cli 19 | 20 | 21 | def main(): 22 | """ 23 | A close-as-possible impelemntation of BiDaF, its based on the `dev` tensorflow 1.1 branch of Ming's repo 24 | which, in particular, uses Adam not Adadelta. I was not able to replicate the results in paper using Adadelta, 25 | but with Adam i was able to get to 78.0 F1 on the dev set with this scripts. I believe this approach is 26 | an exact reproduction up the code in the repo, up to initializations. 27 | 28 | Notes: Exponential Moving Average is very important, as is early stopping. This is also in particualr best run 29 | on a GPU due to the large number of parameters and batch size involved. 30 | """ 31 | out = get_output_name_from_cli() 32 | 33 | train_params = TrainParams(SerializableOptimizer("Adam", dict(learning_rate=0.001)), 34 | num_epochs=12, ema=0.999, async_encoding=10, 35 | log_period=30, eval_period=1000, save_period=1000, 36 | eval_samples=dict(dev=None, train=8000)) 37 | 38 | # recurrent_layer = BiRecurrentMapper(LstmCellSpec(100, keep_probs=0.8)) 39 | # recurrent_layer = FusedLstm() 40 | recurrent_layer = SequenceMapperSeq(DropoutLayer(0.8), CudnnLstm(100)) 41 | 42 | model = Attention( 43 | encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()), 44 | word_embed=FixedWordEmbedder(vec_name="glove.6B.100d", word_vec_init_scale=0, learn_unk=False), 45 | char_embed=CharWordEmbedder( 46 | embedder=LearnedCharEmbedder(16, 49, 8), 47 | layer=ReduceLayer("max", Conv1d(100, 5, 0.8), mask=False), 48 | shared_parameters=True 49 | ), 50 | word_embed_layer=None, 51 | embed_mapper=SequenceMapperSeq( 52 | HighwayLayer(activation="relu"), HighwayLayer(activation="relu"), 53 | recurrent_layer), 54 | preprocess=None, 55 | question_mapper=None, 56 | context_mapper=None, 57 | memory_builder=NullBiMapper(), 58 | attention=BiAttention(TriLinear(bias=True), True), 59 | match_encoder=NullMapper(), 60 | predictor= BoundsPredictor( 61 | ChainConcat( 62 | start_layer=SequenceMapperSeq( 63 | recurrent_layer, 64 | recurrent_layer), 65 | end_layer=recurrent_layer 66 | ) 67 | ), 68 | 69 | ) 70 | 71 | with open(__file__, "r") as f: 72 | notes = f.read() 73 | 74 | eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")] 75 | 76 | corpus = SquadCorpus() 77 | train_batching = ClusteredBatcher(60, ContextLenBucketedKey(3), True, False) 78 | eval_batching = ClusteredBatcher(60, ContextLenKey(), False, False) 79 | data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching) 80 | 81 | trainer.start_training(data, model, train_params, eval, model_dir.ModelDir(out), notes) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /docqa/server/README.md: -------------------------------------------------------------------------------- 1 | ## Server 2 | This contains our code to run the demo server. It uses [bing search](https://azure.microsoft.com/en-us/services/cognitive-services/bing-web-search-api/) 3 | to located web documents related to the input question, and/or [TAGME](https://tagme.d4science.org/tagme/) to 4 | locate relevant Wikipedia documents. Both services requires API keys, TAGME is free but Bing charges a small 5 | fee for each search. 6 | 7 | Running this code requires some additional dependencies, they can be installed with 8 | 9 | `pip install -r docqa/server/requirements.txt` 10 | 11 | docqa/server/qa_system.py contains the end-to-end question answering system that can use these services to answer questions. 12 | 13 | docqa/server/server.py contains the code to run the server. 14 | 15 | -------------------------------------------------------------------------------- /docqa/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/server/__init__.py -------------------------------------------------------------------------------- /docqa/server/boilerpipe.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/server/boilerpipe.jar -------------------------------------------------------------------------------- /docqa/server/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm>=4.14.0 2 | nltk>=3.2.4 3 | ujson>=1.3 4 | beautifulsoup4>=4.0.0 5 | requests>=2.18.0 6 | lxml>=3.0.0 7 | aiohttp==2.2.5 8 | sanic==0.6.0 9 | -------------------------------------------------------------------------------- /docqa/server/static/about.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 |

Introduction

7 | This is a demonstration of a question answering system from the Allen Institute for Artificial Intelligence 8 | 9 |

How to Use it

10 | Type a question into the search box and click the search icon or press enter, the system will return its best guess as to the answer. 11 | Below that we show the text that was searched to find the answer. The answer returned will be highlighted in red, other 12 | candidate locations that were given a comparable score will be highlighted in lighter shades. 13 | Be patient, answering a question can take 30+ seconds. 14 |

15 | To try to extract the answer from your own document, instead of the using the web, set Search to "Document" 16 | and upload a document, or set Search to "Text" and type/copy in your own document. The document must be plain text 17 | with paragraphs separated by newlines. Your results may vary using this approach 18 | since the model was only optimized for using web search results. 19 | 20 |

How it Works

21 | The system will (unless you provided a document) run a web search on the question, and additionally 22 | try to identify Wikipedia articles about entities mentioned in the question. The resulting documents will 23 | be passed to a machine learning algorithm which will try to read the text and identify a span of text within one 24 | of the documents that answers your questions. No knowledge bases or other sources of information are used. 25 | 26 |

Example Questions

27 | 36 | 37 |

Weaknesses/Limitations

38 | The system can answer short answer questions, most other forms of questions are unlikely work, including: 39 | 46 | 47 | The system has some weaknesses you might observe 48 | 53 | 54 |

References

55 | 59 | 60 | 61 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /docqa/server/web_searcher.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import logging 4 | 5 | import ujson 6 | import asyncio 7 | from aiohttp import ClientSession 8 | from os.path import exists 9 | 10 | BING_API = "https://api.cognitive.microsoft.com/bing/" 11 | 12 | 13 | class AsyncWebSearcher(object): 14 | """ Runs search requests and returns the results """ 15 | 16 | def __init__(self, bing_api, bing_version, loop=None): 17 | if bing_api is None or not isinstance(bing_api, str): 18 | raise ValueError("Need a string Bing API key") 19 | self.bing_api = bing_api 20 | self.url = BING_API + bing_version + "/search" 21 | self.cl_sess = ClientSession(headers={"Ocp-Apim-Subscription-Key": self.bing_api}, loop=loop) 22 | 23 | async def run_search(self, question: str, n_docs: int) -> List[Dict]: 24 | # avoid quoting the entire question, some triviaqa questions have this form 25 | # TODO is this the right place to do this? 26 | question = question.strip("\"\' ") 27 | async with self.cl_sess.get(url=self.url, params=dict(count=n_docs, q=question, mkt="en-US")) as resp: 28 | data = await resp.json() 29 | if resp.status != 200: 30 | raise ValueError("Web search error %s" % data) 31 | 32 | if "webPages" not in data: 33 | return [] 34 | else: 35 | return data["webPages"]["value"] 36 | 37 | def close(self): 38 | self.cl_sess.close() 39 | 40 | 41 | class ExtractedWebDoc(object): 42 | def __init__(self, ur: str, text: str): 43 | self.url = ur 44 | self.text = text 45 | 46 | 47 | class AsyncBoilerpipeCliExtractor(object): 48 | """ 49 | Downloads documents from URLs and returns the extracted text 50 | 51 | TriviaQA used boilerpipe (https://github.com/kohlschutter/boilerpipe) to extract the 52 | "main" pieces of text from web documents. There is, far as I can tell, no complete 53 | python re-implementation so far the moment we shell out to a jar file (boilerpipe.jar) 54 | which downloads files from the given URLs and runs them through boilerpipe's extraction code 55 | using multiple threads. 56 | """ 57 | 58 | JAR = "docqa/server/boilerpipe.jar" 59 | 60 | def __init__(self, n_threads: int=10, timeout: int=None, 61 | process_additional_timeout: Optional[int]=5): 62 | """ 63 | :param n_threads: Number of threads to use when downloading urls 64 | :param timeout: Time to wait while downloading urls, if the time limit is reached 65 | downloads that are still hanging will be returned as errors 66 | :param process_additional_timeout: How long to wait for the downloading sub-process to return, 67 | in addition to `timeout`. If this timeout is hit no results will 68 | be returned, so this is a last-resort to stop the server from freezing 69 | """ 70 | self.log = logging.getLogger('downloader') 71 | if not exists(self.JAR): 72 | raise ValueError("Could not find boilerpipe jar") 73 | self.timeout = timeout 74 | self.n_threads = n_threads 75 | if self.timeout is None: 76 | self.proc_timeout = None 77 | else: 78 | self.proc_timeout = timeout + process_additional_timeout 79 | 80 | async def get_text(self, urls: List[str]) -> List[ExtractedWebDoc]: 81 | process = await asyncio.create_subprocess_exec( 82 | "java", "-jar", self.JAR, *urls, "-t", str(self.n_threads), 83 | "-l", str(self.timeout), 84 | stdout=asyncio.subprocess.PIPE) 85 | stdout, stderr = await asyncio.wait_for(process.communicate(), 86 | timeout=self.proc_timeout) 87 | text = stdout.decode("utf-8") 88 | data = ujson.loads(text) 89 | ex = data["extracted"] 90 | errors = data["error"] 91 | if len(errors) > 0: 92 | self.log.info("%d extraction errors: %s" % (len(errors), str(list(errors.items())))) 93 | return [ExtractedWebDoc(url, ex[url]) for url in urls if url in ex] 94 | -------------------------------------------------------------------------------- /docqa/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/squad/__init__.py -------------------------------------------------------------------------------- /docqa/squad/build_squad_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import urllib 4 | from os import listdir, mkdir 5 | from os.path import expanduser, join, exists 6 | from typing import List 7 | 8 | from tqdm import tqdm 9 | 10 | from docqa import config 11 | from docqa.squad.squad_data import Question, Document, Paragraph, SquadCorpus 12 | from docqa.data_processing.span_data import ParagraphSpan, ParagraphSpans 13 | from docqa.data_processing.text_utils import get_word_span, space_re, NltkAndPunctTokenizer 14 | from docqa.utils import flatten_iterable 15 | 16 | """ 17 | Script to build a corpus from SQUAD training data 18 | """ 19 | 20 | 21 | def clean_title(title): 22 | """ Squad titles use URL escape formatting, this method undoes it to get the wiki-title""" 23 | return urllib.parse.unquote(title).replace("_", " ") 24 | 25 | 26 | def parse_squad_data(source, name, tokenizer, use_tqdm=True) -> List[Document]: 27 | with open(source, 'r') as f: 28 | source_data = json.load(f) 29 | 30 | if use_tqdm: 31 | iter_files = tqdm(source_data['data'], ncols=80) 32 | else: 33 | iter_files = source_data['data'] 34 | 35 | for article_ix, article in enumerate(iter_files): 36 | article_ix = "%s-%d" % (name, article_ix) 37 | 38 | paragraphs = [] 39 | 40 | for para_ix, para in enumerate(article['paragraphs']): 41 | questions = [] 42 | context = para['context'] 43 | 44 | tokenized = tokenizer.tokenize_with_inverse(context) 45 | # list of sentences + mapping from words -> original text index 46 | text, text_spans = tokenized.text, tokenized.spans 47 | flat_text = flatten_iterable(text) 48 | 49 | n_words = sum(len(sentence) for sentence in text) 50 | 51 | for question_ix, question in enumerate(para['qas']): 52 | # There are actually some multi-sentence questions, so we should have used 53 | # tokenizer.tokenize_paragraph_flat here which would have produced slighy better 54 | # results in a few cases. However all the results we report were 55 | # done using `tokenize_sentence` so I am just going to leave this way 56 | question_text = tokenizer.tokenize_sentence(question['question']) 57 | 58 | answer_spans = [] 59 | for answer_ix, answer in enumerate(question['answers']): 60 | answer_raw = answer['text'] 61 | 62 | answer_start = answer['answer_start'] 63 | answer_stop = answer_start + len(answer_raw) 64 | 65 | word_ixs = get_word_span(text_spans, answer_start, answer_stop) 66 | 67 | first_word = flat_text[word_ixs[0]] 68 | first_word_span = text_spans[word_ixs[0]] 69 | last_word = flat_text[word_ixs[-1]] 70 | last_word_span = text_spans[word_ixs[-1]] 71 | 72 | char_start = answer_start - first_word_span[0] 73 | char_end = answer_stop - last_word_span[0] 74 | 75 | # Sanity check to ensure we can rebuild the answer using the word and char indices 76 | # Since we might not be able to "undo" the tokenizing exactly we might not be able to exactly 77 | # rebuild 'answer_raw', so just we check that we can rebuild the answer minus spaces 78 | if len(word_ixs) == 1: 79 | if first_word[char_start:char_end] != answer_raw: 80 | raise ValueError() 81 | else: 82 | rebuild = first_word[char_start:] 83 | for word_ix in word_ixs[1:-1]: 84 | rebuild += flat_text[word_ix] 85 | rebuild += last_word[:char_end] 86 | if rebuild != space_re.sub("", tokenizer.clean_text(answer_raw)): 87 | raise ValueError(rebuild + " " + answer_raw) 88 | 89 | # Find the sentence with in-sentence offset 90 | sent_start, sent_end, word_start, word_end = None, None, None, None 91 | on_word = 0 92 | for sent_ix, sent in enumerate(text): 93 | next_word = on_word + len(sent) 94 | if on_word <= word_ixs[0] < next_word: 95 | sent_start = sent_ix 96 | word_start = word_ixs[0] - on_word 97 | if on_word <= word_ixs[-1] < next_word: 98 | sent_end = sent_ix 99 | word_end = word_ixs[-1] - on_word 100 | break 101 | on_word = next_word 102 | 103 | # Sanity check these as well 104 | if text[sent_start][word_start] != flat_text[word_ixs[0]]: 105 | raise RuntimeError() 106 | if text[sent_end][word_end] != flat_text[word_ixs[-1]]: 107 | raise RuntimeError() 108 | 109 | span = ParagraphSpan( 110 | sent_start, word_start, char_start, 111 | sent_end, word_end, char_end, 112 | word_ixs[0], word_ixs[-1], 113 | answer_raw) 114 | if span.para_word_end >= n_words or \ 115 | span.para_word_start >= n_words: 116 | raise RuntimeError() 117 | answer_spans.append(span) 118 | 119 | questions.append(Question(question['id'], question_text, ParagraphSpans(answer_spans))) 120 | 121 | paragraphs.append(Paragraph(text, questions, article_ix, para_ix, context, text_spans)) 122 | 123 | yield Document(article_ix, article["title"], paragraphs) 124 | 125 | 126 | def main(): 127 | parser = argparse.ArgumentParser("Preprocess SQuAD data") 128 | parser.add_argument("--train_file", default=config.SQUAD_TRAIN) 129 | parser.add_argument("--dev_file", default=config.SQUAD_DEV) 130 | 131 | if not exists(config.CORPUS_DIR): 132 | mkdir(config.CORPUS_DIR) 133 | 134 | target_dir = join(config.CORPUS_DIR, SquadCorpus.NAME) 135 | if exists(target_dir) and len(listdir(target_dir)) > 0: 136 | raise ValueError("Files already exist in " + target_dir) 137 | 138 | args = parser.parse_args() 139 | tokenzier = NltkAndPunctTokenizer() 140 | 141 | print("Parsing train...") 142 | train = list(parse_squad_data(args.train_file, "train", tokenzier)) 143 | 144 | print("Parsing dev...") 145 | dev = list(parse_squad_data(args.dev_file, "dev", tokenzier)) 146 | 147 | print("Saving...") 148 | SquadCorpus.make_corpus(train, dev) 149 | print("Done") 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /docqa/squad/document_rd_corpus.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from typing import Dict, List 3 | 4 | from docqa.config import DOCUMENT_READER_DB, CORPUS_DIR 5 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer, ParagraphWithInverse 6 | from docqa.squad.build_squad_dataset import clean_title 7 | from docqa.squad.squad_data import SquadCorpus, Document 8 | import sqlite3 9 | 10 | 11 | """ 12 | Retrive documents by title from a sqlite database, we used this to evaluate our 13 | model on the documents from https://github.com/facebookresearch/DrQA 14 | """ 15 | 16 | 17 | def build_corpus_subset(output): 18 | docs = SquadCorpus().get_dev() 19 | titles = [clean_title(doc.title) for doc in docs] 20 | for i, t in enumerate(titles): 21 | if t == "Sky (United Kingdom)": 22 | titles[i] = "Sky UK" 23 | 24 | with sqlite3.connect(DOCUMENT_READER_DB) as conn: 25 | c = conn.cursor() 26 | 27 | c.execute("CREATE TEMPORARY TABLE squad_docs(id)") 28 | c.executemany("INSERT INTO squad_docs VALUES (?)", [(x,) for x in titles]) 29 | 30 | c.execute("ATTACH DATABASE ? AS db2", (output, )) 31 | c.execute("CREATE TABLE db2.documents (id PRIMARY KEY, text);") 32 | 33 | c.execute("INSERT INTO db2.documents SELECT * FROM documents WHERE id in squad_docs") 34 | c.close() 35 | 36 | 37 | def get_doc_rd_doc(docs: List[Document]) -> Dict[str, List[ParagraphWithInverse]]: 38 | tokenizer = NltkAndPunctTokenizer() 39 | conn = sqlite3.connect(DOCUMENT_READER_DB) 40 | c = conn.cursor() 41 | titles = [clean_title(doc.title) for doc in docs] 42 | for i, t in enumerate(titles): 43 | # Had to manually resolve this (due to changes in Wikipedia?) 44 | if t == "Sky (United Kingdom)": 45 | titles[i] = "Sky UK" 46 | 47 | title_to_doc_id = {t: doc.title for t, doc in zip(titles, docs)} 48 | 49 | c.execute("CREATE TEMPORARY TABLE squad_docs(id)") 50 | c.executemany("INSERT INTO squad_docs VALUES (?)", [(x,) for x in titles]) 51 | 52 | c.execute("SELECT id, text FROM documents WHERE id IN squad_docs") 53 | 54 | documents = {} 55 | out = c.fetchall() 56 | conn.close() 57 | for title, text in out: 58 | paragraphs = [] 59 | for para in text.split("\n"): 60 | para = para.strip() 61 | if len(para) > 0: 62 | paragraphs.append(tokenizer.tokenize_with_inverse(para)) 63 | documents[title_to_doc_id[title]] = paragraphs 64 | 65 | return documents 66 | 67 | if __name__ == "__main__": 68 | build_corpus_subset(join(CORPUS_DIR, "doc-rd-subset.db")) 69 | -------------------------------------------------------------------------------- /docqa/squad/squad_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from os import makedirs, listdir 3 | from os.path import isfile, join, exists, isdir 4 | from typing import List, Optional 5 | 6 | import numpy as np 7 | from docqa.config import CORPUS_DIR 8 | from docqa.data_processing.text_utils import ParagraphWithInverse 9 | from docqa.utils import ResourceLoader, flatten_iterable 10 | 11 | from docqa.data_processing.qa_training_data import ParagraphAndQuestionSpec, Answer, ParagraphQaTrainingData, \ 12 | ContextAndQuestion 13 | from docqa.data_processing.span_data import ParagraphSpans 14 | from docqa.data_processing.word_vectors import load_word_vectors 15 | from docqa.configurable import Configurable 16 | 17 | """ 18 | Represent SQuAD data 19 | """ 20 | 21 | 22 | class Question(object): 23 | """ Question paired with its answer """ 24 | 25 | def __init__(self, question_id: str, words: List[str], answer: ParagraphSpans): 26 | self.question_id = question_id 27 | self.words = words 28 | self.answer = answer 29 | 30 | def __repr__(self) -> str: 31 | return " ".join(self.words) 32 | 33 | 34 | class Paragraph(ParagraphWithInverse): 35 | """ Context with multiple questions, optionally includes it's "raw" untokenzied/un-normalized text and the reverse 36 | mapping for the tokenized text -> raw text """ 37 | 38 | def __init__(self, 39 | context: List[List[str]], 40 | questions: List[Question], 41 | article_id: str, 42 | paragraph_num: int, 43 | original_text: Optional[str] = None, 44 | spans: Optional[np.ndarray] = None): 45 | super().__init__(context, original_text, spans) 46 | self.article_id = article_id 47 | self.questions = questions 48 | self.paragraph_num = paragraph_num 49 | 50 | def __repr__(self) -> str: 51 | return "Paragraph%d(%s...)" % (self.paragraph_num, self.text[0][:40]) 52 | 53 | def __setstate__(self, state): 54 | if "context" in state and "text" not in state: 55 | state["text"] = state["context"] 56 | del state["context"] 57 | self.__dict__ = state 58 | 59 | 60 | class Document(object): 61 | """ Collection of paragraphs """ 62 | 63 | def __init__(self, doc_id: str, title: str, paragraphs: List[Paragraph]): 64 | self.title = title 65 | self.doc_id = doc_id 66 | self.paragraphs = paragraphs 67 | 68 | def __repr__(self) -> str: 69 | return "Document(%s)" % self.title 70 | 71 | 72 | class DocParagraphAndQuestion(ContextAndQuestion): 73 | 74 | def __init__(self, question: List[str], answer: Optional[Answer], 75 | question_id: str, paragraph: Paragraph): 76 | super().__init__(question, answer, question_id) 77 | self.paragraph = paragraph 78 | 79 | def get_original_text(self, para_start, para_end): 80 | return self.paragraph.get_original_text(para_start, para_end) 81 | 82 | def get_context(self): 83 | return flatten_iterable(self.paragraph.text) 84 | 85 | @property 86 | def sentences(self): 87 | return self.paragraph.text 88 | 89 | @property 90 | def n_context_words(self) -> int: 91 | return sum(len(s) for s in self.paragraph.text) 92 | 93 | @property 94 | def paragraph_num(self): 95 | return self.paragraph.paragraph_num 96 | 97 | @property 98 | def article_id(self): 99 | return self.paragraph.article_id 100 | 101 | 102 | def split_docs(docs: List[Document]) -> List[DocParagraphAndQuestion]: 103 | paras = [] 104 | for doc in docs: 105 | for i, para in enumerate(doc.paragraphs): 106 | for question in para.questions: 107 | paras.append(DocParagraphAndQuestion(question.words, question.answer, question.question_id, para)) 108 | return paras 109 | 110 | 111 | class SquadCorpus(Configurable): 112 | TRAIN_FILE = "train.pkl" 113 | DEV_FILE = "dev.pkl" 114 | NAME = "squad" 115 | 116 | VOCAB_FILE = "vocab.txt" 117 | WORD_VEC_SUFFIX = "_pruned" 118 | 119 | @staticmethod 120 | def make_corpus(train: List[Document], 121 | dev: List[Document]): 122 | dir = join(CORPUS_DIR, SquadCorpus.NAME) 123 | if isfile(dir) or (exists(dir) and len(listdir(dir))) > 0: 124 | raise ValueError("Directory %s already exists and is non-empty" % dir) 125 | if not exists(dir): 126 | makedirs(dir) 127 | 128 | for name, data in [(SquadCorpus.TRAIN_FILE, train), (SquadCorpus.DEV_FILE, dev)]: 129 | if data is not None: 130 | with open(join(dir, name), 'wb') as f: 131 | pickle.dump(data, f) 132 | 133 | def __init__(self): 134 | dir = join(CORPUS_DIR, self.NAME) 135 | if not exists(dir) or not isdir(dir): 136 | raise ValueError("No directory %s, corpus not built yet?" % dir) 137 | self.dir = dir 138 | 139 | @property 140 | def evidence(self): 141 | return None 142 | 143 | def get_vocab_file(self): 144 | self.get_vocab() 145 | return join(self.dir, self.VOCAB_FILE) 146 | 147 | def get_vocab(self): 148 | """ get all-lower cased unique words for this corpus, includes train/dev/test files """ 149 | voc_file = join(self.dir, self.VOCAB_FILE) 150 | if exists(voc_file): 151 | with open(voc_file, "r") as f: 152 | return [x.rstrip() for x in f] 153 | else: 154 | voc = set() 155 | for fn in [self.get_train, self.get_dev, self.get_test]: 156 | for doc in fn(): 157 | for para in doc.paragraphs: 158 | for sent in para.text: 159 | voc.update(x.lower() for x in sent) 160 | for question in para.questions: 161 | voc.update(x.lower() for x in question.words) 162 | voc.update(x.lower() for x in question.answer.get_vocab()) 163 | voc_list = sorted(list(voc)) 164 | with open(voc_file, "w") as f: 165 | for word in voc_list: 166 | f.write(word) 167 | f.write("\n") 168 | return voc_list 169 | 170 | def get_pruned_word_vecs(self, word_vec_name, voc=None): 171 | """ 172 | Loads word vectors that have been pruned to the case-insensitive vocab of this corpus. 173 | WARNING: this includes dev words 174 | 175 | This exists since loading word-vecs each time we startup can be a big pain, so 176 | we cache the pruned vecs on-disk as a .npy file we can re-load quickly. 177 | """ 178 | 179 | vec_file = join(self.dir, word_vec_name + self.WORD_VEC_SUFFIX + ".npy") 180 | if isfile(vec_file): 181 | print("Loading word vec %s for %s from cache" % (word_vec_name, self.name)) 182 | with open(vec_file, "rb") as f: 183 | return pickle.load(f) 184 | else: 185 | print("Building pruned word vec %s for %s" % (self.name, word_vec_name)) 186 | voc = self.get_vocab() 187 | vecs = load_word_vectors(word_vec_name, voc) 188 | with open(vec_file, "wb") as f: 189 | pickle.dump(vecs, f) 190 | return vecs 191 | 192 | def get_resource_loader(self): 193 | return ResourceLoader(self.get_pruned_word_vecs) 194 | 195 | def get_train(self) -> List[Document]: 196 | return self._load(join(self.dir, self.TRAIN_FILE)) 197 | 198 | def get_dev(self) -> List[Document]: 199 | return self._load(join(self.dir, self.DEV_FILE)) 200 | 201 | def get_test(self) -> List[Document]: 202 | return [] 203 | 204 | def _load(self, file) -> List[Document]: 205 | if not exists(file): 206 | return [] 207 | with open(file, "rb") as f: 208 | return pickle.load(f) 209 | 210 | 211 | class DocumentQaTrainingData(ParagraphQaTrainingData): 212 | def _preprocess(self, x): 213 | data = split_docs(x) 214 | return data, len(data) 215 | -------------------------------------------------------------------------------- /docqa/squad/squad_document_qa.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.metrics import pairwise_distances 6 | 7 | from docqa.data_processing.multi_paragraph_qa import ParagraphWithAnswers, MultiParagraphQuestion, TokenSpanGroup 8 | from docqa.data_processing.preprocessed_corpus import Preprocessor 9 | from docqa.data_processing.qa_training_data import ContextAndQuestion, Answer 10 | from docqa.data_processing.span_data import TokenSpans 11 | from docqa.squad.squad_data import Document 12 | from docqa.text_preprocessor import TextPreprocessor 13 | from docqa.utils import flatten_iterable 14 | 15 | """ 16 | Preprocessors for document-level question answering with SQuAD data 17 | """ 18 | 19 | 20 | class SquadParagraphWithAnswers(ParagraphWithAnswers): 21 | 22 | @classmethod 23 | def merge(cls, paras: List): 24 | paras.sort(key=lambda x: x.get_order()) 25 | answer_spans = [] 26 | text = [] 27 | original_text = "" 28 | spans = [] 29 | for para in paras: 30 | answer_spans.append(len(text) + para.answer_spans) 31 | spans.append(len(original_text) + para.spans) 32 | original_text += para.original_text 33 | text += para.text 34 | 35 | para = SquadParagraphWithAnswers(text, np.concatenate(answer_spans), 36 | paras[0].doc_id, paras[0].paragraph_num, 37 | original_text, np.concatenate(spans)) 38 | return para 39 | 40 | __slots__ = ["doc_id", "original_text", "paragraph_num", "spans"] 41 | 42 | def __init__(self, text: List[str], answer_spans: np.ndarray, doc_id: str, paragraph_num: int, 43 | original_text: str, spans: np.ndarray): 44 | super().__init__(text, answer_spans) 45 | self.doc_id = doc_id 46 | self.original_text = original_text 47 | self.paragraph_num = paragraph_num 48 | self.spans = spans 49 | 50 | def get_order(self): 51 | return self.paragraph_num 52 | 53 | def get_original_text(self, start, end): 54 | return self.original_text[self.spans[start][0]:self.spans[end][1]] 55 | 56 | def build_qa_pair(self, question, question_id, answer_text, group=None): 57 | if answer_text is None: 58 | ans = None 59 | elif group is None: 60 | ans = TokenSpans(answer_text, self.answer_spans) 61 | else: 62 | ans = TokenSpanGroup(answer_text, self.answer_spans, group) 63 | # returns a context-and-question equiped with a get_original_text method 64 | return QuestionAndSquadParagraph(question, ans, question_id, self) 65 | 66 | 67 | class QuestionAndSquadParagraph(ContextAndQuestion): 68 | def __init__(self, question: List[str], answer: Optional[Answer], question_id: str, para: SquadParagraphWithAnswers): 69 | super().__init__(question, answer, question_id, para.doc_id) 70 | self.para = para 71 | 72 | def get_original_text(self, start, end): 73 | return self.para.get_original_text(start, end) 74 | 75 | def get_context(self): 76 | return self.para.text 77 | 78 | @property 79 | def n_context_words(self) -> int: 80 | return len(self.para.text) 81 | 82 | 83 | class SquadTfIdfRanker(Preprocessor): 84 | """ 85 | TF-IDF ranking for SQuAD, this does the same thing as `TopTfIdf`, but its supports efficient usage 86 | when have many many questions per document 87 | """ 88 | 89 | def __init__(self, stop, n_to_select: int, force_answer: bool, text_process: TextPreprocessor=None): 90 | self.stop = stop 91 | self.n_to_select = n_to_select 92 | self.force_answer = force_answer 93 | self.text_process = text_process 94 | self._tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=self.stop.words) 95 | 96 | def preprocess(self, question: List[Document], evidence): 97 | return self.ranked_questions(question) 98 | 99 | def rank(self, questions: List[List[str]], paragraphs: List[List[List[str]]]): 100 | tfidf = self._tfidf 101 | para_features = tfidf.fit_transform([" ".join(" ".join(s) for s in x) for x in paragraphs]) 102 | q_features = tfidf.transform([" ".join(q) for q in questions]) 103 | scores = pairwise_distances(q_features, para_features, "cosine") 104 | return scores 105 | 106 | def ranked_questions(self, docs: List[Document]) -> List[MultiParagraphQuestion]: 107 | out = [] 108 | for doc in docs: 109 | scores = self.rank(flatten_iterable([q.words for q in x.questions] for x in doc.paragraphs), 110 | [x.text for x in doc.paragraphs]) 111 | q_ix = 0 112 | for para_ix, para in enumerate(doc.paragraphs): 113 | for q in para.questions: 114 | para_scores = scores[q_ix] 115 | para_ranks = np.argsort(para_scores) 116 | selection = [i for i in para_ranks[:self.n_to_select]] 117 | 118 | if self.force_answer and para_ix not in selection: 119 | selection[-1] = para_ix 120 | 121 | para = [] 122 | for ix in selection: 123 | if ix == para_ix: 124 | ans = q.answer.answer_spans 125 | else: 126 | ans = np.zeros((0, 2), dtype=np.int32) 127 | p = doc.paragraphs[ix] 128 | if self.text_process: 129 | text, ans, inv = self.text_process.encode_paragraph(q.words, [flatten_iterable(p.text)], 130 | p.paragraph_num == 0, ans, p.spans) 131 | para.append(SquadParagraphWithAnswers(text, ans, doc.doc_id, 132 | ix, p.original_text, inv)) 133 | else: 134 | para.append(SquadParagraphWithAnswers(flatten_iterable(p.text), ans, doc.doc_id, 135 | ix, p.original_text, p.spans)) 136 | 137 | out.append(MultiParagraphQuestion(q.question_id, q.words, q.answer.answer_text, para)) 138 | q_ix += 1 139 | return out 140 | -------------------------------------------------------------------------------- /docqa/squad/squad_official_evaluation.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /docqa/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/test/__init__.py -------------------------------------------------------------------------------- /docqa/test/test_batching.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from docqa.dataset import ClusteredBatcher, ShuffledBatcher 6 | 7 | 8 | class TestBatches(unittest.TestCase): 9 | 10 | def assert_unique(self, batches): 11 | values, c = np.unique(np.concatenate(batches), return_counts=True) 12 | self.assertTrue(np.all(c == 1)) 13 | 14 | def test_unique_samples(self): 15 | batchers = [ShuffledBatcher(5), ClusteredBatcher(5, lambda x:x)] 16 | test_u = self.assert_unique 17 | for batcher in batchers: 18 | batches = list(batcher.get_epoch(list(np.arange(21)))) 19 | self.assertEqual(len(batches), 4) 20 | self.assertEqual(batcher.epoch_size(21), 4) 21 | test_u(batches) 22 | 23 | batches = list(batcher.get_epoch(list(np.arange(25)))) 24 | self.assertEqual(len(batches), 5) 25 | self.assertEqual(batcher.epoch_size(25), 5) 26 | test_u(batches) 27 | 28 | batches = list(batcher.get_epoch(list(np.arange(5)))) 29 | self.assertEqual(len(batches), 1) 30 | self.assertEqual(batcher.epoch_size(5), 1) 31 | test_u(batches) 32 | 33 | def test_truncate_samples(self): 34 | batchers = [ShuffledBatcher(5, truncate_batches=True), ClusteredBatcher(5, lambda x: x, truncate_batches=True)] 35 | test_u = self.assert_unique 36 | for batcher in batchers: 37 | batches = list(batcher.get_epoch(list(np.arange(21)))) 38 | self.assertEqual(len(batches), 5) 39 | self.assertEqual(batcher.epoch_size(21), 5) 40 | test_u(batches) 41 | 42 | batches = list(batcher.get_epoch(list(np.arange(4)))) 43 | self.assertEqual(len(batches), 1) 44 | self.assertEqual(batcher.epoch_size(4), 1) 45 | test_u(batches) 46 | 47 | batches = list(batcher.get_epoch(list(np.arange(10)))) 48 | self.assertEqual(len(batches), 2) 49 | self.assertEqual(batcher.epoch_size(10), 2) 50 | test_u(batches) 51 | 52 | def test_order(self): 53 | batch = list(np.arange(103)) 54 | np.random.shuffle(batch) 55 | batches = list(ClusteredBatcher(10, lambda x: x, truncate_batches=True).get_epoch(batch)) 56 | self.assertEqual(len(batches), 11) 57 | for batch in batches: 58 | for i in range(0, len(batch)-1): 59 | if batch[i] != batch[i+1]-1: 60 | raise ValueError("Out of order point") 61 | 62 | -------------------------------------------------------------------------------- /docqa/test/test_embedder.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from docqa.nn.embedder import FixedWordEmbedder, shrink_embed 6 | 7 | 8 | class MockLoader(object): 9 | def __init__(self, name, vec): 10 | self.name = name 11 | self.vec = vec 12 | 13 | def load_word_vec(self, name, voc): 14 | if name != self.name: 15 | raise ValueError() 16 | return self.vec 17 | 18 | 19 | class TestEmbed(unittest.TestCase): 20 | 21 | def test_shrink_embed(self): 22 | with tf.Session().as_default(): 23 | original_mat = np.arange(9).reshape((9, 1)).astype(np.float32) 24 | original_word_ix = [(np.array([[0, 3, 0], [8, 3, 8]]))] 25 | mat, word_ix = shrink_embed(original_mat, original_word_ix) 26 | self.assertEqual(list(mat.eval().ravel()), [0, 3, 8]) 27 | self.assertEqual(list(word_ix[0].eval().ravel()), [0, 1, 0, 2, 1, 2]) 28 | self.assertEqual(word_ix[0].eval().shape, original_word_ix[0].shape) 29 | 30 | def test_shrink_embed_rng(self): 31 | with tf.Session().as_default(): 32 | n_words = 100 33 | original_mat = tf.constant(np.arange(n_words).reshape((n_words, 1)).astype(np.float32)) 34 | for i in range(20): 35 | original_word_ix = [np.random.randint(0, n_words, (2, 5)), 36 | np.random.randint(0, n_words, (3, 5, 2))] 37 | mat, word_ix = shrink_embed(original_mat, original_word_ix) 38 | mat = mat.eval() 39 | self.assertTrue(np.array_equal( 40 | mat[word_ix[0].eval().ravel()].reshape(word_ix[0].shape), 41 | original_word_ix[0])) 42 | self.assertTrue(np.array_equal( 43 | mat[word_ix[1].eval().ravel()].reshape(word_ix[1].shape), 44 | original_word_ix[1])) 45 | 46 | def test_fixed_embed(self): 47 | loader = MockLoader("v1", dict( 48 | red=np.array([0, 1], dtype=np.float32), 49 | the=np.array([1, 1], dtype=np.float32), 50 | fish=np.array([1, 0], dtype=np.float32), 51 | one=np.array([1, 0], dtype=np.float32))) 52 | 53 | emb = FixedWordEmbedder("v1") 54 | emb.init(loader, {"red", "cat", "decoy", "one", "fish"}) 55 | 56 | out = [emb.context_word_to_ix(x, True) for x in ["red", "one", "fish"]] 57 | self.assertEqual(set(out), {2, 3, 4}) 58 | 59 | out = [emb.context_word_to_ix(x, True) for x in ["decoy", "??", "the"]] 60 | self.assertEqual(list(out), [1, 1, 1]) 61 | -------------------------------------------------------------------------------- /docqa/test/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from docqa.data_processing.span_data import get_best_span_bounded, get_best_span_from_sent_predictions, get_best_span, \ 5 | get_best_in_sentence_span 6 | 7 | 8 | class TestEvaluator(unittest.TestCase): 9 | 10 | @staticmethod 11 | def best_span_brute_force(p1, p2): 12 | best_val = -1 13 | best_span = -1 14 | for i in range(len(p1)): 15 | for j in range(i, len(p2)): 16 | val = p1[i] * p2[j] 17 | if val > best_val: 18 | best_val = val 19 | best_span = (i, j) 20 | return best_span, best_val 21 | 22 | def test_best_span_rng(self): 23 | rng = np.random.RandomState(0) 24 | for test_num in range(0, 100): 25 | rng.seed(test_num) 26 | p1 = rng.uniform(0, 1, 15) 27 | p2 = rng.uniform(0, 1, 15) 28 | best_span, best_val = self.best_span_brute_force(p1, p2) 29 | pred_span, pred_val = get_best_span(p1, p2) 30 | self.assertEqual(best_span, pred_span) 31 | self.assertAlmostEqual(best_val, pred_val, 10) 32 | 33 | def test_best_restricted_span_rng(self): 34 | rng = np.random.RandomState(0) 35 | for test_num in range(200): 36 | rng.seed(test_num) 37 | lens = rng.random_integers(1, 4, size=3) 38 | p1 = [rng.uniform(0, 1, x) for x in lens] 39 | p2 = [rng.uniform(0, 1, x) for x in lens] 40 | best_span, best_val = None, -1 41 | offset = 0 42 | for i in range(len(lens)): 43 | span, val = self.best_span_brute_force(p1[i], p2[i]) 44 | span = span[0] + offset, span[1] + offset 45 | offset += lens[i] 46 | if val > best_val: 47 | best_span = span 48 | best_val = val 49 | 50 | pred_span, pred_val = get_best_in_sentence_span(np.concatenate(p1), np.concatenate(p2), lens) 51 | 52 | self.assertEqual(best_span, pred_span) 53 | self.assertAlmostEqual(best_val, pred_val, 10) 54 | 55 | def test_best_sent_span_rng(self): 56 | rng = np.random.RandomState(0) 57 | for test_num in range(200): 58 | rng.seed(test_num) 59 | n_sent = 3 60 | sen_lengths = rng.random_integers(1, 15, size=n_sent) 61 | 62 | p1 = rng.uniform(0, 1, (n_sent, sen_lengths.max()+1)) 63 | p2 = rng.uniform(0, 1, (n_sent, sen_lengths.max()+1)) 64 | best_span, best_val = None, -1 65 | 66 | offset = 0 67 | for sent_ix,sent_len in enumerate(sen_lengths): 68 | span, val = self.best_span_brute_force(p1[sent_ix][:sent_len], p2[sent_ix][:sent_len]) 69 | span = span[0] + offset, span[1] + offset 70 | offset += sent_len 71 | if val > best_val: 72 | best_span = span 73 | best_val = val 74 | 75 | pred_span, pred_val = get_best_span_from_sent_predictions(p1, p2, sen_lengths) 76 | 77 | self.assertEqual(best_span, pred_span) 78 | self.assertAlmostEqual(best_val, pred_val, 10) 79 | 80 | def test_bounded_span(self): 81 | p1 = np.array([0.5, 0.1, 0, 0.2]) 82 | p2 = np.array([0.6, 0.1, 0, 0.9]) 83 | 84 | self.assertEqual(list(get_best_span_bounded(p1, p2, 2)[0]), [0, 0]) 85 | self.assertEqual(list(get_best_span_bounded(p1, p2, 12)[0]), [0, 3]) 86 | -------------------------------------------------------------------------------- /docqa/test/test_lstm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.ops import init_ops 6 | 7 | from docqa.nn.recurrent_layers import _compute_gates 8 | 9 | 10 | class TestInitLstm(unittest.TestCase): 11 | 12 | def test_forget_bias(self): 13 | """ 14 | Make sure the forget bias is only being applied to the forget gate 15 | """ 16 | batches = 1 17 | num_units = 5 18 | num_inputs = 5 19 | 20 | hidden_size = (batches, num_units) 21 | input_size = (batches, num_inputs) 22 | 23 | inputs = tf.placeholder(dtype='float32', shape=input_size) 24 | h = tf.placeholder(dtype='float32', shape=hidden_size) 25 | with tf.variable_scope("test_bias"): 26 | i_t, j_t, f_t, o_t = _compute_gates(inputs, h, 4 * num_units, 1, 27 | init_ops.zeros_initializer(), init_ops.zeros_initializer()) 28 | gates = [i_t, j_t, f_t, o_t] 29 | 30 | sess = tf.Session() 31 | sess.run(tf.global_variables_initializer()) 32 | 33 | # Make sure the bias is ONLY getting applied to the forget gate 34 | [i,j,f,o] = sess.run(gates, feed_dict={inputs: np.zeros(input_size), h: np.ones(hidden_size)}) 35 | self.assertTrue(np.allclose(f, np.ones(f.shape), rtol=0)) 36 | for x in [i,j,o]: 37 | self.assertTrue(np.allclose(x, np.zeros(x.shape), rtol=0)) 38 | 39 | def test_inits(self): 40 | """ 41 | Make sure the initializers effects the correct weights 42 | """ 43 | batches = 1 44 | num_units = 2 45 | num_inputs = 3 46 | 47 | hidden_size = (batches, num_units) 48 | input_size = (batches, num_inputs) 49 | 50 | inputs = tf.placeholder(dtype='float32', shape=input_size) 51 | h = tf.placeholder(dtype='float32', shape=hidden_size) 52 | with tf.variable_scope("test_inits"): 53 | i_t, j_t, f_t, o_t = _compute_gates(inputs, h, num_units, 0, 54 | init_ops.constant_initializer(1), init_ops.constant_initializer(100)) 55 | gates = [i_t, j_t, f_t, o_t] 56 | 57 | sess = tf.Session() 58 | sess.run(tf.global_variables_initializer()) 59 | 60 | inputs_init = np.zeros(input_size) 61 | hidden_init = np.zeros(hidden_size) 62 | inputs_init[0] = 1 63 | i, j, f, o = sess.run(gates, feed_dict={inputs: inputs_init, h: hidden_init}) 64 | self.assertTrue(np.allclose(i, np.full(i.shape, num_inputs), rtol=0)) 65 | 66 | inputs_init[0] = 0 67 | hidden_init[0] = 1 68 | i, j, f, o = sess.run(gates, feed_dict={inputs: inputs_init, h: hidden_init}) 69 | self.assertTrue(np.allclose(i, np.full(i.shape, num_units*100), rtol=0)) 70 | 71 | hidden_init[0] = 0 72 | inputs_init[0, 0] = -2 73 | hidden_init[0, 0] = 1 74 | i, j, f, o = sess.run(gates, feed_dict={inputs: inputs_init, h: hidden_init}) 75 | self.assertTrue(np.allclose(i, np.full(i.shape, 98), rtol=0)) 76 | -------------------------------------------------------------------------------- /docqa/test/test_span_prediction.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from docqa.nn.span_prediction import packed_span_f1_mask, to_unpacked_coordinates 6 | from docqa.nn.span_prediction_ops import best_span_from_bounds 7 | from docqa.utils import flatten_iterable 8 | 9 | from docqa.data_processing.span_data import get_best_span_bounded, span_f1, top_disjoint_spans 10 | from docqa.nn.ops import segment_logsumexp 11 | 12 | 13 | class TestBestSpan(unittest.TestCase): 14 | 15 | def setUp(self): 16 | self.sess = tf.Session() 17 | 18 | def test_segment_log_sum_exp(self): 19 | sess = self.sess 20 | with sess.as_default(): 21 | for i in range(10): 22 | groups = [] 23 | for group_id in range(10): 24 | group = [] 25 | for _ in range(np.random.randint(1, 5)): 26 | group.append(np.random.normal(0, 2, 10)) 27 | groups.append(group) 28 | 29 | flat_groups = np.stack(flatten_iterable(groups), axis=0) 30 | semgents = np.array(flatten_iterable([ix]*len(g) for ix, g in enumerate(groups))) 31 | actual = sess.run(segment_logsumexp(flat_groups, semgents)) 32 | expected = [np.log(np.sum(np.exp(np.concatenate(g, axis=0)))) for g in groups] 33 | self.assertTrue(np.allclose(actual, expected)) 34 | 35 | def test_top_n_simple(self): 36 | spans, scores = top_disjoint_spans(np.array([ 37 | 1, 0, 0, 10, 38 | 2, 2, 0, 0, 39 | 0, 0, 3, 0, 40 | 1, 0, 5, 4, 41 | ]).reshape((4, 4)), 3, 2) 42 | self.assertEqual(list(scores), [4, 3]) 43 | self.assertEqual(spans.tolist(), [[3, 3], [2, 2]]) 44 | 45 | def test_top_n_overlap(self): 46 | spans, scores = top_disjoint_spans(np.array([ 47 | 4, 4, 5, 4, 48 | 0, 4, 4, 4, 49 | 0, 0, 0, 4, 50 | 0, 0, 0, 2, 51 | ]).reshape((4, 4)), 10, 5) 52 | self.assertEqual(list(scores), [5, 2]) 53 | self.assertEqual(spans.tolist(), [[0, 2], [3, 3]]) 54 | 55 | def test_best_span(self): 56 | bound = 5 57 | start_pl = tf.placeholder(tf.float32, (None, None)) 58 | end_pl = tf.placeholder(tf.float32, (None, None)) 59 | best_span, best_val = best_span_from_bounds(start_pl, end_pl, bound) 60 | sess = self.sess 61 | 62 | for i in range(0, 20): 63 | rng = np.random.RandomState(i) 64 | l = rng.randint(50, 200) 65 | batch = rng.randint(1, 60) 66 | 67 | start = rng.uniform(size=(batch, l)) 68 | end = rng.uniform(size=(batch, l)) 69 | 70 | # exp since the tf version uses logits and the py version use probabilities 71 | expected_span, expected_score = zip(*[get_best_span_bounded(np.exp(start[i]), np.exp(end[i]), bound) 72 | for i in range(batch)]) 73 | 74 | actual_span, actuals_score = sess.run([best_span, best_val], {start_pl:start, end_pl:end}) 75 | 76 | self.assertTrue(np.all(np.array(expected_span) == actual_span)) 77 | self.assertTrue(np.allclose(expected_score, np.exp(actuals_score))) 78 | 79 | def test_span_f1(self): 80 | bound = 15 81 | batch_size = 5 82 | l = 20 83 | 84 | spans_pl = tf.placeholder(tf.int32, (None, 2)) 85 | coordinate_pl = tf.placeholder(tf.int32, (1,)) 86 | 87 | mask = packed_span_f1_mask(spans_pl, 15, bound) 88 | coordinates = to_unpacked_coordinates(coordinate_pl, 15, bound)[0] 89 | sess = self.sess 90 | 91 | for i in range(0, 20): 92 | rng = np.random.RandomState(i) 93 | starts = rng.randint(0, l, batch_size) 94 | ends = [rng.randint(0, l-x) + x for x in starts] 95 | spans = np.stack([starts, np.array(ends)], axis=1) 96 | 97 | f1_mask = sess.run(mask, {spans_pl:spans}) 98 | 99 | for i in range(batch_size): 100 | coord = np.random.randint(0, f1_mask.shape[1]) 101 | x,y = sess.run(coordinates, {coordinate_pl:[coord]}) 102 | expected = span_f1(spans[i], (x, y)) 103 | actual = f1_mask[i, coord] 104 | self.assertAlmostEqual(expected, actual, places=5) -------------------------------------------------------------------------------- /docqa/test/test_splitter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import List 3 | import numpy as np 4 | 5 | from docqa.data_processing.document_splitter import DocumentSplitter, ExtractedParagraph, extract_tokens 6 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer 7 | from docqa.utils import flatten_iterable 8 | 9 | 10 | class RandomSplitter(DocumentSplitter): 11 | def split(self, doc: List[List[List[str]]]) -> List[ExtractedParagraph]: 12 | words = flatten_iterable(flatten_iterable(doc)) 13 | on_word = 0 14 | out = [] 15 | while True: 16 | end_word = on_word + np.random.randint(1, 7) 17 | if on_word + end_word > len(words): 18 | out.append(ExtractedParagraph([words[on_word:]], on_word, len(words))) 19 | return out 20 | out.append(ExtractedParagraph([words[on_word:end_word]], on_word, end_word)) 21 | on_word = end_word 22 | 23 | 24 | class TestSplitter(unittest.TestCase): 25 | 26 | def test_split_inv(self): 27 | paras = [ 28 | "One fish two fish. Red fish blue fish", 29 | "Just one sentence", 30 | "How will an overhead score? The satisfactory juice returns against an inviting protein. " 31 | "How can a rat expand? The subway fishes throughout a struggle. The guaranteed herd pictures an " 32 | "episode into the accustomed damned. The garbage reigns beside the component!", 33 | ] 34 | tok = NltkAndPunctTokenizer() 35 | tokenized = [tok.tokenize_with_inverse(x) for x in paras] 36 | inv_split = RandomSplitter().split_inverse(tokenized) 37 | for para in inv_split: 38 | self.assertTrue(flatten_iterable(para.text) == [para.original_text[s:e] for s,e in para.spans]) 39 | 40 | -------------------------------------------------------------------------------- /docqa/test/test_ut_coordinates.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from docqa.nn.span_prediction import to_packed_coordinates, to_unpacked_coordinates 6 | 7 | 8 | class TestPackedCoordinates(unittest.TestCase): 9 | 10 | def test_random(self): 11 | matrix_size = 100 12 | sess = tf.Session() 13 | 14 | matrix_size_placeholder = tf.placeholder(np.int32, ()) 15 | span_placeholder = tf.placeholder(np.int32, [None, 2]) 16 | # tmp = tmp_lens(to_packed_coordiantes(span_placeholder, matrix_size_placeholder), matrix_size_placeholder) 17 | rebuilt = to_unpacked_coordinates(to_packed_coordinates(span_placeholder, 18 | matrix_size_placeholder), matrix_size_placeholder, matrix_size) 19 | 20 | for i in range(0, 1000): 21 | rng = np.random.RandomState(i) 22 | n_elements = 20 23 | 24 | start = rng.randint(0, matrix_size-1, size=n_elements) 25 | end = np.zeros_like(start) 26 | for i in range(n_elements): 27 | end[i] = start[i] + rng.randint(0, matrix_size - start[i]) 28 | spans = np.stack([start, end], axis=1) 29 | 30 | r = sess.run(rebuilt, {span_placeholder:spans, matrix_size_placeholder:matrix_size}) 31 | self.assertTrue(np.all(r == spans)) -------------------------------------------------------------------------------- /docqa/test/test_word_features.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from docqa.data_processing.text_features import extract_year, is_number, BasicWordFeatures 4 | 5 | 6 | class TestWordFeatures(unittest.TestCase): 7 | def test_regex(self): 8 | fe = BasicWordFeatures() 9 | self.assertIsNone(fe.punc_regex.match("the")) 10 | self.assertIsNotNone(fe.punc_regex.match("!!!")) 11 | self.assertIsNotNone(fe.punc_regex.match("!.,()\"\'")) 12 | 13 | def test_non_english(self): 14 | tmp = BasicWordFeatures() 15 | self.assertIsNone(tmp.non_english.match("51,419,420")) 16 | self.assertIsNone(tmp.non_english.match("cat")) 17 | self.assertIsNone(tmp.non_english.match(",")) 18 | self.assertIsNotNone(tmp.non_english.match("لجماهيري")) 19 | 20 | def test_date(self): 21 | self.assertEqual(extract_year("sdf"), None) 22 | self.assertEqual(extract_year("-1"), None) 23 | self.assertEqual(extract_year("1990"), 1990) 24 | self.assertEqual(extract_year("1990s"), 1990) 25 | self.assertEqual(extract_year("90s"), 1990) 26 | 27 | def test_any_num(self): 28 | tmp = BasicWordFeatures() 29 | self.assertIsNone(tmp.any_num_regex.match("cat")) 30 | self.assertIsNotNone(tmp.any_num_regex.match("c3at")) 31 | 32 | def test_numbers(self): 33 | self.assertIsNotNone(is_number("90,000")) 34 | self.assertIsNotNone(is_number("90,000,112")) 35 | self.assertIsNone(is_number("90,000112")) 36 | self.assertIsNone(is_number("101,1")) 37 | self.assertIsNone(is_number("1,2,3")) 38 | self.assertIsNone(is_number("1,2,3")) 39 | self.assertIsNone(is_number("0.1th")) 40 | 41 | self.assertIsNotNone(is_number("90,000")) 42 | self.assertIsNotNone(is_number("90,000.01")) 43 | self.assertIsNotNone(is_number("91234.01")) 44 | self.assertIsNotNone(is_number("91234st")) 45 | self.assertIsNotNone(is_number("91,234th")) 46 | self.assertIsNotNone(is_number(".034")) 47 | self.assertIsNotNone(is_number(".034km")) 48 | -------------------------------------------------------------------------------- /docqa/text_preprocessor.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import List, Optional, Tuple 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from docqa.utils import flatten_iterable 7 | 8 | from docqa.data_processing.document_splitter import ExtractedParagraphWithAnswers, MergeParagraphs, ExtractedParagraph 9 | from docqa.data_processing.multi_paragraph_qa import ParagraphWithAnswers 10 | from docqa.configurable import Configurable 11 | from docqa.squad.squad_data import SquadCorpus 12 | from docqa.triviaqa.build_span_corpus import TriviaQaWebDataset 13 | 14 | 15 | class TextPreprocessor(Configurable): 16 | """ Preprocess text input, must be deterministic. Only used thus far adding special indicator tokens """ 17 | 18 | def encode_extracted_paragraph(self, question: List[str], paragraph: ExtractedParagraphWithAnswers): 19 | text, answers, _ = self.encode_paragraph(question, paragraph.text, 20 | paragraph.start == 0, paragraph.answer_spans) 21 | return ParagraphWithAnswers(text, answers) 22 | 23 | def encode_text(self, question: List[str], paragraph: ExtractedParagraph): 24 | text, _, _ = self.encode_paragraph(question, paragraph.text, paragraph.start == 0, 25 | np.zeros((0, 2), dtype=np.int32)) 26 | return text 27 | 28 | def encode_paragraph(self, question: List[str], paragraphs: List[List[str]], 29 | is_first, answer_spans: np.ndarray, 30 | token_spans=None) -> Tuple[List[str], np.ndarray, Optional[np.ndarray]]: 31 | """ 32 | Returns updated (and flattened) text, answer_spans, and token_spans 33 | """ 34 | raise NotImplementedError() 35 | 36 | def special_tokens(self) -> List[str]: 37 | return [] 38 | 39 | 40 | class WithIndicators(TextPreprocessor): 41 | """ 42 | Adds a document or group start token before the text, and a paragraph token between each 43 | between in each paragraph. 44 | """ 45 | 46 | PARAGRAPH_TOKEN = "%%PARAGRAPH%%" 47 | DOCUMENT_START_TOKEN = "%%DOCUMENT%%" 48 | PARAGRAPH_GROUP = "%%PARAGRAPH_GROUP%%" 49 | 50 | def __init__(self, remove_cross_answer: bool=True, para_tokens: bool=True, doc_start_token: bool=True): 51 | self.remove_cross_answer = remove_cross_answer 52 | self.doc_start_token = doc_start_token 53 | self.para_tokens = para_tokens 54 | 55 | def special_tokens(self) -> List[str]: 56 | tokens = [self.PARAGRAPH_GROUP] 57 | if self.doc_start_token: 58 | tokens.append(self.DOCUMENT_START_TOKEN) 59 | if self.para_tokens: 60 | tokens.append(self.PARAGRAPH_TOKEN) 61 | return tokens 62 | 63 | def encode_paragraph(self, question: List[str], paragraphs: List[List[str]], is_first, answer_spans: np.ndarray, inver=None): 64 | out = [] 65 | 66 | offset = 0 67 | if self.doc_start_token and is_first: 68 | out.append(self.DOCUMENT_START_TOKEN) 69 | else: 70 | out.append(self.PARAGRAPH_GROUP) 71 | 72 | if inver is not None: 73 | inv_out = [np.zeros((1, 2), dtype=np.int32)] 74 | else: 75 | inv_out = None 76 | 77 | offset += 1 78 | spans = answer_spans + offset 79 | 80 | out += paragraphs[0] 81 | offset += len(paragraphs[0]) 82 | on_ix = len(paragraphs[0]) 83 | if inv_out is not None: 84 | inv_out.append(inver[:len(paragraphs[0])]) 85 | 86 | for sent in paragraphs[1:]: 87 | if self.remove_cross_answer: 88 | remove = np.logical_and(spans[:, 0] < offset, spans[:, 1] >= offset) 89 | spans = spans[np.logical_not(remove)] 90 | 91 | if self.para_tokens: 92 | spans[spans[:, 0] >= offset, 0] += 1 93 | spans[spans[:, 1] >= offset, 1] += 1 94 | 95 | out.append(self.PARAGRAPH_TOKEN) 96 | if inv_out is not None: 97 | if len(inv_out) == 0 or len(inv_out[-1]) == 0: 98 | inv_out.append(np.zeros((1, 2), dtype=np.int32)) 99 | else: 100 | inv_out.append(np.full((1, 2), inv_out[-1][-1][1], dtype=np.int32)) 101 | offset += 1 102 | 103 | out += sent 104 | offset += len(sent) 105 | if inv_out is not None: 106 | inv_out.append(inver[on_ix:on_ix+len(sent)]) 107 | on_ix += len(sent) 108 | 109 | return out, spans, None if inv_out is None else np.concatenate(inv_out) 110 | 111 | def __setstate__(self, state): 112 | if "state" in state: 113 | state["state"]["doc_start_token"] = True 114 | state["state"]["para_tokens"] = True 115 | else: 116 | if "doc_start_token" not in state: 117 | state["doc_start_token"] = True 118 | if "para_tokens" not in state: 119 | state["para_tokens"] = True 120 | super().__setstate__(state) 121 | 122 | 123 | def check_preprocess(): 124 | data = TriviaQaWebDataset() 125 | merge = MergeParagraphs(400) 126 | questions = data.get_dev() 127 | pre = WithIndicators(False) 128 | remove_cross = WithIndicators(True) 129 | rng = np.random.RandomState(0) 130 | rng.shuffle(questions) 131 | 132 | for q in tqdm(questions[:1000]): 133 | doc = rng.choice(q.all_docs, 1)[0] 134 | text = data.evidence.get_document(doc.doc_id, n_tokens=800) 135 | paras = merge.split_annotated(text, doc.answer_spans) 136 | para = paras[np.random.randint(0, len(paras))] 137 | built = pre.encode_extracted_paragraph(q.question, para) 138 | 139 | expected_text = flatten_iterable(para.text) 140 | if expected_text != [x for x in built.text if x not in pre.special_tokens()]: 141 | raise ValueError() 142 | 143 | expected = [expected_text[s:e+1] for s, e in para.answer_spans] 144 | expected = Counter([tuple(x) for x in expected]) 145 | 146 | actual = [tuple(built.text[s:e+1]) for s,e in built.answer_spans] 147 | actual_cleaned = Counter(tuple(z for z in x if z not in pre.special_tokens()) for x in actual) 148 | if actual_cleaned != expected: 149 | raise ValueError() 150 | 151 | r_built = remove_cross.encode_extracted_paragraph(q.question, para) 152 | rc = Counter(tuple(r_built.text[s:e + 1]) for s, e in r_built.answer_spans) 153 | removed = Counter() 154 | for w in actual: 155 | if all(x not in pre.special_tokens() for x in w): 156 | removed[w] += 1 157 | 158 | if rc != removed: 159 | raise ValueError() 160 | 161 | 162 | def check_preprocess_squad(): 163 | data = SquadCorpus().get_train() 164 | remove_cross = WithIndicators(True) 165 | 166 | for doc in tqdm(data): 167 | for para in doc.paragraphs: 168 | q = para.questions[np.random.randint(0, len(para.questions))] 169 | 170 | text, ans, inv = remove_cross.encode_paragraph(q.words, para.text, para.paragraph_num == 0, 171 | q.answer.answer_spans, para.spans) 172 | if len(inv) != len(text): 173 | raise ValueError() 174 | for i in range(len(inv)-1): 175 | if inv[i, 0] > inv[i+1, 0]: 176 | raise ValueError() 177 | for (s1, e1), (s2, e2) in zip(ans, q.answer.answer_spans): 178 | if tuple(inv[s1]) != tuple(para.spans[s2]): 179 | raise ValueError() 180 | if tuple(inv[e1]) != tuple(para.spans[e2]): 181 | raise ValueError() 182 | 183 | 184 | if __name__ == "__main__": 185 | check_preprocess_squad() -------------------------------------------------------------------------------- /docqa/triviaqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/document-qa/2f9fa6878b60ed8a8a31bcf03f802cde292fe48b/docqa/triviaqa/__init__.py -------------------------------------------------------------------------------- /docqa/triviaqa/build_complete_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from os.path import exists 4 | 5 | from docqa.triviaqa.build_span_corpus import TriviaQaOpenDataset 6 | from docqa.triviaqa.evidence_corpus import get_evidence_voc 7 | 8 | """ 9 | Build vocab of all words in the triviaqa dataset, including 10 | all documents and all train questions. 11 | """ 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("output") 17 | parser.add_argument("-m", "--min_count", type=int, default=1) 18 | parser.add_argument("-n", "--n_processes", type=int, default=1) 19 | args = parser.parse_args() 20 | 21 | if exists(args.output): 22 | raise ValueError() 23 | 24 | data = TriviaQaOpenDataset() 25 | corpus_voc = get_evidence_voc(data.evidence, args.n_processes) 26 | 27 | print("Adding question voc...") 28 | train = data.get_train() 29 | for q in train: 30 | corpus_voc.update(q.question) 31 | 32 | print("Saving...") 33 | with open(args.output, "w") as f: 34 | for word, c in corpus_voc.items(): 35 | if c >= args.min_count: 36 | f.write(word) 37 | f.write("\n") 38 | 39 | 40 | if __name__ == "__main__": 41 | main() -------------------------------------------------------------------------------- /docqa/triviaqa/build_span_corpus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import unicodedata 5 | from itertools import islice 6 | from os import mkdir 7 | from os.path import join, exists 8 | from typing import List, Optional, Dict 9 | 10 | from docqa.config import CORPUS_DIR, TRIVIA_QA, TRIVIA_QA_UNFILTERED 11 | from docqa.configurable import Configurable 12 | from docqa.data_processing.text_utils import NltkAndPunctTokenizer 13 | from docqa.triviaqa.answer_detection import compute_answer_spans_par, FastNormalizedAnswerDetector 14 | from docqa.triviaqa.evidence_corpus import TriviaQaEvidenceCorpusTxt 15 | from docqa.triviaqa.read_data import iter_trivia_question, TriviaQaQuestion 16 | from docqa.utils import ResourceLoader 17 | 18 | """ 19 | Build span-level training data from the raw trivia-qa inputs, in particular load the questions 20 | from the json file and annotates each question/doc with the places the question answer's occur 21 | within the document, and save the results in our format. Assumes the evidence corpus has 22 | already been preprocessed 23 | """ 24 | 25 | 26 | def build_dataset(name: str, tokenizer, train_files: Dict[str, str], 27 | answer_detector, n_process: int, prune_unmapped_docs=True, 28 | sample=None): 29 | out_dir = join(CORPUS_DIR, "triviaqa", name) 30 | if not exists(out_dir): 31 | mkdir(out_dir) 32 | 33 | file_map = {} # maps document_id -> filename 34 | 35 | for name, filename in train_files.items(): 36 | print("Loading %s questions" % name) 37 | if sample is None: 38 | questions = list(iter_trivia_question(filename, file_map, False)) 39 | else: 40 | if isinstance(sample, int): 41 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample)) 42 | elif isinstance(sample, dict): 43 | questions = list(islice(iter_trivia_question(filename, file_map, False), sample[name])) 44 | else: 45 | raise ValueError() 46 | 47 | if prune_unmapped_docs: 48 | for q in questions: 49 | if q.web_docs is not None: 50 | q.web_docs = [x for x in q.web_docs if x.doc_id in file_map] 51 | q.entity_docs = [x for x in q.entity_docs if x.doc_id in file_map] 52 | 53 | print("Adding answers for %s question" % name) 54 | corpus = TriviaQaEvidenceCorpusTxt(file_map) 55 | questions = compute_answer_spans_par(questions, corpus, tokenizer, answer_detector, n_process) 56 | for q in questions: # Sanity check, we should have answers for everything (even if of size 0) 57 | if q.answer is None: 58 | continue 59 | for doc in q.all_docs: 60 | if doc.doc_id in file_map: 61 | if doc.answer_spans is None: 62 | raise RuntimeError() 63 | 64 | print("Saving %s question" % name) 65 | with open(join(out_dir, name + ".pkl"), "wb") as f: 66 | pickle.dump(questions, f) 67 | 68 | print("Dumping file mapping") 69 | with open(join(out_dir, "file_map.json"), "w") as f: 70 | json.dump(file_map, f) 71 | 72 | print("Complete") 73 | 74 | 75 | class TriviaQaSpanCorpus(Configurable): 76 | def __init__(self, corpus_name): 77 | self.corpus_name = corpus_name 78 | self.dir = join(CORPUS_DIR, "triviaqa", corpus_name) 79 | with open(join(self.dir, "file_map.json"), "r") as f: 80 | file_map = json.load(f) 81 | for k, v in file_map.items(): 82 | file_map[k] = unicodedata.normalize("NFD", v) 83 | self.evidence = TriviaQaEvidenceCorpusTxt(file_map) 84 | 85 | def get_train(self) -> List[TriviaQaQuestion]: 86 | with open(join(self.dir, "train.pkl"), "rb") as f: 87 | return pickle.load(f) 88 | 89 | def get_dev(self) -> List[TriviaQaQuestion]: 90 | with open(join(self.dir, "dev.pkl"), "rb") as f: 91 | return pickle.load(f) 92 | 93 | def get_test(self) -> List[TriviaQaQuestion]: 94 | with open(join(self.dir, "test.pkl"), "rb") as f: 95 | return pickle.load(f) 96 | 97 | def get_verified(self) -> Optional[List[TriviaQaQuestion]]: 98 | verified_dir = join(self.dir, "verified.pkl") 99 | if not exists(verified_dir): 100 | return None 101 | with open(verified_dir, "rb") as f: 102 | return pickle.load(f) 103 | 104 | def get_resource_loader(self): 105 | return ResourceLoader() 106 | 107 | @property 108 | def name(self): 109 | return self.corpus_name 110 | 111 | 112 | class TriviaQaWebDataset(TriviaQaSpanCorpus): 113 | def __init__(self): 114 | super().__init__("web") 115 | 116 | 117 | class TriviaQaWikiDataset(TriviaQaSpanCorpus): 118 | def __init__(self): 119 | super().__init__("wiki") 120 | 121 | 122 | class TriviaQaOpenDataset(TriviaQaSpanCorpus): 123 | def __init__(self): 124 | super().__init__("web-open") 125 | 126 | 127 | class TriviaQaSampleWebDataset(TriviaQaSpanCorpus): 128 | def __init__(self): 129 | super().__init__("web-sample") 130 | 131 | 132 | def build_wiki_corpus(n_processes): 133 | build_dataset("wiki", NltkAndPunctTokenizer(), 134 | dict( 135 | verified=join(TRIVIA_QA, "qa", "verified-wikipedia-dev.json"), 136 | dev=join(TRIVIA_QA, "qa", "wikipedia-dev.json"), 137 | train=join(TRIVIA_QA, "qa", "wikipedia-train.json"), 138 | test=join(TRIVIA_QA, "qa", "wikipedia-test-without-answers.json") 139 | ), 140 | FastNormalizedAnswerDetector(), n_processes) 141 | 142 | 143 | def build_web_corpus(n_processes): 144 | build_dataset("web", NltkAndPunctTokenizer(), 145 | dict( 146 | verified=join(TRIVIA_QA, "qa", "verified-web-dev.json"), 147 | dev=join(TRIVIA_QA, "qa", "web-dev.json"), 148 | train=join(TRIVIA_QA, "qa", "web-train.json"), 149 | test=join(TRIVIA_QA, "qa", "web-test-without-answers.json") 150 | ), 151 | FastNormalizedAnswerDetector(), n_processes) 152 | 153 | 154 | def build_sample_corpus(n_processes): 155 | build_dataset("web-sample", NltkAndPunctTokenizer(), 156 | dict( 157 | dev=join(TRIVIA_QA, "qa", "web-dev.json"), 158 | train=join(TRIVIA_QA, "qa", "web-train.json"), 159 | ), 160 | FastNormalizedAnswerDetector(), n_processes, sample=1000) 161 | 162 | 163 | def build_unfiltered_corpus(n_processes): 164 | build_dataset("web-open", NltkAndPunctTokenizer(), 165 | dict( 166 | dev=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-dev.json"), 167 | train=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-train.json"), 168 | test=join(TRIVIA_QA_UNFILTERED, "unfiltered-web-test-without-answers.json") 169 | ), 170 | answer_detector=FastNormalizedAnswerDetector(), 171 | n_process=n_processes) 172 | 173 | 174 | def main(): 175 | parser = argparse.ArgumentParser("Pre-procsess TriviaQA data") 176 | parser.add_argument("corpus", choices=["web", "wiki", "web-open"]) 177 | parser.add_argument("-n", "--n_processes", type=int, default=1, help="Number of processes to use") 178 | args = parser.parse_args() 179 | if args.corpus == "web": 180 | build_web_corpus(args.n_processes) 181 | elif args.corpus == "wiki": 182 | build_wiki_corpus(args.n_processes) 183 | elif args.corpus == "web-open": 184 | build_unfiltered_corpus(args.n_processes) 185 | else: 186 | raise RuntimeError() 187 | 188 | 189 | if __name__ == "__main__": 190 | main() 191 | 192 | 193 | -------------------------------------------------------------------------------- /docqa/triviaqa/trivia_qa_eval.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.0 of the TriviaQA dataset. 2 | Extended from the evaluation script for v1.1 of the SQuAD dataset. 3 | 4 | (Additionally condensed into a single file) 5 | """ 6 | from __future__ import print_function 7 | 8 | import json 9 | from collections import Counter 10 | import string 11 | import re 12 | import sys 13 | import argparse 14 | 15 | import unicodedata 16 | from tqdm import tqdm 17 | 18 | 19 | def normalize_answer(s): 20 | """Lower text and remove punctuation, articles and extra whitespace.""" 21 | 22 | def remove_articles(text): 23 | return re.sub(r'\b(a|an|the)\b', ' ', text) 24 | 25 | def white_space_fix(text): 26 | return ' '.join(text.split()) 27 | 28 | def handle_punc(text): 29 | exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"])) 30 | return ''.join(ch if ch not in exclude else ' ' for ch in text) 31 | 32 | def lower(text): 33 | return text.lower() 34 | 35 | def replace_underscore(text): 36 | return text.replace('_', ' ') 37 | 38 | return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip() 39 | 40 | 41 | def f1_score(prediction, ground_truth): 42 | prediction_tokens = normalize_answer(prediction).split() 43 | ground_truth_tokens = normalize_answer(ground_truth).split() 44 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 45 | num_same = sum(common.values()) 46 | if num_same == 0: 47 | return 0 48 | precision = 1.0 * num_same / len(prediction_tokens) 49 | recall = 1.0 * num_same / len(ground_truth_tokens) 50 | f1 = (2 * precision * recall) / (precision + recall) 51 | return f1 52 | 53 | 54 | def exact_match_score(prediction, ground_truth): 55 | return normalize_answer(prediction) == normalize_answer(ground_truth) 56 | 57 | 58 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 59 | scores_for_ground_truths = [] 60 | for ground_truth in ground_truths: 61 | score = metric_fn(prediction, ground_truth) 62 | scores_for_ground_truths.append(score) 63 | return max(scores_for_ground_truths) 64 | 65 | 66 | def get_ground_truths(answer): 67 | return answer['NormalizedAliases'] + [normalize_answer(ans) for ans in answer.get('HumanAnswers', [])] 68 | 69 | 70 | def get_file_contents(filename, encoding='utf-8'): 71 | with open(filename, encoding=encoding) as f: 72 | content = f.read() 73 | return content 74 | 75 | 76 | def read_json(filename, encoding='utf-8'): 77 | contents = get_file_contents(filename, encoding=encoding) 78 | return json.loads(contents) 79 | 80 | 81 | def is_exact_match(answer_object, prediction): 82 | ground_truths = get_ground_truths(answer_object) 83 | for ground_truth in ground_truths: 84 | if exact_match_score(prediction, ground_truth): 85 | return True 86 | return False 87 | 88 | 89 | def has_exact_match(ground_truths, candidates): 90 | for ground_truth in ground_truths: 91 | if ground_truth in candidates: 92 | return True 93 | return False 94 | 95 | 96 | def get_key_to_ground_truth(data): 97 | if data['Domain'] == 'Wikipedia': 98 | return {datum['QuestionId']: datum['Answer'] for datum in data['Data']} 99 | else: 100 | return get_qd_to_answer(data) 101 | 102 | 103 | def get_question_doc_string(qid, doc_name): 104 | return '{}--{}'.format(qid, unicodedata.normalize("NFD", doc_name).lower()) 105 | 106 | 107 | def get_qd_to_answer(data): 108 | key_to_answer = {} 109 | for datum in data['Data']: 110 | for page in datum.get('EntityPages', []) + datum.get('SearchResults', []): 111 | qd_tuple = get_question_doc_string(datum['QuestionId'], page['Filename']) 112 | key_to_answer[qd_tuple] = datum['Answer'] 113 | return key_to_answer 114 | 115 | 116 | def evaluate_triviaqa(ground_truth, predicted_answers, qid_list=None, mute=False): 117 | f1 = exact_match = common = 0 118 | if qid_list is None: 119 | qid_list = ground_truth.keys() 120 | for qid in tqdm(qid_list, ncols=80): 121 | if qid not in predicted_answers: 122 | if not mute: 123 | message = 'Missed question {} will receive score 0.'.format(qid) 124 | print(message, file=sys.stderr) 125 | continue 126 | if qid not in ground_truth: 127 | if not mute: 128 | message = 'Irrelavant question {} will receive score 0.'.format(qid) 129 | print(message, file=sys.stderr) 130 | continue 131 | common += 1 132 | prediction = predicted_answers[qid] 133 | ground_truths = get_ground_truths(ground_truth[qid]) 134 | em_for_this_question = metric_max_over_ground_truths( 135 | exact_match_score, prediction, ground_truths) 136 | if em_for_this_question == 0 and not mute: 137 | print("em=0:", prediction, ground_truths) 138 | exact_match += em_for_this_question 139 | f1_for_this_question = metric_max_over_ground_truths( 140 | f1_score, prediction, ground_truths) 141 | f1 += f1_for_this_question 142 | 143 | exact_match = 100.0 * exact_match / len(qid_list) 144 | f1 = 100.0 * f1 / len(qid_list) 145 | 146 | return {'exact_match': exact_match, 'f1': f1, 'common': common, 'denominator': len(qid_list), 147 | 'pred_len': len(predicted_answers), 'gold_len': len(ground_truth)} 148 | 149 | 150 | def read_clean_part(datum): 151 | for key in ['EntityPages', 'SearchResults']: 152 | new_page_list = [] 153 | for page in datum.get(key, []): 154 | if page['DocPartOfVerifiedEval']: 155 | new_page_list.append(page) 156 | datum[key] = new_page_list 157 | assert len(datum['EntityPages']) + len(datum['SearchResults']) > 0 158 | return datum 159 | 160 | 161 | def read_triviaqa_data(qajson): 162 | data = read_json(qajson) 163 | # read only documents and questions that are a part of clean data set 164 | if data['VerifiedEval']: 165 | clean_data = [] 166 | for datum in data['Data']: 167 | if datum['QuestionPartOfVerifiedEval']: 168 | if data['Domain'] == 'Web': 169 | datum = read_clean_part(datum) 170 | clean_data.append(datum) 171 | data['Data'] = clean_data 172 | return data 173 | 174 | 175 | def get_args(): 176 | parser = argparse.ArgumentParser(description='Evaluation for TriviaQA') 177 | parser.add_argument('--dataset_file', help='Dataset file') 178 | parser.add_argument('--prediction_file', help='Prediction File') 179 | args = parser.parse_args() 180 | return args 181 | 182 | 183 | if __name__ == '__main__': 184 | expected_version = 1.0 185 | args = get_args() 186 | 187 | dataset_json = read_triviaqa_data(args.dataset_file) 188 | if dataset_json['Version'] != expected_version: 189 | print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']), 190 | file=sys.stderr) 191 | key_to_ground_truth = get_key_to_ground_truth(dataset_json) 192 | predictions = read_json(args.prediction_file) 193 | eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions, mute=True) 194 | print(eval_dict) -------------------------------------------------------------------------------- /docqa/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | from os.path import join 4 | from typing import List, TypeVar, Iterable 5 | 6 | from docqa.data_processing.word_vectors import load_word_vectors 7 | 8 | 9 | class ResourceLoader(object): 10 | """ 11 | Abstraction for models the need access to external resources to setup, currently just 12 | for word-vectors. 13 | """ 14 | 15 | def __init__(self, load_vec_fn=load_word_vectors): 16 | self.load_vec_fn = load_vec_fn 17 | 18 | def load_word_vec(self, vec_name, voc=None): 19 | return self.load_vec_fn(vec_name, voc) 20 | 21 | 22 | class LoadFromPath(object): 23 | def __init__(self, path): 24 | self.path = path 25 | 26 | def load_word_vec(self, vec_name, voc=None): 27 | return load_word_vectors(join(self.path, vec_name), voc, True) 28 | 29 | 30 | class CachingResourceLoader(ResourceLoader): 31 | 32 | def __init__(self, load_vec_fn=load_word_vectors): 33 | super().__init__(load_vec_fn) 34 | self.word_vec = {} 35 | 36 | def load_word_vec(self, vec_name, voc=None): 37 | if vec_name not in self.word_vec: 38 | self.word_vec[vec_name] = super().load_word_vec(vec_name) 39 | return self.word_vec[vec_name] 40 | 41 | 42 | def print_table(table: List[List[str]]): 43 | """ Print the lists with evenly spaced columns """ 44 | 45 | # print while padding each column to the max column length 46 | col_lens = [0] * len(table[0]) 47 | for row in table: 48 | for i,cell in enumerate(row): 49 | col_lens[i] = max(len(cell), col_lens[i]) 50 | 51 | formats = ["{0:<%d}" % x for x in col_lens] 52 | for row in table: 53 | print(" ".join(formats[i].format(row[i]) for i in range(len(row)))) 54 | 55 | T = TypeVar('T') 56 | 57 | 58 | def transpose_lists(lsts: List[List[T]]) -> List[List[T]]: 59 | return [list(i) for i in zip(*lsts)] 60 | 61 | 62 | def max_or_none(a, b): 63 | if a is None or b is None: 64 | return None 65 | return max(a, b) 66 | 67 | 68 | def flatten_iterable(listoflists: Iterable[Iterable[T]]) -> List[T]: 69 | return [item for sublist in listoflists for item in sublist] 70 | 71 | 72 | def split(lst: List[T], n_groups) -> List[List[T]]: 73 | """ partition `lst` into `n_groups` that are as evenly sized as possible """ 74 | per_group = len(lst) // n_groups 75 | remainder = len(lst) % n_groups 76 | groups = [] 77 | ix = 0 78 | for _ in range(n_groups): 79 | group_size = per_group 80 | if remainder > 0: 81 | remainder -= 1 82 | group_size += 1 83 | groups.append(lst[ix:ix + group_size]) 84 | ix += group_size 85 | return groups 86 | 87 | 88 | def group(lst: List[T], max_group_size) -> List[List[T]]: 89 | """ partition `lst` into that the mininal number of groups that as evenly sized 90 | as possible and are at most `max_group_size` in size """ 91 | if max_group_size is None: 92 | return [lst] 93 | n_groups = (len(lst)+max_group_size-1) // max_group_size 94 | per_group = len(lst) // n_groups 95 | remainder = len(lst) % n_groups 96 | groups = [] 97 | ix = 0 98 | for _ in range(n_groups): 99 | group_size = per_group 100 | if remainder > 0: 101 | remainder -= 1 102 | group_size += 1 103 | groups.append(lst[ix:ix + group_size]) 104 | ix += group_size 105 | return groups 106 | 107 | 108 | def get_output_name_from_cli(): 109 | parser = argparse.ArgumentParser(description='') 110 | parser.add_argument('--name', '-n', nargs=1, help='name of output to exmaine') 111 | 112 | args = parser.parse_args() 113 | if args.name: 114 | out = join(args.name[0] + "-" + datetime.now().strftime("%m%d-%H%M%S")) 115 | print("Starting run on: " + out) 116 | else: 117 | out = "out/run-" + datetime.now().strftime("%m%d-%H%M%S") 118 | print("Starting run on: " + out) 119 | return out 120 | 121 | -------------------------------------------------------------------------------- /requirements-exact.txt: -------------------------------------------------------------------------------- 1 | nltk==3.2.4 2 | numpy==1.13.1 3 | pandas==0.20.1 4 | scikit-learn==0.18.1 5 | scipy==0.19.0 6 | tqdm==4.14.0 7 | ujson==1.35 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | tqdm 3 | ujson 4 | # We require >=0.18 to support the "expanding" function 5 | pandas>=0.18.0 6 | scikit-learn 7 | scipy 8 | --------------------------------------------------------------------------------