├── squad ├── figures │ └── piqa_eval.jpg ├── static │ ├── files │ │ ├── pika.png │ │ ├── style.css │ │ ├── popper.min.js │ │ └── bootstrap.min.js │ └── index.html ├── requirements.txt ├── base │ ├── __init__.py │ ├── model.py │ ├── processor.py │ ├── argument_parser.py │ └── file_interface.py ├── baseline │ ├── __init__.py │ ├── argument_parser.py │ ├── file_interface.py │ ├── model.py │ └── processor.py ├── download.sh ├── scripts │ ├── benchmark.py │ └── tfidf.py ├── split.py ├── codalab.sh ├── evaluate.py ├── merge.py ├── piqa_evaluate.py ├── README.md └── main.py ├── README.md └── LICENSE /squad/figures/piqa_eval.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominjoon/piqa/HEAD/squad/figures/piqa_eval.jpg -------------------------------------------------------------------------------- /squad/static/files/pika.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominjoon/piqa/HEAD/squad/static/files/pika.png -------------------------------------------------------------------------------- /squad/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.4.1 2 | numpy==1.15.2 3 | scipy==1.1.0 4 | nltk==3.3 5 | allennlp==0.6.1 6 | 7 | tqdm 8 | gensim 9 | 10 | faiss 11 | 12 | tornado 13 | flask -------------------------------------------------------------------------------- /squad/base/__init__.py: -------------------------------------------------------------------------------- 1 | from base.argument_parser import ArgumentParser 2 | from base.file_interface import FileInterface 3 | from base.processor import Processor, Sampler 4 | from base.model import Model, Loss 5 | -------------------------------------------------------------------------------- /squad/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | from baseline.argument_parser import ArgumentParser 2 | from baseline.file_interface import FileInterface 3 | from baseline.processor import Processor, Sampler 4 | from baseline.model import Model, Loss 5 | -------------------------------------------------------------------------------- /squad/static/files/style.css: -------------------------------------------------------------------------------- 1 | html { position: relative; min-height: 100%; } 2 | body { margin-bottom: 60px; } 3 | .footer { position: absolute; bottom: 0; width: 100%; height: 40px; line-height: 15px; background-color: #f5f5f5; padding-top: 5px; font-size: 12px; text-align: center;} 4 | label, footer { user-select: none; } 5 | .list-group-item:first-of-type { background-color: #e0f2f1; color: #00695c; } -------------------------------------------------------------------------------- /squad/base/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | from torch import nn 4 | 5 | 6 | class Model(nn.Module, metaclass=ABCMeta): 7 | def forward(self, *input): 8 | """ 9 | :param input: 10 | :return: a dict of tensors 11 | """ 12 | raise NotImplementedError() 13 | 14 | def init(self, metadata): 15 | raise NotImplementedError() 16 | 17 | def get_context(self, *args, **kwargs): 18 | raise NotImplementedError() 19 | 20 | def get_question(self, *args, **kwargs): 21 | raise NotImplementedError() 22 | 23 | 24 | class Loss(nn.Module, metaclass=ABCMeta): 25 | def forward(self, *input): 26 | """ 27 | :param input: 28 | :return: a scalar tensor for the loss 29 | """ 30 | raise NotImplementedError() 31 | -------------------------------------------------------------------------------- /squad/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Install requirements; assuming Python 3.6 3 | pip install nltk==3.3 numpy==1.15.2 scipy==1.1.0 torch==0.4.1 allennlp==0.6.1 tqdm gensim 4 | 5 | DATA_DIR=$HOME/data/ 6 | mkdir $DATA_DIR 7 | 8 | # Download GloVe 9 | GLOVE_DIR=$DATA_DIR/glove 10 | mkdir $GLOVE_DIR 11 | wget http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip 12 | unzip $GLOVE_DIR/glove.6B.zip -d $GLOVE_DIR 13 | rm $GLOVE_DIR/glove.6B.zip 14 | 15 | # Download ELMo 16 | ELMO_DIR=$DATA_DIR/elmo 17 | mkdir $ELMO_DIR 18 | wget https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json -O $ELMO_DIR/options.json 19 | wget https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5 -O $ELMO_DIR/weights.hdf5 20 | 21 | -------------------------------------------------------------------------------- /squad/scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import time 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--num_iters', default=10, type=int) 8 | parser.add_argument('--num_vecs', default=10000, type=int) 9 | parser.add_argument('--dim', default=1024, type=int) 10 | args = parser.parse_args() 11 | 12 | query = np.random.randn(args.dim) 13 | 14 | # numpy experiment 15 | docs = [np.random.randn(args.num_vecs, args.dim) for _ in range(args.num_iters)] 16 | 17 | start_time = time.time() 18 | for i, doc in enumerate(docs): 19 | ans = np.argmax(np.matmul(doc, np.expand_dims(query, -1)), 0) 20 | duration = time.time() - start_time 21 | speed = args.num_vecs * args.num_iters / duration 22 | print('numpy: %.3f ms per %d vecs of %dD, or %d vecs/s' % (duration * 1000 / args.num_iters, args.num_vecs, args.dim, speed)) 23 | 24 | 25 | -------------------------------------------------------------------------------- /squad/split.py: -------------------------------------------------------------------------------- 1 | """Official split script for PI-SQuAD v0.1""" 2 | 3 | 4 | import argparse 5 | import json 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Official split script for PI-SQuAD v0.1') 9 | parser.add_argument('data_path', help='Dataset file path') 10 | parser.add_argument('context_path', help='Output path for context-only dataset') 11 | parser.add_argument('question_path', help='Output path for question-only dataset') 12 | args = parser.parse_args() 13 | 14 | with open(args.data_path, 'r') as fp: 15 | context = json.load(fp) 16 | with open(args.data_path, 'r') as fp: 17 | question = json.load(fp) 18 | 19 | for article in context['data']: 20 | for para in article['paragraphs']: 21 | del para['qas'] 22 | 23 | for article in question['data']: 24 | for para in article['paragraphs']: 25 | del para['context'] 26 | for qa in para['qas']: 27 | if 'answers' in qa: 28 | del qa['answers'] 29 | 30 | with open(args.context_path, 'w') as fp: 31 | json.dump(context, fp) 32 | 33 | with open(args.question_path, 'w') as fp: 34 | json.dump(question, fp) 35 | -------------------------------------------------------------------------------- /squad/base/processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | import torch.utils.data 4 | 5 | 6 | class Processor(metaclass=ABCMeta): 7 | def construct(self, examples, metadata): 8 | raise NotImplementedError() 9 | 10 | def state_dict(self): 11 | raise NotImplementedError() 12 | 13 | def load_state_dict(self, in_): 14 | raise NotImplementedError() 15 | 16 | def preprocess(self, example): 17 | raise NotImplementedError() 18 | 19 | def postprocess(self, example, model_output): 20 | raise NotImplementedError() 21 | 22 | def postprocess_batch(self, dataset, model_input, model_output): 23 | raise NotImplementedError() 24 | 25 | def postprocess_context(self, example, context_output): 26 | raise NotImplementedError() 27 | 28 | def postprocess_context_batch(self, dataset, model_input, context_output): 29 | raise NotImplementedError() 30 | 31 | def postprocess_question(self, example, question_output): 32 | raise NotImplementedError() 33 | 34 | def postprocess_question_batch(self, dataset, model_input, question_output): 35 | raise NotImplementedError() 36 | 37 | def collate(self, examples): 38 | raise NotImplementedError() 39 | 40 | def process_metadata(self, metadata): 41 | raise NotImplementedError() 42 | 43 | def get_dump(self, dataset, input_, output, results): 44 | raise NotImplementedError() 45 | 46 | 47 | class Sampler(torch.utils.data.Sampler, metaclass=ABCMeta): 48 | def __init__(self, dataset, data_type, **kwargs): 49 | self.dataset = dataset 50 | self.data_type = data_type 51 | -------------------------------------------------------------------------------- /squad/codalab.sh: -------------------------------------------------------------------------------- 1 | cl add bundle squad-data//dev-v1.1.json . 2 | cl run piqa-master:piqa-master dev-v1.1.json:dev-v1.1.json "python piqa-master/squad/split.py dev-v1.1.json dev-v1.1-context.json dev-v1.1-question.json" -n run-split 3 | cl make run-split/dev-v1.1-context.json -n dev-v1.1-context.json 4 | cl make run-split/dev-v1.1-question.json -n dev-v1.1-question.json 5 | 6 | # For LSTM Model 7 | # cl run dev-v1.1-context.json:dev-v1.1-context.json piqa-master:piqa-master model.pt:model.pt "python piqa-master/squad/main.py baseline --cuda --static_dir /static --load_dir model.pt --mode embed_context --test_path dev-v1.1-context.json --context_emb_dir context" -n run-context --request-docker-image minjoon/research:180908 --request-memory 8g --request-gpus 1 8 | # cl run dev-v1.1-question.json:dev-v1.1-question.json piqa-master:piqa-master model.pt:model.pt "python piqa-master/squad/main.py baseline --cuda --batch_size 256 --static_dir /static --load_dir model.pt --mode embed_question --test_path dev-v1.1-question.json --question_emb_dir question" -n run-question --request-docker-image minjoon/research:180908 --request-memory 8g --request-gpus 1 9 | 10 | # For LSTM+SA+ELMo Model 11 | cl run dev-v1.1-context.json:dev-v1.1-context.json piqa-master:piqa-master model.pt:model.pt "python piqa-master/squad/main.py baseline --elmo --num_heads 2 --batch_size 32 --cuda --static_dir /static --load_dir model.pt --mode embed_context --test_path dev-v1.1-context.json --context_emb_dir context" -n run-context --request-docker-image minjoon/research:180908 --request-disk 4g --request-memory 8g --request-gpus 1 12 | cl run dev-v1.1-question.json:dev-v1.1-question.json piqa-master:piqa-master model.pt:model.pt "python piqa-master/squad/main.py baseline --elmo --num_heads 2 --batch_size 128 --cuda --static_dir /static --load_dir model.pt --mode embed_question --test_path dev-v1.1-question.json --question_emb_dir question" -n run-question --request-docker-image minjoon/research:180908 --request-disk 4g --request-memory 8g --request-gpus 1 13 | 14 | cl run piqa-master:piqa-master dev-v1.1.json:dev-v1.1.json run-context:run-context run-question:run-question "python piqa-master/squad/merge.py dev-v1.1.json run-context/context run-question/question pred.json" -n run-merge --request-docker-image minjoon/research:180908 --request-disk 4g 15 | cl make run-merge/pred.json -n predictions-ELMo 16 | 17 | cl macro squad-utils/dev-evaluate-v1.1 predictions-ELMo 18 | cl edit predictions-ELMo --tags squad-test-submit pi-squad-test-submit -------------------------------------------------------------------------------- /squad/baseline/argument_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base 3 | 4 | 5 | class ArgumentParser(base.ArgumentParser): 6 | def __init__(self, description='baseline', **kwargs): 7 | super(ArgumentParser, self).__init__(description=description) 8 | 9 | def add_arguments(self): 10 | super().add_arguments() 11 | 12 | home = os.path.expanduser('~') 13 | 14 | # Metadata paths 15 | self.add_argument('--static_dir', type=str, default=os.path.join(home, 'data')) 16 | self.add_argument('--glove_dir', type=str, default=None, help='location of GloVe') 17 | self.add_argument('--elmo_options_file', type=str, default=None) 18 | self.add_argument('--elmo_weights_file', type=str, default=None) 19 | 20 | # Model arguments 21 | self.add_argument('--word_vocab_size', type=int, default=10000) 22 | self.add_argument('--char_vocab_size', type=int, default=100) 23 | self.add_argument('--glove_vocab_size', type=int, default=400002) 24 | self.add_argument('--glove_size', type=int, default=200) 25 | self.add_argument('--hidden_size', type=int, default=128) 26 | self.add_argument('--batch_size', type=int, default=64, help='batch size') 27 | self.add_argument('--elmo', default=False, action='store_true') 28 | self.add_argument('--num_heads', type=int, default=1) 29 | self.add_argument('--max_pool', default=False, action='store_true') 30 | self.add_argument('--num_layers', type=int, default=1) 31 | 32 | # Training arguments. Only valid during training 33 | self.add_argument('--dropout', type=float, default=0.2) 34 | self.add_argument('--max_context_size', type=int, default=256) 35 | self.add_argument('--max_question_size', type=int, default=32) 36 | self.add_argument('--no_bucket', default=False, action='store_true') 37 | self.add_argument('--no_shuffle', default=False, action='store_true') 38 | 39 | # Other arguments 40 | self.add_argument('--glove_cuda', default=False, action='store_true') 41 | 42 | def parse_args(self, **kwargs): 43 | args = super().parse_args() 44 | 45 | if args.draft: 46 | args.glove_vocab_size = 102 47 | 48 | if args.glove_dir is None: 49 | args.glove_dir = os.path.join(args.static_dir, 'glove') 50 | if args.elmo_options_file is None: 51 | args.elmo_options_file = os.path.join(args.static_dir, 'elmo', 'options.json') 52 | if args.elmo_weights_file is None: 53 | args.elmo_weights_file = os.path.join(args.static_dir, 'elmo', 'weights.hdf5') 54 | 55 | args.embed_size = args.glove_size 56 | args.glove_cpu = not args.glove_cuda 57 | args.bucket = not args.no_bucket 58 | args.shuffle = not args.no_shuffle 59 | return args 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Phrase-Indexed Question Answering (PIQA) 2 | - This is the official github repository for [Phrase-Indexed Question Answering: A New Challenge for Scalable Document Comprehension][paper] (EMNLP 2018). 3 | - Much of the work and code is heavily influenced by our former [project][mipsqa] at Google AI. 4 | - For inquiries, please contact [Minjoon Seo][minjoon] ([@seominjoon][minjoon-github]). 5 | - For citation, please use: 6 | ``` 7 | @inproceedings{seo2018phrase, 8 | title={Phrase-Indexed Question Answering: A New Challenge for Scalable Document Comprehension}, 9 | author={Seo, Minjoon and Kwiatkowski, Tom and Parikh, Ankur P and Farhadi, Ali and Hajishirzi, Hannaneh}, 10 | booktitle={EMNLP}, 11 | year={2018} 12 | } 13 | ``` 14 | 15 | ## Introduction 16 | We will assume that you have read the [paper][paper], though we will try to recap it here. PIQA challenge is about approaching (existing) extractive question answering tasks via phrase retrieval mechanism (we plan to hold the challenge for several extractive QA datasets in near future, though we currently only support PIQA for [SQuAD 1.1][squad].). This means we need: 17 | 18 | 1. **document encoder**: enumerates a list of (phrase, vector) pairs from the document, 19 | 2. **question encoder**: maps each question to the same vector space, and 20 | 3. **retrieval**: retrieves the (phrasal) answer to the question by performing nearest neighbor search on the list. 21 | 22 | While the challenge shares some similarities with document retrieval, a classic problem in information retrieval literature, a key difference is that the phrase representation will need to be *context-based*, which is more challenging than obtaining the embedding by its *content*. 23 | 24 | An important aspect of the challenge is the constraint of *independence* between the **document encoder** and the **question encoder**. As we have noted in our paper, most existing models heavily rely on question-dependent representations of the context document. Nevertheless, phrase representations in PIQA need to be completely *independent* of the input question. Not only this makes the challenge quite difficult, but also state-of-the-art models cannot be directly used for the task. Hence we have proposed a few reasonable baseline models as the starting point, which can be found in this repository. 25 | 26 | Note that it is also not so straightforward to strictly enforce the constraint on an evaluation platform such as CodaLab. For instance, current SQuAD 1.1 evaluator simply provides the test dataset (both context and question) without answers, and ask the model to output predictions, which are then compared against the answers. This setup is not great for PIQA because we cannot know if the submitted model abides the independence constraint. To resolve this issue, a submission should consist of the two encoders with explicit independence, and the retrieval is performed on the evaluator side. While it is not as convenient as a vanilla SQuAD submission, it strictly enforces independence constraint. 27 | 28 | ## Tasks 29 | 30 | - [Phrase-Indexed SQuAD][pi-squad] (PI-SQuAD) 31 | 32 | [paper]: https://arxiv.org/abs/1804.07726 33 | [minjoon]: https://seominjoon.github.io 34 | [minjoon-github]: https://github.com/seominjoon 35 | [jhyuklee-github]: https://github.com/jhyuklee 36 | [squad-train]: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json 37 | [squad-dev]: https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json 38 | [squad-context]: https://nlp.cs.washington.edu/piqa/squad/dev-v1.1-context.json 39 | [squad-question]: https://nlp.cs.washington.edu/piqa/squad/dev-v1.1-question.json 40 | [elmo]: https://allennlp.org/elmo 41 | [squad]: https://stanford-qa.com 42 | [mipsqa]: https://github.com/google/mipsqa 43 | [pi-squad]: squad/ 44 | -------------------------------------------------------------------------------- /squad/evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | """ Identical to SQuAD v1.1 evaluation script""" 3 | from __future__ import print_function 4 | from collections import Counter 5 | import string 6 | import re 7 | import argparse 8 | import json 9 | import sys 10 | 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | def remove_articles(text): 15 | return re.sub(r'\b(a|an|the)\b', ' ', text) 16 | 17 | def white_space_fix(text): 18 | return ' '.join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return ''.join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | 30 | def f1_score(prediction, ground_truth): 31 | prediction_tokens = normalize_answer(prediction).split() 32 | ground_truth_tokens = normalize_answer(ground_truth).split() 33 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 34 | num_same = sum(common.values()) 35 | if num_same == 0: 36 | return 0 37 | precision = 1.0 * num_same / len(prediction_tokens) 38 | recall = 1.0 * num_same / len(ground_truth_tokens) 39 | f1 = (2 * precision * recall) / (precision + recall) 40 | return f1 41 | 42 | 43 | def exact_match_score(prediction, ground_truth): 44 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 45 | 46 | 47 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 48 | scores_for_ground_truths = [] 49 | for ground_truth in ground_truths: 50 | score = metric_fn(prediction, ground_truth) 51 | scores_for_ground_truths.append(score) 52 | return max(scores_for_ground_truths) 53 | 54 | 55 | def evaluate(dataset, predictions): 56 | f1 = exact_match = total = 0 57 | for article in dataset: 58 | for paragraph in article['paragraphs']: 59 | for qa in paragraph['qas']: 60 | total += 1 61 | if qa['id'] not in predictions: 62 | message = 'Unanswered question ' + qa['id'] + \ 63 | ' will receive score 0.' 64 | print(message, file=sys.stderr) 65 | continue 66 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 67 | prediction = predictions[qa['id']] 68 | exact_match += metric_max_over_ground_truths( 69 | exact_match_score, prediction, ground_truths) 70 | f1 += metric_max_over_ground_truths( 71 | f1_score, prediction, ground_truths) 72 | 73 | exact_match = 100.0 * exact_match / total 74 | f1 = 100.0 * f1 / total 75 | 76 | return {'exact_match': exact_match, 'f1': f1} 77 | 78 | 79 | if __name__ == '__main__': 80 | expected_version = '1.1' 81 | parser = argparse.ArgumentParser( 82 | description='Evaluation for SQuAD ' + expected_version) 83 | parser.add_argument('dataset_file', help='Dataset file') 84 | parser.add_argument('prediction_file', help='Prediction File') 85 | args = parser.parse_args() 86 | with open(args.dataset_file) as dataset_file: 87 | dataset_json = json.load(dataset_file) 88 | if (dataset_json['version'] != expected_version): 89 | print('Evaluation expects v-' + expected_version + 90 | ', but got dataset with v-' + dataset_json['version'], 91 | file=sys.stderr) 92 | dataset = dataset_json['data'] 93 | with open(args.prediction_file) as prediction_file: 94 | predictions = json.load(prediction_file) 95 | print(json.dumps(evaluate(dataset, predictions))) 96 | -------------------------------------------------------------------------------- /squad/baseline/file_interface.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | 6 | import base 7 | 8 | 9 | class FileInterface(base.FileInterface): 10 | def __init__(self, glove_dir, glove_size, elmo_options_file, elmo_weights_file, **kwargs): 11 | self._glove_dir = glove_dir 12 | self._glove_size = glove_size 13 | self._elmo_options_file = elmo_options_file 14 | self._elmo_weights_file = elmo_weights_file 15 | super(FileInterface, self).__init__(**kwargs) 16 | 17 | def load_train(self): 18 | return _load_squad(self._train_path, draft=self._draft) 19 | 20 | def load_test(self): 21 | return _load_squad(self._test_path, draft=self._draft) 22 | 23 | def load_metadata(self): 24 | glove_vocab, glove_emb_mat = _load_glove(self._glove_size, glove_dir=self._glove_dir, draft=self._draft) 25 | return {'glove_vocab': glove_vocab, 26 | 'glove_emb_mat': glove_emb_mat, 27 | 'elmo_options_file': self._elmo_options_file, 28 | 'elmo_weights_file': self._elmo_weights_file} 29 | 30 | 31 | def _load_squad(squad_path, draft=False): 32 | with open(squad_path, 'r') as fp: 33 | squad = json.load(fp) 34 | examples = [] 35 | for article in squad['data']: 36 | for para_idx, paragraph in enumerate(article['paragraphs']): 37 | cid = '%s_%d' % (article['title'], para_idx) 38 | if 'context' in paragraph: 39 | context = paragraph['context'] 40 | context_example = {'cid': cid, 'context': context} 41 | else: 42 | context_example = {} 43 | 44 | if 'qas' in paragraph: 45 | for question_idx, qa in enumerate(paragraph['qas']): 46 | id_ = qa['id'] 47 | qid = '%s_%d' % (cid, question_idx) 48 | question = qa['question'] 49 | question_example = {'id': id_, 'qid': qid, 'question': question} 50 | if 'answers' in qa: 51 | answers, answer_starts, answer_ends = [], [], [] 52 | for answer in qa['answers']: 53 | answer_start = answer['answer_start'] 54 | answer_end = answer_start + len(answer['text']) 55 | answers.append(answer['text']) 56 | answer_starts.append(answer_start) 57 | answer_ends.append(answer_end) 58 | answer_example = {'answers': answers, 'answer_starts': answer_starts, 59 | 'answer_ends': answer_ends} 60 | question_example.update(answer_example) 61 | 62 | example = {'idx': len(examples)} 63 | example.update(context_example) 64 | example.update(question_example) 65 | examples.append(example) 66 | if draft and len(examples) == 100: 67 | return examples 68 | else: 69 | example = {'idx': len(examples)} 70 | example.update(context_example) 71 | examples.append(example) 72 | if draft and len(examples) == 100: 73 | return examples 74 | return examples 75 | 76 | 77 | def _load_glove(size, glove_dir=None, draft=False): 78 | if glove_dir is None: 79 | glove_url = 'http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip' 80 | raise NotImplementedError() 81 | 82 | glove_path = os.path.join(glove_dir, 'glove.6B.%dd.txt' % size) 83 | with open(glove_path, 'rb') as fp: 84 | vocab = [] 85 | vecs = [] 86 | for idx, line in enumerate(fp): 87 | line = line.decode('utf-8') 88 | tokens = line.strip().split(u' ') 89 | word = tokens[0] 90 | vec = list(map(float, tokens[1:])) 91 | vecs.append(vec) 92 | vocab.append(word) 93 | if draft and idx >= 99: 94 | break 95 | emb_mat = np.array(vecs, dtype=np.float32) 96 | return vocab, emb_mat 97 | 98 | -------------------------------------------------------------------------------- /squad/scripts/tfidf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import nltk 5 | 6 | from gensim import corpora, models, similarities 7 | from tqdm import tqdm 8 | 9 | 10 | def load_squad(squad_path, draft=False): 11 | with open(squad_path, 'r') as fp: 12 | squad = json.load(fp) 13 | examples = [] 14 | for article in squad['data']: 15 | for paragraph in article['paragraphs']: 16 | context = paragraph['context'] 17 | for qa in paragraph['qas']: 18 | question = qa['question'] 19 | id_ = qa['id'] 20 | answers, answer_starts, answer_ends = [], [], [] 21 | for answer in qa['answers']: 22 | answer_start = answer['answer_start'] 23 | answer_end = answer_start + len(answer['text']) 24 | answers.append(answer['text']) 25 | answer_starts.append(answer_start) 26 | answer_ends.append(answer_end) 27 | 28 | # to avoid csv compatibility issue 29 | context = context.replace('\n', '\t') 30 | 31 | example = {'id': id_, 32 | 'idx': len(examples), 33 | 'context': context, 34 | 'question': question, 35 | 'answers': answers, 36 | 'answer_starts': answer_starts, 37 | 'answer_ends': answer_ends} 38 | examples.append(example) 39 | if draft and len(examples) == 100: 40 | return examples 41 | return examples 42 | 43 | 44 | def tokenize(in_): 45 | in_ = in_.replace('``', '" ').replace("''", '" ').replace('\t', ' ') 46 | words = nltk.word_tokenize(in_) 47 | words = [word.replace('``', '"').replace("''", '"') for word in words] 48 | return words 49 | 50 | 51 | def get_phrases_and_documents(context, nbr_len=7, max_ans_len=7, lower=False): 52 | words = tokenize(context) 53 | doc_words = [word.lower() for word in words] if lower else words 54 | phrases = [] 55 | documents = [] 56 | for i in range(len(words)): 57 | for j in range(i+1, min(len(words), i+max_ans_len)+1): 58 | phrase = ' '.join(words[i:j]) 59 | document = doc_words[max(0, i-nbr_len):i] + doc_words[j:min(len(words), j+nbr_len)] 60 | if len(document) == 0: 61 | continue 62 | phrases.append(phrase) 63 | documents.append(document) 64 | return phrases, documents 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser(description='TF-IDF') 69 | parser.add_argument('data_path') 70 | parser.add_argument('out_path') 71 | parser.add_argument('--draft', default=False, action='store_true') 72 | parser.add_argument('--nbr_len', default=7, type=int) 73 | parser.add_argument('--max_ans_len', default=7, type=int) 74 | parser.add_argument('--lower', default=False, action='store_true') 75 | args = parser.parse_args() 76 | 77 | examples = load_squad(args.data_path, draft=args.draft) 78 | 79 | out_dict = {} 80 | for example in tqdm(examples): 81 | query = tokenize(example['question']) 82 | phrases, documents = get_phrases_and_documents(example['context'], 83 | nbr_len=args.nbr_len, 84 | max_ans_len=args.max_ans_len, 85 | lower=args.lower) 86 | dictionary = corpora.Dictionary(documents) 87 | corpus = [dictionary.doc2bow(doc) for doc in documents] 88 | tfidf = models.TfidfModel(corpus) 89 | index = similarities.MatrixSimilarity(tfidf[corpus]) 90 | sims = index[tfidf[dictionary.doc2bow(query)]] 91 | phrase = phrases[max(enumerate(sims), key=lambda item: item[1])[0]] 92 | out_dict[example['id']] = phrase 93 | if args.draft: 94 | print(example['question']) 95 | break 96 | 97 | with open(args.out_path, 'w') as fp: 98 | json.dump(out_dict, fp) 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /squad/base/argument_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | class ArgumentParser(argparse.ArgumentParser): 6 | def __init__(self, description='base', **kwargs): 7 | super(ArgumentParser, self).__init__(description=description) 8 | 9 | def add_arguments(self): 10 | home = os.path.expanduser('~') 11 | self.add_argument('model', type=str) 12 | 13 | self.add_argument('--mode', type=str, default='train') 14 | self.add_argument('--iteration', type=str, default='0') 15 | self.add_argument('--pause', type=int, default=0) # ignore this argument. 16 | 17 | # Data (input) paths 18 | self.add_argument('--train_path', type=str, default=os.path.join(home, 'data', 'squad', 'train-v1.1.json'), 19 | help='location of the training data') 20 | self.add_argument('--test_path', type=str, default=os.path.join(home, 'data', 'squad', 'dev-v1.1.json'), 21 | help='location of the test data') 22 | 23 | # Output paths 24 | self.add_argument('--output_dir', type=str, default='/tmp/piqa/squad/', help='Output directory') 25 | self.add_argument('--save_dir', type=str, default=None, help='location for saving the model') 26 | self.add_argument('--load_dir', type=str, default=None, help='location for loading the model') 27 | self.add_argument('--dump_dir', type=str, default=None, help='location for dumping outputs') 28 | self.add_argument('--report_path', type=str, default=None, help='location for report') 29 | self.add_argument('--pred_path', type=str, default=None, help='location for prediction json file during `test`') 30 | self.add_argument('--cache_path', type=str, default=None) 31 | self.add_argument('--question_emb_dir', type=str, default=None) 32 | self.add_argument('--context_emb_dir', type=str, default=None) 33 | 34 | # Training arguments 35 | self.add_argument('--epochs', type=int, default=20) 36 | self.add_argument('--train_steps', type=int, default=0) 37 | self.add_argument('--eval_steps', type=int, default=1000) 38 | self.add_argument('--eval_save_period', type=int, default=500) 39 | self.add_argument('--report_period', type=int, default=100) 40 | 41 | # Similarity search (faiss, pysparnn) arguments 42 | self.add_argument('--metric', type=str, default='ip', help='ip|l2') 43 | self.add_argument('--nlist', type=int, default=1) 44 | self.add_argument('--nprobe', type=int, default=1) 45 | self.add_argument('--bpv', type=int, default=None, help='bytes per vector (e.g. 8)') 46 | self.add_argument('--num_train_mats', type=int, default=100) 47 | 48 | # Demo arguments 49 | self.add_argument('--port', type=int, default=8080) 50 | 51 | # Other arguments 52 | self.add_argument('--draft', default=False, action='store_true') 53 | self.add_argument('--cuda', default=False, action='store_true') 54 | self.add_argument('--preload', default=False, action='store_true') 55 | self.add_argument('--cache', default=False, action='store_true') 56 | self.add_argument('--archive', default=False, action='store_true') 57 | self.add_argument('--dump_period', type=int, default=20) 58 | self.add_argument('--emb_type', type=str, default='dense', help='dense|sparse') 59 | self.add_argument('--metadata', default=False, action='store_true') 60 | self.add_argument('--mem_info', default=False, action='store_true') 61 | 62 | def parse_args(self, **kwargs): 63 | args = super().parse_args() 64 | if args.draft: 65 | args.batch_size = 2 66 | args.eval_steps = 1 67 | args.eval_save_period = 2 68 | args.train_steps = 2 69 | 70 | if args.save_dir is None: 71 | args.save_dir = os.path.join(args.output_dir, 'save') 72 | if args.load_dir is None: 73 | args.load_dir = os.path.join(args.output_dir, 'save') 74 | if args.dump_dir is None: 75 | args.dump_dir = os.path.join(args.output_dir, 'dump') 76 | if args.question_emb_dir is None: 77 | args.question_emb_dir = os.path.join(args.output_dir, 'question_emb') 78 | if args.context_emb_dir is None: 79 | args.context_emb_dir = os.path.join(args.output_dir, 'context_emb') 80 | if args.report_path is None: 81 | args.report_path = os.path.join(args.output_dir, 'report.csv') 82 | if args.pred_path is None: 83 | args.pred_path = os.path.join(args.output_dir, 'pred.json') 84 | if args.cache_path is None: 85 | args.cache_path = os.path.join(args.output_dir, 'cache.b') 86 | 87 | args.load_dir = os.path.abspath(args.load_dir) 88 | args.context_emb_dir = os.path.abspath(args.context_emb_dir) 89 | args.question_emb_dir = os.path.abspath(args.question_emb_dir) 90 | 91 | return args 92 | -------------------------------------------------------------------------------- /squad/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | PIQA Demo 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 26 |
27 |
28 | pika 29 |
30 | 31 |
32 | 34 |
35 | 38 |
39 |
40 | 41 |
42 |
Latency:
43 |
44 | 45 | 46 |
47 |
48 |
49 | 50 |
51 | 54 |
55 |
56 | 57 | 65 | 66 | 67 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /squad/merge.py: -------------------------------------------------------------------------------- 1 | """Official merge script for PI-SQuAD v0.1""" 2 | from __future__ import print_function 3 | 4 | import os 5 | import argparse 6 | import json 7 | import sys 8 | import shutil 9 | 10 | import scipy.sparse 11 | import scipy.sparse.linalg 12 | import numpy as np 13 | import numpy.linalg 14 | 15 | 16 | def get_q2c(dataset): 17 | q2c = {} 18 | for article in dataset: 19 | for para_idx, paragraph in enumerate(article['paragraphs']): 20 | cid = '%s_%d' % (article['title'], para_idx) 21 | for qa in paragraph['qas']: 22 | q2c[qa['id']] = cid 23 | return q2c 24 | 25 | 26 | def get_predictions(context_emb_path, question_emb_path, q2c, sparse=False, metric='ip', progress=False): 27 | context_emb_dir, context_emb_ext = os.path.splitext(context_emb_path) 28 | question_emb_dir, question_emb_ext = os.path.splitext(question_emb_path) 29 | if context_emb_ext == '.zip': 30 | print('Extracting %s to %s' % (context_emb_path, context_emb_dir)) 31 | shutil.unpack_archive(context_emb_path, context_emb_dir) 32 | if question_emb_ext == '.zip': 33 | print('Extracting %s to %s' % (question_emb_path, question_emb_dir)) 34 | shutil.unpack_archive(question_emb_path, question_emb_dir) 35 | 36 | if progress: 37 | from tqdm import tqdm 38 | else: 39 | tqdm = lambda x: x 40 | predictions = {} 41 | for id_, cid in tqdm(q2c.items()): 42 | q_emb_path = os.path.join(question_emb_dir, '%s.npz' % id_) 43 | c_emb_path = os.path.join(context_emb_dir, '%s.npz' % cid) 44 | c_json_path = os.path.join(context_emb_dir, '%s.json' % cid) 45 | 46 | if not os.path.exists(q_emb_path): 47 | print('Missing %s' % q_emb_path) 48 | continue 49 | if not os.path.exists(c_emb_path): 50 | print('Missing %s' % c_emb_path) 51 | continue 52 | if not os.path.exists(c_json_path): 53 | print('Missing %s' % c_json_path) 54 | continue 55 | 56 | load = scipy.sparse.load_npz if sparse else np.load 57 | q_emb = load(q_emb_path) # shape = [M, d], d is the embedding size. 58 | c_emb = load(c_emb_path) # shape = [N, d], d is the embedding size. 59 | 60 | with open(c_json_path, 'r') as fp: 61 | phrases = json.load(fp) 62 | 63 | if sparse: 64 | if metric == 'ip': 65 | sim = c_emb * q_emb.T 66 | m = sim.max(1) 67 | m = np.squeeze(np.array(m.todense()), 1) 68 | elif metric == 'l1': 69 | m = scipy.sparse.linalg.norm(c_emb - q_emb, ord=1, axis=1) 70 | elif metric == 'l2': 71 | m = scipy.sparse.linalg.norm(c_emb - q_emb, ord=2, axis=1) 72 | else: 73 | q_emb = q_emb['arr_0'] 74 | c_emb = c_emb['arr_0'] 75 | if metric == 'ip': 76 | sim = np.matmul(c_emb, q_emb.T) 77 | m = sim.max(1) 78 | elif metric == 'l1': 79 | m = numpy.linalg.norm(c_emb - q_emb, ord=1, axis=1) 80 | elif metric == 'l2': 81 | m = numpy.linalg.norm(c_emb - q_emb, ord=2, axis=1) 82 | 83 | argmax = m.argmax(0) 84 | predictions[id_] = phrases[argmax] 85 | 86 | if context_emb_ext == '.zip': 87 | shutil.rmtree(context_emb_dir) 88 | if question_emb_ext == '.zip': 89 | shutil.rmtree(question_emb_dir) 90 | 91 | return predictions 92 | 93 | 94 | if __name__ == '__main__': 95 | squad_expected_version = '1.1' 96 | parser = argparse.ArgumentParser(description='Official merge script for PI-SQuAD v0.1') 97 | parser.add_argument('data_path', help='Dataset file path') 98 | parser.add_argument('context_emb_dir', help='Context embedding directory') 99 | parser.add_argument('question_emb_dir', help='Question embedding directory') 100 | parser.add_argument('pred_path', help='Prediction json file path') 101 | parser.add_argument('--sparse', default=False, action='store_true', 102 | help='Whether the embeddings are scipy.sparse or pure numpy.') 103 | parser.add_argument('--metric', type=str, default='ip', 104 | help='ip|l1|l2 (inner product or L1 or L2 distance)') 105 | parser.add_argument('--progress', default=False, action='store_true', help='Show progress bar. Requires `tqdm`.') 106 | args = parser.parse_args() 107 | 108 | with open(args.data_path) as dataset_file: 109 | dataset_json = json.load(dataset_file) 110 | if dataset_json['version'] != squad_expected_version: 111 | print('Evaluation expects v-' + squad_expected_version + 112 | ', but got dataset with v-' + dataset_json['version'], 113 | file=sys.stderr) 114 | dataset = dataset_json['data'] 115 | q2c = get_q2c(dataset) 116 | predictions = get_predictions(args.context_emb_dir, args.question_emb_dir, q2c, sparse=args.sparse, 117 | metric=args.metric, progress=args.progress) 118 | 119 | with open(args.pred_path, 'w') as fp: 120 | json.dump(predictions, fp) 121 | 122 | -------------------------------------------------------------------------------- /squad/piqa_evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official alpha evaluation script for PIQA (inherited from SQuAD v1.1 evaluation script).""" 2 | from __future__ import print_function 3 | 4 | import os 5 | from collections import Counter 6 | import string 7 | import re 8 | import argparse 9 | import json 10 | import sys 11 | import shutil 12 | 13 | import scipy.sparse 14 | import numpy as np 15 | 16 | 17 | def normalize_answer(s): 18 | """Lower text and remove punctuation, articles and extra whitespace.""" 19 | 20 | def remove_articles(text): 21 | return re.sub(r'\b(a|an|the)\b', ' ', text) 22 | 23 | def white_space_fix(text): 24 | return ' '.join(text.split()) 25 | 26 | def remove_punc(text): 27 | exclude = set(string.punctuation) 28 | return ''.join(ch for ch in text if ch not in exclude) 29 | 30 | def lower(text): 31 | return text.lower() 32 | 33 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 34 | 35 | 36 | def f1_score(prediction, ground_truth): 37 | prediction_tokens = normalize_answer(prediction).split() 38 | ground_truth_tokens = normalize_answer(ground_truth).split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return 0 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 51 | 52 | 53 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 54 | scores_for_ground_truths = [] 55 | for ground_truth in ground_truths: 56 | score = metric_fn(prediction, ground_truth) 57 | scores_for_ground_truths.append(score) 58 | return max(scores_for_ground_truths) 59 | 60 | 61 | def evaluate(dataset, predictions): 62 | f1 = exact_match = total = 0 63 | for article in dataset: 64 | for paragraph in article['paragraphs']: 65 | for qa in paragraph['qas']: 66 | total += 1 67 | if qa['id'] not in predictions: 68 | message = 'Unanswered question ' + qa['id'] + \ 69 | ' will receive score 0.' 70 | print(message, file=sys.stderr) 71 | continue 72 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 73 | prediction = predictions[qa['id']] 74 | exact_match += metric_max_over_ground_truths( 75 | exact_match_score, prediction, ground_truths) 76 | f1 += metric_max_over_ground_truths( 77 | f1_score, prediction, ground_truths) 78 | 79 | exact_match = 100.0 * exact_match / total 80 | f1 = 100.0 * f1 / total 81 | 82 | return {'exact_match': exact_match, 'f1': f1} 83 | 84 | 85 | def get_q2c(dataset): 86 | q2c = {} 87 | for article in dataset: 88 | for para_idx, paragraph in enumerate(article['paragraphs']): 89 | cid = '%s_%d' % (article['title'], para_idx) 90 | for qa in paragraph['qas']: 91 | q2c[qa['id']] = cid 92 | return q2c 93 | 94 | 95 | def get_predictions(context_emb_path, question_emb_path, q2c, sparse=False, progress=False): 96 | context_emb_dir, context_emb_ext = os.path.splitext(context_emb_path) 97 | question_emb_dir, question_emb_ext = os.path.splitext(question_emb_path) 98 | if context_emb_ext == '.zip': 99 | print('Extracting %s to %s' % (context_emb_path, context_emb_dir)) 100 | shutil.unpack_archive(context_emb_path, context_emb_dir) 101 | if question_emb_ext == '.zip': 102 | print('Extracting %s to %s' % (question_emb_path, question_emb_dir)) 103 | shutil.unpack_archive(question_emb_path, question_emb_dir) 104 | 105 | if progress: 106 | from tqdm import tqdm 107 | else: 108 | tqdm = lambda x: x 109 | predictions = {} 110 | for id_, cid in tqdm(q2c.items()): 111 | q_emb_path = os.path.join(question_emb_dir, '%s.npz' % id_) 112 | c_emb_path = os.path.join(context_emb_dir, '%s.npz' % cid) 113 | c_json_path = os.path.join(context_emb_dir, '%s.json' % cid) 114 | 115 | if not os.path.exists(q_emb_path): 116 | print('Missing %s' % q_emb_path) 117 | continue 118 | if not os.path.exists(c_emb_path): 119 | print('Missing %s' % c_emb_path) 120 | continue 121 | if not os.path.exists(c_json_path): 122 | print('Missing %s' % c_json_path) 123 | continue 124 | 125 | load = scipy.sparse.load_npz if sparse else np.load 126 | q_emb = load(q_emb_path) # shape = [M, d], d is the embedding size. 127 | c_emb = load(c_emb_path) # shape = [N, d], d is the embedding size. 128 | 129 | with open(c_json_path, 'r') as fp: 130 | phrases = json.load(fp) 131 | 132 | if sparse: 133 | sim = c_emb * q_emb.T 134 | m = sim.max(1) 135 | m = np.squeeze(np.array(m.todense()), 1) 136 | else: 137 | q_emb = q_emb['arr_0'] 138 | c_emb = c_emb['arr_0'] 139 | sim = np.matmul(c_emb, q_emb.T) 140 | m = sim.max(1) 141 | 142 | argmax = m.argmax(0) 143 | predictions[id_] = phrases[argmax] 144 | 145 | # Dump piqa_pred 146 | # with open('test/piqa_pred.json', 'w') as f: 147 | # f.write(json.dumps(predictions)) 148 | 149 | if context_emb_ext == '.zip': 150 | shutil.rmtree(context_emb_dir) 151 | if question_emb_ext == '.zip': 152 | shutil.rmtree(question_emb_dir) 153 | 154 | return predictions 155 | 156 | 157 | if __name__ == '__main__': 158 | expected_version = '1.1' 159 | parser = argparse.ArgumentParser( 160 | description='Evaluation for SQuAD ' + expected_version) 161 | parser.add_argument('dataset_file', help='Dataset file') 162 | parser.add_argument('context_emb_dir', help='Context embedding directory') 163 | parser.add_argument('question_emb_dir', help='Question embedding directory') 164 | parser.add_argument('--sparse', default=False, action='store_true', 165 | help='Whether the embeddings are scipy.sparse or pure numpy.') 166 | parser.add_argument('--progress', default=False, action='store_true', help='Show progress bar. Requires `tqdm`.') 167 | args = parser.parse_args() 168 | with open(args.dataset_file) as dataset_file: 169 | dataset_json = json.load(dataset_file) 170 | if (dataset_json['version'] != expected_version): 171 | print('Evaluation expects v-' + expected_version + 172 | ', but got dataset with v-' + dataset_json['version'], 173 | file=sys.stderr) 174 | dataset = dataset_json['data'] 175 | q2c = get_q2c(dataset) 176 | predictions = get_predictions(args.context_emb_dir, args.question_emb_dir, q2c, sparse=args.sparse, 177 | progress=args.progress) 178 | print(json.dumps(evaluate(dataset, predictions))) 179 | -------------------------------------------------------------------------------- /squad/base/file_interface.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import shutil 5 | 6 | import torch 7 | import scipy.sparse 8 | import numpy as np 9 | import csv 10 | 11 | 12 | class FileInterface(object): 13 | def __init__(self, cuda, mode, save_dir, load_dir, report_path, pred_path, question_emb_dir, context_emb_dir, 14 | cache_path, dump_dir, train_path, test_path, draft, **kwargs): 15 | self._cuda = cuda 16 | self._mode = mode 17 | self._train_path = train_path 18 | self._test_path = test_path 19 | self._save_dir = save_dir 20 | self._load_dir = load_dir 21 | self._report_path = report_path 22 | self._dump_dir = dump_dir 23 | self._pred_path = pred_path 24 | self._question_emb_dir = os.path.splitext(question_emb_dir)[0] 25 | self._context_emb_dir = os.path.splitext(context_emb_dir)[0] 26 | self._cache_path = cache_path 27 | self._args_path = os.path.join(save_dir, 'args.json') 28 | self._draft = draft 29 | self._save = None 30 | self._load = None 31 | self._report_header = [] 32 | self._report = [] 33 | self._kwargs = kwargs 34 | 35 | def _bind(self, save=None, load=None): 36 | self._save = save 37 | self._load = load 38 | 39 | def save(self, iteration, save_fn=None): 40 | filename = os.path.join(self._save_dir, str(iteration)) 41 | if not os.path.exists(filename): 42 | os.makedirs(filename) 43 | if save_fn is None: 44 | save_fn = self._save 45 | save_fn(filename) 46 | 47 | def save_args(self, args): 48 | if not os.path.exists(self._save_dir): 49 | os.makedirs(self._save_dir) 50 | with open(self._args_path, 'w') as fp: 51 | json.dump(args, fp) 52 | 53 | def load(self, iteration='0', load_fn=None, session=None): 54 | if session is None: 55 | session = self._load_dir 56 | if iteration == '0': 57 | filename = session 58 | else: 59 | filename = os.path.join(session, str(iteration), 'model.pt') 60 | if load_fn is None: 61 | load_fn = self._load 62 | load_fn(filename) 63 | 64 | def pred(self, pred): 65 | if not os.path.exists(os.path.dirname(self._pred_path)): 66 | os.makedirs(os.path.dirname(self._pred_path)) 67 | with open(self._pred_path, 'w') as fp: 68 | json.dump(pred, fp) 69 | print('Prediction saved at %s' % self._pred_path) 70 | 71 | def report(self, summary=False, **kwargs): 72 | if not os.path.exists(os.path.dirname(self._report_path)): 73 | os.makedirs(os.path.dirname(self._report_path)) 74 | if len(self._report) == 0 and os.path.exists(self._report_path): 75 | with open(self._report_path, 'r') as fp: 76 | reader = csv.DictReader(fp, delimiter=',') 77 | rows = list(reader) 78 | for key in rows[0]: 79 | if key not in self._report_header: 80 | self._report_header.append(key) 81 | self._report.extend(rows) 82 | 83 | for key, val in kwargs.items(): 84 | if key not in self._report_header: 85 | self._report_header.append(key) 86 | self._report.append(kwargs) 87 | with open(self._report_path, 'w') as fp: 88 | writer = csv.DictWriter(fp, delimiter=',', fieldnames=self._report_header) 89 | writer.writeheader() 90 | writer.writerows(self._report) 91 | return ', '.join('%s=%.5r' % (s, r) for s, r in kwargs.items()) 92 | 93 | def question_emb(self, id_, emb, emb_type='dense'): 94 | if not os.path.exists(self._question_emb_dir): 95 | os.makedirs(self._question_emb_dir) 96 | savez = scipy.sparse.save_npz if emb_type == 'sparse' else np.savez_compressed 97 | path = os.path.join(self._question_emb_dir, '%s.npz' % id_) 98 | savez(path, emb) 99 | 100 | def context_emb(self, id_, phrases, emb, metadata=None, emb_type='dense'): 101 | if not os.path.exists(self._context_emb_dir): 102 | os.makedirs(self._context_emb_dir) 103 | savez = scipy.sparse.save_npz if emb_type == 'sparse' else np.savez_compressed 104 | emb_path = os.path.join(self._context_emb_dir, '%s.npz' % id_) 105 | json_path = os.path.join(self._context_emb_dir, '%s.json' % id_) 106 | 107 | if os.path.exists(emb_path): 108 | print('Skipping %s; already exists' % emb_path) 109 | else: 110 | savez(emb_path, emb) 111 | if os.path.exists(json_path): 112 | print('Skipping %s; already exists' % json_path) 113 | else: 114 | with open(json_path, 'w') as fp: 115 | json.dump(phrases, fp) 116 | 117 | if metadata is not None: 118 | metadata_path = os.path.join(self._context_emb_dir, '%s.metadata' % id_) 119 | with open(metadata_path, 'w') as fp: 120 | json.dump(metadata, fp) 121 | 122 | def context_load(self, metadata=False, emb_type='dense', shuffle=True): 123 | paths = os.listdir(self._context_emb_dir) 124 | if shuffle: 125 | random.shuffle(paths) 126 | json_paths = tuple(os.path.join(self._context_emb_dir, path) 127 | for path in paths if os.path.splitext(path)[1] == '.json') 128 | npz_paths = tuple('%s.npz' % os.path.splitext(path)[0] for path in json_paths) 129 | metadata_paths = tuple('%s.metadata' % os.path.splitext(path)[0] for path in json_paths) 130 | for json_path, npz_path, metadata_path in zip(json_paths, npz_paths, metadata_paths): 131 | with open(json_path, 'r') as fp: 132 | phrases = json.load(fp) 133 | if emb_type == 'dense': 134 | emb = np.load(npz_path)['arr_0'] 135 | else: 136 | emb = scipy.sparse.load_npz(npz_path) 137 | if metadata: 138 | with open(metadata_path, 'r') as fp: 139 | metadata = json.load(fp) 140 | yield phrases, emb, metadata 141 | else: 142 | yield phrases, emb 143 | 144 | def archive(self): 145 | if self._mode == 'embed' or self._mode == 'embed_context': 146 | shutil.make_archive(self._context_emb_dir, 'zip', self._context_emb_dir) 147 | shutil.rmtree(self._context_emb_dir) 148 | 149 | if self._mode == 'embed' or self._mode == 'embed_question': 150 | shutil.make_archive(self._question_emb_dir, 'zip', self._question_emb_dir) 151 | shutil.rmtree(self._question_emb_dir) 152 | 153 | def cache(self, preprocess, args): 154 | if os.path.exists(self._cache_path): 155 | return torch.load(self._cache_path) 156 | out = preprocess(self, args) 157 | torch.save(out, self._cache_path) 158 | return out 159 | 160 | def dump(self, batch_idx, item): 161 | filename = os.path.join(self._dump_dir, '%s.pt' % str(batch_idx).zfill(6)) 162 | dirname = os.path.dirname(filename) 163 | if not os.path.exists(dirname): 164 | os.makedirs(dirname) 165 | torch.save(item, filename) 166 | 167 | def bind(self, processor, model, optimizer=None): 168 | def load(filename, **kwargs): 169 | # filename = os.path.join(filename, 'model.pt') 170 | state = torch.load(filename, map_location=None if self._cuda else 'cpu') 171 | processor.load_state_dict(state['preprocessor']) 172 | model.load_state_dict(state['model']) 173 | if 'optimizer' in state and optimizer: 174 | optimizer.load_state_dict(state['optimizer']) 175 | print('Model loaded from %s' % filename) 176 | 177 | def save(filename, **kwargs): 178 | state = { 179 | 'preprocessor': processor.state_dict(), 180 | 'model': model.state_dict(), 181 | 'optimizer': optimizer.state_dict() 182 | } 183 | filename = os.path.join(filename, 'model.pt') 184 | torch.save(state, filename) 185 | print('Model saved at %s' % filename) 186 | 187 | def infer(input, top_k=100): 188 | # input = {'id': '', 'question': '', 'context': ''} 189 | model.eval() 190 | 191 | self._bind(save=save, load=load) 192 | 193 | def load_train(self): 194 | raise NotImplementedError() 195 | 196 | def load_test(self): 197 | raise NotImplementedError() 198 | 199 | def load_metadata(self): 200 | raise NotImplementedError() 201 | -------------------------------------------------------------------------------- /squad/README.md: -------------------------------------------------------------------------------- 1 | # Phrase-Indexed SQuAD (PI-SQuAD) 2 | 3 | ## Baseline Models 4 | 5 | ### 0. Download requirements 6 | Make sure you have Python 3.6 or later. Download and install all requirements by: 7 | 8 | ```bash 9 | chmod +x download.sh; ./download.sh 10 | ``` 11 | 12 | This will install following python packages: 13 | 14 | - `numpy==1.15.2`, `scipy==1.1.0`, `torch==0.4.1`, `nltk==3.3`: essential packages 15 | - `allenlp==0.6.1`: only if you want to try using [ELMo][elmo]; the installation takes some time. 16 | - `tqdm`: optional, progressbar tool. 17 | - `gensim`: optional, running tf-idf experiments in `scripts`. 18 | 19 | This will also download several things at `$HOME/data` (consider changing it to your favorite location): 20 | 21 | - nltk word tokenizer 22 | - GloVe 6B 23 | - ELMo options and weights files 24 | 25 | You might also want to consider using docker image `minjoon/research:180908` that meets all of these requirements. 26 | The [sample CodaLab submission][worksheet-elmo] also uses this. 27 | 28 | Download the original SQuAD v1.1 train and dev set at [`$SQUAD_TRAIN_PATH`][squad-train] and [`$SQUAD_DEV_PATH`][squad-dev], respectively. The default directory that our baseline model searches in is `$HOME/data/squad`. 29 | 30 | #### Demo 31 | In order to run demo, you will also need: 32 | 33 | - `faiss`, `pysparnn`: similarity search packages for dense and sparse vectors respectively. 34 | - `tornado`, `flask`: to serve the demo on the web. 35 | 36 | The easiest way to install `faiss` is via `conda`: `conda install faiss -c pytorch`. So you might want to consider using `conda` before installing the requirements above. 37 | 38 | 39 | ### 1. Training 40 | In our [paper][paper], we have introduced three baseline models: 41 | 42 | For LSTM model: 43 | 44 | ```bash 45 | python main.py baseline --cuda --train_path $SQUAD_TRAIN_PATH --test_path $SQUAD_DEV_PATH 46 | ``` 47 | 48 | For LSTM+SA model: 49 | 50 | ```bash 51 | python main.py baseline --cuda --num_heads 2 --train_path $SQUAD_TRAIN_PATH --test_path $SQUAD_DEV_PATH 52 | ``` 53 | 54 | For LSTM+SA+ELMo model: 55 | 56 | ```bash 57 | python main.py baseline --cuda --num_heads 2 --elmo --train_path --batch_size 32 $SQUAD_TRAIN_PATH --test_path $SQUAD_DEV_PATH 58 | ``` 59 | 60 | Note that the first positional argument, `baseline`, indicates that we are using the python modules in `./baseline/` directory. 61 | In future, you can easily add a new model by creating a new module (e.g. `./my_model/`) and giving the positional argument (`my_model`). 62 | 63 | By default, these commands will output all interesting files (save, report, etc.) to `/tmp/piqa/squad`. You can change the directory with `--output_dir` argument. Let `$OUTPUT_DIR` denotes this. 64 | 65 | 66 | ### 2. Easy Evaluation 67 | Assuming you trust us, let's just try to output the prediction file from a full (context+question) dataset, and evaluate it with SQuAD v1.1 evaluator. To do this with LSTM model, simply run: 68 | 69 | ```bash 70 | python main.py baseline --cuda --mode test --load_dir $OUTPUT_DIR/save/####/model.pt --test_path $SQUAD_DEV_PATH 71 | ``` 72 | 73 | Where the #### indicates the step at which the model of interest is saved (e.g. `7001`). Take a look at the standard output during training and pick the one that gives the best performance (which is automatically tracked). Technically speaking, this is *cheating* on the dev set, but we are going to evalute on the test set at the end, so we are okay. 74 | 75 | This will output the prediction file at `$OUTPUT_DIR/pred.json`. Finally, simply feed it to the SQuAD v1.1 evaluator (changed name from `evaluate-v1.1.py` to `evaluate.py`): 76 | 77 | ```bash 78 | python evaluate.py $SQUAD_DEV_PATH $OUTPUT_DIR/pred.json 79 | ``` 80 | 81 | That was easy! But why is this not an *official evaluation*? Because we had a big assumption in the beginning, that you trust us that our encoders are independent. But who knows? 82 | 83 | 84 | ### 3. Official Evaluation 85 | 86 | #### How? 87 | 88 | We need a strict evaluation method that enforces the independence between the encoders. We require 'split-encode-merge' pipeline to ensure this: 89 | 90 | ![piqa_eval](figures/piqa_eval.jpg) 91 | 92 | where 'dev1.1.json*' does not contain the answers. A regular SQuAD v1.1 submission will correspond to uploading a model that replaces the black dotted box. A PI-SQuAD submission instead requires one to upload two encoders, document encoder and question encoder, that the replace orange boxes. They **must** be preceded by `split.py` and followed by `merge.py`. We describe the expected formats of `dev-v1.1-c.json`, `dev-v1.1-q.json`, `context_emb/` and `question_emb/`. 93 | 94 | `dev-v1.1-c.json` and `dev-v1.1-q.json`: `split.py` simply splts `dev-v1.1.json` into context-only and question-only json files. 95 | 96 | 97 | `context_emb/`: the directory should contain a numpy file (`.npz`) and a list of phrases (`.json`) for each context (paragraph). 98 | 99 | `question_emb/`: the directory should contain a numpy file for each question. 100 | 101 | The directories will look as following: 102 | 103 | ``` 104 | $OUTPUT_DIR 105 | +-- context_emb/ 106 | | +-- Super_Bowl_50_0.npz 107 | | +-- Super_Bowl_50_0.json 108 | | +-- Super_Bowl_50_1.npz 109 | | +-- Super_Bowl_50_2.json 110 | | ... 111 | +-- question_emb/ 112 | | +-- 56be4eafacb8001400a50302.npz 113 | | +-- 56d204ade7d4791d00902603.npz 114 | | ... 115 | ``` 116 | 117 | This looks quite complicated! Let's take a look at one by one. 118 | 119 | 1. **`.npz` is a numpy/scipy matrix dump**: Each `.npz` file corresponds to a *N-by-d* matrix. If it is a dense matrix, it needs to be saved via `numpy.savez()` method, and if it is a sparse matrix (depending on your need), it needs to be saved via `scipy.sparse.save_npz()` method. Note that `scipy.sparse.save_npz()` is relatively new and old scipy versions do not support it. 120 | 2. **each `.npz` in `context_emb` is named after paragraph id**: Here, paragraph id is `'%s_%d' % (article_title, para_idx)`, where `para_idx` indicates the index of the paragraph within the article (starts at `0`). For instance, if the article `Super_Bowl_50` has 35 paragraphs, then it will have `.npz` files until `Super_Bowl_50_34.npz`. 121 | 3. **each `.npz` in `context_emb` is *N* phrase vectors of *d*-dim**: It is up to the submitted model to decide *N* and *d*. For instance, if the paragraph length is 100 words and we enumerate all possible phrases with length <= 7, then we will approximately have *N* = 700. While we will limit the size of `.npz` per word during the submission so one cannot have a very large dense matrix, we will allow sparse matrices, so *d* can be very large in some cases. 122 | 4. **`.json` is a list of *N* phrases**: each phrase corresponds to each phrase vector in its corresponding `.npz` file. Of course, one can have duplicate phrases (i.e. several vectors per phrase). 123 | 5. **each `.npz` in `question_emb` is named after question id**: Here, question id is the official id in original SQuAD 1.1. 124 | 6. **each `.npz` in `question_emb` must be *1*-by-*d* matrix**: Since each question has a single embedding, *N* = 1. Hence the matrix corresponds to the question representation. 125 | 126 | Following these rules, one should confirm that `context_emb` contains 4134 files (2067 `.npz` files and 2067 `.json` files, i.e. 2067 paragraphs) and `question_emb` contains 10570 files (one file for each question) for SQuAD v1.1 dev dataset. Hint: `ls context_emb/ | wc -l` gives you the count in the `context_emb` folder. 127 | 128 | #### Running baseline 129 | 130 | To split `dev-v1.1.json`: 131 | ``` 132 | python split.py $SQUAD_DEV_PATH $SQUAD_DEV_CONTEXT_PATH 133 | ``` 134 | 135 | Now, for document and question encoders, we run `main.py` with two different arguments for each. 136 | 137 | For document encoder: 138 | 139 | ```bash 140 | python main.py baseline --cuda --mode embed_context --load_dir $OUTPUT_DIR/save/XXXX/model.pt --test_path $SQUAD_DEV_CONTEXT_PATH --context_emb_dir $CONTEXT_EMB_DIR 141 | ``` 142 | 143 | For question encoder: 144 | 145 | ```bash 146 | python main.py baseline --cuda --mode embed_question --load_dir $OUTPUT_DIR/save/XXXX/model.pt --test_path $SQUAD_DEV_QUESTION_PATH --question_emb_dir $QUESTION_EMB_DIR 147 | ``` 148 | 149 | The encoders will output the embeddings to `$CONTEXT_EMB_DIR` and `$QUESTION_EMB_DIR`, respectively. Using compressed dump for the LSTM model, these directories take about 500 MB of disk space. 150 | 151 | To merge: 152 | 153 | ```bash 154 | python merge.py $SQUAD_DEV_PATH $CONTEXT_EMB_DIR $QUESTION_EMB_DIR $PRED_PATH 155 | ``` 156 | 157 | where `$PRED_PATH` is the prediction path. 158 | For our baselines, this takes ~4 minutes on a typical consumer-grade CPU, though keep in mind that the duration will depend on the size of *N* and *d*. 159 | By default the evaluator assumes the dumped matrices are dense matrices, but you can also work with sparse matrices by giving `--sparse` argument. 160 | If you have `tqdm`, you can display progress with `--progress` argument. 161 | The evaluator does not require `torch` and `nltk`, but it needs `numpy` and `scipy`. 162 | 163 | Note that we currently only support *inner product* for the nearest neighbor search (our baseline model uses inner product as well). We will support L1/L2 distances when the submission opens. Please let us know (create an issue) if you think other measures should be also supported. Note that, however, we try to limit to those that are commonly used for approximate search (so it is unlikely that we will support a multilayer perceptron, because it simply does not scale up). 164 | 165 | Lastly, to evaluate, use the official evalutor script: 166 | 167 | ``` 168 | python evaluate.py $SQUAD_DEV_PATH $PRED_PATH 169 | ``` 170 | 171 | ### 3. Demo 172 | Demo uses dumped context embeddings. We need to make one change: when encoding document (context), give `--metadata` flag to output additional necessary data for demo: 173 | ```bash 174 | python main.py baseline --cuda --mode embed_context --load_dir $OUTPUT_DIR/save/XXXX/model.pt --test_path $SQUAD_DEV_CONTEXT_PATH --context_emb_dir $CONTEXT_EMB_DIR --metadata 175 | ``` 176 | 177 | Then run the demo by: 178 | ```bash 179 | python main.py baseline --mode serve_demo --load_dir $OUTPUT_DIR/save/XXXX/model.pt --context_emb_dir $CONTEXT_EMB_DIR --port 8080 180 | ``` 181 | 182 | This will serve the demo on localhost at port 8080. Note that this is *exact search without compression* 183 | (i.e. it can take a lot of time with a lot of RAM usage). 184 | For time- and memory-efficient similarity search, additionally give following flags: 185 | ```bash 186 | --nlist 100 --nprobe 10 --bpv 128 187 | ``` 188 | `--nlist` is number of clusters, and `--nprobe` is number of clusters you peek into (so `nlist` > `nprobe` makes it faster but approximation) 189 | `--bpv` is bytes per vector. 512D vector takes 2 KB, so `--bpv 128` compresses the size by 16-fold. 190 | 191 | 192 | ## Submission 193 | Please see the [sample submission worksheet][worksheet-elmo] for instructions on how to submit a PI-SQuAD model for test set. In short, it should be a valid SQuAD submission, but in addition, it should meet the requirements described above. 194 | 195 | 196 | 197 | 198 | [paper]: https://arxiv.org/abs/1804.07726 199 | [minjoon]: https://seominjoon.github.io 200 | [minjoon-github]: https://github.com/seominjoon 201 | [squad-train]: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json 202 | [squad-dev]: https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json 203 | [squad-context]: https://uwnlp.github.io/piqa/data/squad/dev-v1.1-context.json 204 | [squad-question]: https://uwnlp.github.io/piqa/data/squad/dev-v1.1-question.json 205 | [elmo]: https://allennlp.org/elmo 206 | [squad]: https://stanford-qa.com 207 | [mipsqa]: https://github.com/google/mipsqa 208 | [worksheet-elmo]: https://worksheets.codalab.org/worksheets/0x58f20753fb784ffaa37877f777057b17/ 209 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 University of Washington 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /squad/baseline/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import base 5 | 6 | 7 | class CharEmbedding(nn.Module): 8 | def __init__(self, char_vocab_size, embed_dim): 9 | super(CharEmbedding, self).__init__() 10 | self.char_vocab_size = char_vocab_size 11 | self.embed_dim = embed_dim 12 | 13 | self.embedding = nn.Embedding(char_vocab_size, embed_dim) 14 | 15 | def forward(self, x): 16 | flat_x = x.view(-1, x.size()[-1]) 17 | flat_out = self.embedding(flat_x) 18 | out = flat_out.view(x.size() + (flat_out.size()[-1],)) 19 | out, _ = torch.max(out, -2) 20 | return out 21 | 22 | 23 | class WordEmbedding(nn.Module): 24 | def __init__(self, word_vocab_size=None, embed_dim=None, requires_grad=True, cpu=False): 25 | super(WordEmbedding, self).__init__() 26 | self.embedding = nn.Embedding(word_vocab_size, embed_dim) 27 | self.embedding.weight.requires_grad = requires_grad 28 | self._cpu = cpu 29 | 30 | def forward(self, x): 31 | device = x.device 32 | weight_device = self.embedding.weight.device 33 | x = x.to(weight_device) 34 | flat_x = x.view(-1, x.size()[-1]) 35 | flat_out = self.embedding(flat_x) 36 | out = flat_out.view(x.size() + (flat_out.size()[-1],)) 37 | out = out.to(device) 38 | return out 39 | 40 | def to(self, device): 41 | return self if self._cpu else super().to(device) 42 | 43 | 44 | class Highway(nn.Module): 45 | def __init__(self, input_dim, dropout): 46 | super(Highway, self).__init__() 47 | self.input_linear = nn.Linear(input_dim, input_dim) 48 | self.relu = nn.ReLU() 49 | self.gate_linear = nn.Linear(input_dim, input_dim) 50 | self.sigmoid = nn.Sigmoid() 51 | self.dropout = nn.Dropout(dropout) 52 | 53 | def forward(self, input_): 54 | input_ = self.dropout(input_) 55 | output = self.relu(self.input_linear(input_)) 56 | gate = self.sigmoid(self.gate_linear(input_)) 57 | output = input_ * gate + output * (1.0 - gate) 58 | return output 59 | 60 | 61 | class Embedding(nn.Module): 62 | def __init__(self, char_vocab_size, glove_vocab_size, word_vocab_size, embed_dim, dropout, elmo=False, 63 | glove_cpu=False): 64 | super(Embedding, self).__init__() 65 | self.word_embedding = WordEmbedding(word_vocab_size, embed_dim) 66 | self.char_embedding = CharEmbedding(char_vocab_size, embed_dim) 67 | self.glove_embedding = WordEmbedding(glove_vocab_size, embed_dim, requires_grad=False, cpu=glove_cpu) 68 | self.output_size = 2 * embed_dim 69 | self.highway1 = Highway(self.output_size, dropout) 70 | self.highway2 = Highway(self.output_size, dropout) 71 | self.use_elmo = elmo 72 | self.elmo = None 73 | if self.use_elmo: 74 | self.output_size += 1024 75 | 76 | def load_glove(self, glove_emb_mat): 77 | device = self.glove_embedding.embedding.weight.device 78 | glove_emb_mat = glove_emb_mat.to(device) 79 | glove_emb_mat = torch.cat([torch.zeros(2, glove_emb_mat.size()[-1]).to(device), glove_emb_mat], dim=0) 80 | self.glove_embedding.embedding.weight = torch.nn.Parameter(glove_emb_mat, requires_grad=False) 81 | 82 | def load_elmo(self, elmo_options_file, elmo_weights_file): 83 | device = self.word_embedding.embedding.weight.device 84 | from allennlp.modules.elmo import Elmo 85 | self.elmo = Elmo(elmo_options_file, elmo_weights_file, 1, dropout=0).to(device) 86 | 87 | def init(self, processed_metadata): 88 | self.load_glove(processed_metadata['glove_emb_mat']) 89 | if self.use_elmo: 90 | self.load_elmo(processed_metadata['elmo_options_file'], processed_metadata['elmo_weights_file']) 91 | 92 | def forward(self, cx, gx, x, ex=None): 93 | cx = self.char_embedding(cx) 94 | gx = self.glove_embedding(gx) 95 | output = torch.cat([cx, gx], -1) 96 | output = self.highway2(self.highway1(output)) 97 | if self.use_elmo: 98 | elmo, = self.elmo(ex)['elmo_representations'] 99 | output = torch.cat([output, elmo], 2) 100 | return output 101 | 102 | 103 | class SelfSeqAtt(nn.Module): 104 | def __init__(self, input_size, hidden_size, dropout): 105 | super(SelfSeqAtt, self).__init__() 106 | self.dropout = torch.nn.Dropout(p=dropout) 107 | self.query_lstm = nn.LSTM(input_size=input_size, 108 | hidden_size=hidden_size, 109 | batch_first=True, 110 | bidirectional=True) 111 | self.key_lstm = nn.LSTM(input_size=input_size, 112 | hidden_size=hidden_size, 113 | batch_first=True, 114 | bidirectional=True) 115 | self.softmax = nn.Softmax(dim=2) 116 | 117 | def forward(self, input_, mask): 118 | input_ = self.dropout(input_) 119 | key_input = input_ 120 | query_input = input_ 121 | key, _ = self.key_lstm(key_input) 122 | query, _ = self.key_lstm(query_input) 123 | att = query.matmul(key.transpose(1, 2)) + mask.unsqueeze(1) 124 | att = self.softmax(att) 125 | output = att.matmul(input_) 126 | return {'value': output, 'key': key, 'query': query} 127 | 128 | 129 | class ContextBoundary(nn.Module): 130 | def __init__(self, input_size, hidden_size, dropout, num_heads, identity=True, num_layers=1): 131 | super(ContextBoundary, self).__init__() 132 | assert num_heads >= 1, num_heads 133 | self.dropout = torch.nn.Dropout(p=dropout) 134 | self.num_layers = num_layers 135 | for i in range(self.num_layers): 136 | self.add_module('lstm%d' % i, torch.nn.LSTM(input_size=input_size, 137 | hidden_size=hidden_size, 138 | batch_first=True, 139 | bidirectional=True)) 140 | self.num_heads = num_heads 141 | self.identity = identity 142 | self.att_num_heads = num_heads - 1 if identity else num_heads 143 | for i in range(self.att_num_heads): 144 | self.add_module('self_att%d' % i, 145 | SelfSeqAtt(hidden_size * 2, hidden_size, dropout)) 146 | 147 | def forward(self, x, m): 148 | modules = dict(self.named_children()) 149 | x = self.dropout(x) 150 | for i in range(self.num_layers): 151 | x, _ = modules['lstm%d' % i](x) 152 | atts = [x] if self.identity else [] 153 | for i in range(self.att_num_heads): 154 | a = modules['self_att%d' % i](x, m) 155 | atts.append(a['value']) 156 | 157 | dense = torch.cat(atts, 2) 158 | return {'dense': dense} 159 | 160 | 161 | class QuestionBoundary(ContextBoundary): 162 | def __init__(self, input_size, hidden_size, dropout, num_heads, max_pool=False): 163 | super(QuestionBoundary, self).__init__(input_size, hidden_size, dropout, num_heads, identity=False) 164 | self.max_pool = max_pool 165 | 166 | def forward(self, x, m): 167 | d = super().forward(x, m) 168 | if self.max_pool: 169 | dense = d['dense'].max(1)[0] 170 | else: 171 | dense = d['dense'][:, 0, :] 172 | return {'dense': dense} 173 | 174 | 175 | class Model(base.Model): 176 | def __init__(self, 177 | char_vocab_size, 178 | glove_vocab_size, 179 | word_vocab_size, 180 | hidden_size, 181 | embed_size, 182 | dropout, 183 | num_heads, 184 | max_ans_len=7, 185 | elmo=False, 186 | max_pool=False, 187 | num_layers=1, 188 | glove_cpu=False, 189 | metric='ip', 190 | **kwargs): 191 | super(Model, self).__init__() 192 | self.embedding = Embedding(char_vocab_size, glove_vocab_size, word_vocab_size, embed_size, dropout, 193 | elmo=elmo, glove_cpu=glove_cpu) 194 | self.context_embedding = self.embedding 195 | self.question_embedding = self.embedding 196 | word_size = self.embedding.output_size 197 | context_input_size = word_size 198 | question_input_size = word_size 199 | self.context_start = ContextBoundary(context_input_size, hidden_size, dropout, num_heads, num_layers=num_layers) 200 | self.context_end = ContextBoundary(context_input_size, hidden_size, dropout, num_heads, num_layers=num_layers) 201 | self.question_start = QuestionBoundary(question_input_size, hidden_size, dropout, num_heads, max_pool=max_pool) 202 | self.question_end = QuestionBoundary(question_input_size, hidden_size, dropout, num_heads, max_pool=max_pool) 203 | self.softmax = nn.Softmax(dim=1) 204 | self.max_ans_len = max_ans_len 205 | self.linear = nn.Linear(word_size, 1) 206 | self.metric = metric 207 | 208 | def forward(self, 209 | context_char_idxs, 210 | context_glove_idxs, 211 | context_word_idxs, 212 | question_char_idxs, 213 | question_glove_idxs, 214 | question_word_idxs, 215 | context_elmo_idxs=None, 216 | question_elmo_idxs=None, 217 | num_samples=None, 218 | **kwargs): 219 | q = self.question_embedding(question_char_idxs, question_glove_idxs, question_word_idxs, ex=question_elmo_idxs) 220 | x = self.context_embedding(context_char_idxs, context_glove_idxs, context_word_idxs, ex=context_elmo_idxs) 221 | 222 | mq = ((question_glove_idxs == 0).float() * -1e9) 223 | qd1 = self.question_start(q, mq) 224 | qd2 = self.question_end(q, mq) 225 | q1 = qd1['dense'] 226 | q2 = qd2['dense'] 227 | # print(qs1[0, question_word_idxs[0] > 0]) 228 | 229 | mx = (context_glove_idxs == 0).float() * -1e9 230 | 231 | hd1 = self.context_start(x, mx) 232 | hd2 = self.context_end(x, mx) 233 | x1 = hd1['dense'] 234 | x2 = hd2['dense'] 235 | 236 | logits1 = torch.sum(x1 * q1.unsqueeze(1), 2) + mx 237 | logits2 = torch.sum(x2 * q2.unsqueeze(1), 2) + mx 238 | 239 | if self.metric == 'l2': 240 | logits1 += -0.5 * (torch.sum(x1 * x1, 2) + torch.sum(q1 * q1, 1).unsqueeze(1)) 241 | logits2 += -0.5 * (torch.sum(x2 * x2, 2) + torch.sum(q2 * q2, 1).unsqueeze(1)) 242 | 243 | prob1 = self.softmax(logits1) 244 | prob2 = self.softmax(logits2) 245 | prob = prob1.unsqueeze(2) * prob2.unsqueeze(1) 246 | mask = (torch.ones(*prob.size()[1:]).triu() - torch.ones(*prob.size()[1:]).triu(self.max_ans_len)).to( 247 | prob.device) 248 | prob *= mask 249 | _, yp1 = prob.max(2)[0].max(1) 250 | _, yp2 = prob.max(1)[0].max(1) 251 | 252 | return {'logits1': logits1, 253 | 'logits2': logits2, 254 | 'yp1': yp1, 255 | 'yp2': yp2, 256 | 'x1': x1, 257 | 'x2': x2, 258 | 'q1': q1, 259 | 'q2': q2} 260 | 261 | def init(self, processed_metadata): 262 | self.embedding.init(processed_metadata) 263 | 264 | def get_context(self, context_char_idxs, context_glove_idxs, context_word_idxs, context_elmo_idxs=None, **kwargs): 265 | l = (context_glove_idxs > 0).sum(1) 266 | mx = (context_glove_idxs == 0).float() * -1e9 267 | x = self.context_embedding(context_char_idxs, context_glove_idxs, context_word_idxs, ex=context_elmo_idxs) 268 | xd1 = self.context_start(x, mx) 269 | x1 = xd1['dense'] 270 | xd2 = self.context_end(x, mx) 271 | x2 = xd2['dense'] 272 | out = [] 273 | for k, (lb, x1b, x2b) in enumerate(zip(l, x1, x2)): 274 | pos_list = [] 275 | vec_list = [] 276 | for i in range(lb): 277 | for j in range(i, min(i + self.max_ans_len, lb)): 278 | vec = torch.cat([x1b[i], x2b[j]], 0) 279 | pos_list.append((i, j)) 280 | vec_list.append(vec) 281 | 282 | dense = torch.stack(vec_list, 0) 283 | out.append((tuple(pos_list), dense)) 284 | return tuple(out) 285 | 286 | def get_question(self, question_char_idxs, question_glove_idxs, question_word_idxs, question_elmo_idxs=None, 287 | **kwargs): 288 | mq = ((question_glove_idxs == 0).float() * -1e9) 289 | q = self.question_embedding(question_char_idxs, question_glove_idxs, question_word_idxs, ex=question_elmo_idxs) 290 | qd1 = self.question_start(q, mq) 291 | q1 = qd1['dense'] 292 | qd2 = self.question_end(q, mq) 293 | q2 = qd2['dense'] 294 | out = list(torch.cat([q1, q2], 1).unsqueeze(1)) 295 | return out 296 | 297 | 298 | class Loss(base.Loss): 299 | def __init__(self, **kwargs): 300 | super(Loss, self).__init__() 301 | self.cel = nn.CrossEntropyLoss() 302 | 303 | def forward(self, logits1, logits2, answer_word_starts, answer_word_ends, **kwargs): 304 | answer_word_starts -= 1 305 | answer_word_ends -= 1 306 | loss1 = self.cel(logits1, answer_word_starts[:, 0]) 307 | loss2 = self.cel(logits2, answer_word_ends[:, 0]) 308 | loss = loss1 + loss2 309 | return loss 310 | -------------------------------------------------------------------------------- /squad/baseline/processor.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import string 4 | from collections import Counter 5 | 6 | import nltk 7 | import torch 8 | import torch.utils.data 9 | from scipy.sparse import csc_matrix 10 | import numpy as np 11 | 12 | import base 13 | 14 | 15 | class Tokenizer(object): 16 | def tokenize(self, in_): 17 | raise NotImplementedError() 18 | 19 | 20 | class PTBSentTokenizer(Tokenizer): 21 | def tokenize(self, in_): 22 | sents = nltk.sent_tokenize(in_) 23 | return _get_spans(in_, sents) 24 | 25 | 26 | class PTBWordTokenizer(Tokenizer): 27 | def tokenize(self, in_): 28 | in_ = in_.replace('``', '" ').replace("''", '" ').replace('\t', ' ') 29 | words = nltk.word_tokenize(in_) 30 | words = tuple(word.replace('``', '"').replace("''", '"') for word in words) 31 | return _get_spans(in_, words) 32 | 33 | 34 | class Processor(base.Processor): 35 | keys = {'context_word_idxs', 36 | 'context_glove_idxs', 37 | 'context_char_idxs', 38 | 'question_word_idxs', 39 | 'question_glove_idxs', 40 | 'question_char_idxs', 41 | 'answer_word_starts', 42 | 'answer_word_ends', 43 | 'idx'} 44 | depths = {'context_word_idxs': 1, 45 | 'context_glove_idxs': 1, 46 | 'context_char_idxs': 2, 47 | 'question_word_idxs': 1, 48 | 'question_glove_idxs': 1, 49 | 'question_char_idxs': 2, 50 | 'answer_word_starts': 1, 51 | 'answer_word_ends': 1, 52 | 'idx': 0} 53 | pad = '' 54 | unk = '' 55 | 56 | def __init__(self, char_vocab_size=None, glove_vocab_size=None, word_vocab_size=None, elmo=False, draft=False, 57 | emb_type=None, **kwargs): 58 | self._word_tokenizer = PTBWordTokenizer() 59 | self._sent_tokenizer = PTBSentTokenizer() 60 | self._char_vocab_size = char_vocab_size 61 | self._glove_vocab_size = glove_vocab_size 62 | self._word_vocab_size = word_vocab_size 63 | self._elmo = elmo 64 | if elmo: 65 | from allennlp.modules.elmo import batch_to_ids 66 | self._batch_to_ids = batch_to_ids 67 | self._draft = draft 68 | self._emb_type = emb_type 69 | self._glove = None 70 | 71 | self._word_cache = {} 72 | self._sent_cache = {} 73 | self._word2idx_dict = {} 74 | self._word2idx_ext_dict = {} 75 | self._char2idx_dict = {} 76 | 77 | def construct(self, examples, metadata): 78 | assert metadata is not None 79 | glove_vocab = metadata['glove_vocab'] 80 | word_counter, lower_word_counter, char_counter = Counter(), Counter(), Counter() 81 | for example in examples: 82 | for text in (example['context'], example['question']): 83 | for span in self._word_tokenize(example['context']): 84 | word = text[span[0]:span[1]] 85 | word_counter[word] += 1 86 | lower_word_counter[word] += 1 87 | for char in word: 88 | char_counter[char] += 1 89 | 90 | word_vocab = tuple(item[0] for item in sorted(word_counter.items(), key=lambda item: -item[1])) 91 | word_vocab = (Processor.pad, Processor.unk) + word_vocab 92 | word_vocab = word_vocab[:self._word_vocab_size] if len(word_vocab) > self._word_vocab_size else word_vocab 93 | self._word2idx_dict = {word: idx for idx, word in enumerate(word_vocab)} 94 | 95 | char_vocab = tuple(item[0] for item in sorted(char_counter.items(), key=lambda item: -item[1])) 96 | char_vocab = (Processor.pad, Processor.unk) + char_vocab 97 | char_vocab = char_vocab[:self._char_vocab_size] if len(char_vocab) > self._char_vocab_size else char_vocab 98 | self._char2idx_dict = {char: idx for idx, char in enumerate(char_vocab)} 99 | 100 | ext_vocab = (Processor.pad, Processor.unk) + tuple(glove_vocab) 101 | if len(ext_vocab) > self._glove_vocab_size: 102 | ext_vocab = ext_vocab[:self._glove_vocab_size] 103 | self._word2idx_ext_dict = {ext: idx for idx, ext in enumerate(ext_vocab)} 104 | # assert max(self._word2idx_ext.values()) + 1 == self._glove_vocab_size, max(self._word2idx_ext.values()) + 1 105 | 106 | def state_dict(self): 107 | out = {'word2idx': self._word2idx_dict, 108 | 'word2idx_ext': self._word2idx_ext_dict, 109 | 'char2idx': self._char2idx_dict} 110 | return out 111 | 112 | def load_state_dict(self, in_): 113 | self._word2idx_dict = in_['word2idx'] 114 | self._word2idx_ext_dict = in_['word2idx_ext'] 115 | self._char2idx_dict = in_['char2idx'] 116 | 117 | def preprocess(self, example): 118 | prepro_example = {'idx': example['idx']} 119 | 120 | if 'context' in example: 121 | context = example['context'] 122 | context_spans = self._word_tokenize(context) 123 | context_words = tuple(context[span[0]:span[1]] for span in context_spans) 124 | context_word_idxs = tuple(map(self._word2idx, context_words)) 125 | context_glove_idxs = tuple(map(self._word2idx_ext, context_words)) 126 | context_char_idxs = tuple(tuple(map(self._char2idx, word)) for word in context_words) 127 | prepro_example['context_spans'] = context_spans 128 | prepro_example['context_word_idxs'] = context_word_idxs 129 | prepro_example['context_glove_idxs'] = context_glove_idxs 130 | prepro_example['context_char_idxs'] = context_char_idxs 131 | 132 | if 'question' in example: 133 | question = example['question'] 134 | question_spans = self._word_tokenize(example['question']) 135 | question_words = tuple(question[span[0]:span[1]] for span in question_spans) 136 | question_word_idxs = tuple(map(self._word2idx, question_words)) 137 | question_glove_idxs = tuple(map(self._word2idx_ext, question_words)) 138 | question_char_idxs = tuple(tuple(map(self._char2idx, word)) for word in question_words) 139 | prepro_example['question_spans'] = question_spans 140 | prepro_example['question_word_idxs'] = question_word_idxs 141 | prepro_example['question_glove_idxs'] = question_glove_idxs 142 | prepro_example['question_char_idxs'] = question_char_idxs 143 | 144 | if 'answer_starts' in example: 145 | answer_word_start, answer_word_end = 0, 0 146 | answer_word_starts, answer_word_ends = [], [] 147 | for answer_start in example['answer_starts']: 148 | for word_idx, span in enumerate(context_spans): 149 | if span[0] <= answer_start: 150 | answer_word_start = word_idx + 1 151 | answer_word_starts.append(answer_word_start) 152 | for answer_end in example['answer_ends']: 153 | for word_idx, span in enumerate(context_spans): 154 | if span[0] <= answer_end: 155 | answer_word_end = word_idx + 1 156 | answer_word_ends.append(answer_word_end) 157 | prepro_example['answer_word_starts'] = answer_word_starts 158 | prepro_example['answer_word_ends'] = answer_word_ends 159 | 160 | output = dict(tuple(example.items()) + tuple(prepro_example.items())) 161 | return output 162 | 163 | def postprocess(self, example, model_output): 164 | yp1 = model_output['yp1'].item() 165 | yp2 = model_output['yp2'].item() 166 | context = example['context'] 167 | context_spans = example['context_spans'] 168 | pred = _get_pred(context, context_spans, yp1, yp2) 169 | out = {'pred': pred, 'id': example['id']} 170 | if 'answer_starts' in example: 171 | y1 = example['answer_starts'] 172 | y2 = example['answer_ends'] 173 | gt = [context[s:e] for s, e in zip(y1, y2)] 174 | f1 = max(_f1_score(pred, gt_each) for gt_each in gt) 175 | em = max(_exact_match_score(pred, gt_each) for gt_each in gt) 176 | out['gt'] = gt 177 | out['f1'] = f1 178 | out['em'] = em 179 | return out 180 | 181 | def postprocess_batch(self, dataset, model_input, model_output): 182 | results = tuple(self.postprocess(dataset[idx], 183 | {key: val[i] if val is not None else None for key, val in 184 | model_output.items()}) 185 | for i, idx in enumerate(model_input['idx'])) 186 | return results 187 | 188 | def postprocess_context(self, example, context_output): 189 | pos_tuple, dense = context_output 190 | out = dense.cpu().numpy() 191 | context = example['context'] 192 | context_spans = example['context_spans'] 193 | phrases = tuple(_get_pred(context, context_spans, yp1, yp2) for yp1, yp2 in pos_tuple) 194 | if self._emb_type == 'sparse': 195 | out = csc_matrix(out) 196 | metadata = {'context': context, 197 | 'answer_spans': tuple((context_spans[yp1][0], context_spans[yp2][1]) for yp1, yp2 in pos_tuple)} 198 | return example['cid'], phrases, out, metadata 199 | 200 | def postprocess_context_batch(self, dataset, model_input, context_output): 201 | results = tuple(self.postprocess_context(dataset[idx], context_output[i]) 202 | for i, idx in enumerate(model_input['idx'])) 203 | return results 204 | 205 | def postprocess_question(self, example, question_output): 206 | dense = question_output 207 | out = dense.cpu().numpy() 208 | if self._emb_type == 'sparse': 209 | out = csc_matrix(out) 210 | return example['id'], out 211 | 212 | def postprocess_question_batch(self, dataset, model_input, question_output): 213 | results = tuple(self.postprocess_question(dataset[idx], question_output[i]) 214 | for i, idx in enumerate(model_input['idx'])) 215 | return results 216 | 217 | def collate(self, examples): 218 | tensors = {} 219 | for key in self.keys: 220 | if key not in examples[0]: 221 | continue 222 | val = tuple(example[key] for example in examples) 223 | depth = self.depths[key] + 1 224 | shape = _get_shape(val, depth) 225 | tensor = torch.zeros(shape, dtype=torch.int64) 226 | _fill_tensor(tensor, val) 227 | tensors[key] = tensor 228 | if self._elmo: 229 | if 'context' in examples[0]: 230 | sentences = [[example['context'][span[0]:span[1]] for span in example['context_spans']] 231 | for example in examples] 232 | character_ids = self._batch_to_ids(sentences) 233 | tensors['context_elmo_idxs'] = character_ids 234 | if 'question' in examples[0]: 235 | sentences = [[example['question'][span[0]:span[1]] for span in example['question_spans']] 236 | for example in examples] 237 | character_ids = self._batch_to_ids(sentences) 238 | tensors['question_elmo_idxs'] = character_ids 239 | return tensors 240 | 241 | def process_metadata(self, metadata): 242 | return {'glove_emb_mat': torch.tensor(metadata['glove_emb_mat']), 243 | 'elmo_options_file': metadata['elmo_options_file'], 244 | 'elmo_weights_file': metadata['elmo_weights_file']} 245 | 246 | def get_dump(self, dataset, input_, output, results): 247 | dump = [] 248 | for i, idx in enumerate(input_['idx']): 249 | example = dataset[idx] 250 | each = {'id': example['id'], 251 | 'context': example['context'], 252 | 'question': example['question'], 253 | 'answer_starts': example['answer_starts'], 254 | 'answer_ends': example['answer_ends'], 255 | 'context_spans': example['context_spans'], 256 | 'yp1': output['yp1'][i].cpu().numpy(), 257 | 'yp2': output['yp2'][i].cpu().numpy(), 258 | } 259 | dump.append(each) 260 | return dump 261 | 262 | # private methods below 263 | def _word_tokenize(self, string): 264 | if string in self._word_cache: 265 | return self._word_cache[string] 266 | spans = self._word_tokenizer.tokenize(string) 267 | self._word_cache[string] = spans 268 | return spans 269 | 270 | def _sent_tokenize(self, string): 271 | if string in self._sent_cache: 272 | return self._sent_cache[string] 273 | spans = self._sent_tokenizer.tokenize(string) 274 | self._sent_cache[string] = spans 275 | return spans 276 | 277 | def _word2idx(self, word): 278 | return self._word2idx_dict[word] if word in self._word2idx_dict else 1 279 | 280 | def _word2idx_ext(self, word): 281 | word = word.lower() 282 | return self._word2idx_ext_dict[word] if word in self._word2idx_ext_dict else 1 283 | 284 | def _char2idx(self, char): 285 | return self._char2idx_dict[char] if char in self._char2idx_dict else 1 286 | 287 | 288 | class Sampler(base.Sampler): 289 | def __init__(self, dataset, data_type, max_context_size=None, max_question_size=None, bucket=False, shuffle=False, 290 | **kwargs): 291 | super(Sampler, self).__init__(dataset, data_type) 292 | if data_type == 'dev' or data_type == 'test': 293 | max_context_size = None 294 | max_question_size = None 295 | self.shuffle = False 296 | 297 | self.max_context_size = max_context_size 298 | self.max_question_size = max_question_size 299 | self.shuffle = shuffle 300 | self.bucket = bucket 301 | 302 | idxs = tuple(idx for idx in range(len(dataset)) 303 | if (max_context_size is None or len(dataset[idx]['context_spans']) <= max_context_size) and 304 | (max_question_size is None or len(dataset[idx]['question_spans']) <= max_question_size)) 305 | 306 | if shuffle: 307 | idxs = random.sample(idxs, len(idxs)) 308 | 309 | if bucket: 310 | if 'context_spans' in dataset[0]: 311 | idxs = sorted(idxs, key=lambda idx: len(dataset[idx]['context_spans'])) 312 | else: 313 | assert 'question_spans' in dataset[0] 314 | idxs = sorted(idxs, key=lambda idx: len(dataset[idx]['question_spans'])) 315 | self._idxs = idxs 316 | 317 | def __iter__(self): 318 | return iter(self._idxs) 319 | 320 | def __len__(self): 321 | return len(self._idxs) 322 | 323 | 324 | class SparseTensor(object): 325 | def __init__(self, idx, val, max_=None): 326 | self.idx = idx 327 | self.val = val 328 | self.max = max_ 329 | 330 | def scipy(self): 331 | col = self.idx.flatten() 332 | row = np.tile(np.expand_dims(range(self.idx.shape[0]), 1), [1, self.idx.shape[1]]).flatten() 333 | data = self.val.flatten() 334 | shape = None if self.max is None else [self.idx.shape[0], self.max] 335 | return csc_matrix((data, (row, col)), shape=shape) 336 | 337 | 338 | # SquadProcessor-specific helpers 339 | 340 | def _get_pred(context, spans, yp1, yp2): 341 | if yp1 >= len(spans): 342 | print('warning: yp1 is set to 0') 343 | yp1 = 0 344 | if yp2 >= len(spans): 345 | print('warning: yp1 is set to 0') 346 | yp2 = 0 347 | yp1c = spans[yp1][0] 348 | yp2c = spans[yp2][1] 349 | return context[yp1c:yp2c] 350 | 351 | 352 | def _get_spans(in_, tokens): 353 | pairs = [] 354 | i = 0 355 | for token in tokens: 356 | i = in_.find(token, i) 357 | assert i >= 0, 'token `%s` not found starting from %d: `%s`' % (token, i, in_[i:]) 358 | pair = (i, i + len(token)) 359 | pairs.append(pair) 360 | i += len(token) 361 | return tuple(pairs) 362 | 363 | 364 | def _get_shape(nested_list, depth): 365 | if depth > 0: 366 | return (len(nested_list),) + tuple(map(max, zip(*[_get_shape(each, depth - 1) for each in nested_list]))) 367 | return () 368 | 369 | 370 | def _fill_tensor(tensor, nested_list): 371 | if tensor.dim() == 1: 372 | tensor[:len(nested_list)] = torch.tensor(nested_list) 373 | elif tensor.dim() == 2: 374 | for i, each in enumerate(nested_list): 375 | tensor[i, :len(each)] = torch.tensor(each) 376 | elif tensor.dim() == 3: 377 | for i1, each1 in enumerate(nested_list): 378 | for i2, each2 in enumerate(each1): 379 | tensor[i1, i2, :len(each2)] = torch.tensor(each2) 380 | else: 381 | for tensor_child, nested_list_child in zip(tensor, nested_list): 382 | _fill_tensor(tensor_child, nested_list_child) 383 | 384 | 385 | # SQuAD official evaluation helpers 386 | 387 | def _normalize_answer(s): 388 | """Lower text and remove punctuation, articles and extra whitespace. 389 | 390 | Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED. 391 | 392 | Args: 393 | s: Input text. 394 | Returns: 395 | Normalized text. 396 | """ 397 | 398 | def remove_articles(text): 399 | return re.sub(r'\b(a|an|the)\b', ' ', text) 400 | 401 | def white_space_fix(text): 402 | return ' '.join(text.split()) 403 | 404 | def remove_punc(text): 405 | exclude = set(string.punctuation) 406 | return ''.join(ch for ch in text if ch not in exclude) 407 | 408 | def lower(text): 409 | return text.lower() 410 | 411 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 412 | 413 | 414 | def _f1_score(prediction, ground_truth): 415 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED.""" 416 | prediction_tokens = _normalize_answer(prediction).split() 417 | ground_truth_tokens = _normalize_answer(ground_truth).split() 418 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 419 | num_same = sum(common.values()) 420 | if num_same == 0: 421 | return 0 422 | precision = 1.0 * num_same / len(prediction_tokens) 423 | recall = 1.0 * num_same / len(ground_truth_tokens) 424 | f1 = (2 * precision * recall) / (precision + recall) 425 | return f1 426 | 427 | 428 | def _exact_match_score(prediction, ground_truth): 429 | """Directly copied from official SQuAD eval script, SHOULD NOT BE MODIFIED.""" 430 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 431 | -------------------------------------------------------------------------------- /squad/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from collections import OrderedDict 4 | from pprint import pprint 5 | import importlib 6 | 7 | import scipy.sparse 8 | import torch 9 | import numpy as np 10 | from torch.utils.data import DataLoader 11 | 12 | import base 13 | 14 | 15 | def preprocess(interface, args): 16 | """Helper function for caching preprocessed data 17 | """ 18 | print('Loading train and dev data') 19 | train_examples = interface.load_train() 20 | dev_examples = interface.load_test() 21 | 22 | # load metadata, such as GloVe 23 | print('Loading metadata') 24 | metadata = interface.load_metadata() 25 | 26 | print('Constructing processor') 27 | processor = Processor(**args.__dict__) 28 | processor.construct(train_examples, metadata) 29 | 30 | # data loader 31 | print('Preprocessing datasets and metadata') 32 | train_dataset = tuple(processor.preprocess(example) for example in train_examples) 33 | dev_dataset = tuple(processor.preprocess(example) for example in dev_examples) 34 | processed_metadata = processor.process_metadata(metadata) 35 | 36 | print('Creating data loaders') 37 | train_sampler = Sampler(train_dataset, 'train', **args.__dict__) 38 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 39 | collate_fn=processor.collate, sampler=train_sampler) 40 | 41 | dev_sampler = Sampler(dev_dataset, 'dev', **args.__dict__) 42 | dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, 43 | collate_fn=processor.collate, sampler=dev_sampler) 44 | 45 | if args.preload: 46 | train_loader = tuple(train_loader) 47 | dev_loader = tuple(dev_loader) 48 | 49 | out = {'processor': processor, 50 | 'train_dataset': train_dataset, 51 | 'dev_dataset': dev_dataset, 52 | 'processed_metadata': processed_metadata, 53 | 'train_loader': train_loader, 54 | 'dev_loader': dev_loader} 55 | 56 | return out 57 | 58 | 59 | def train(args): 60 | start_time = time.time() 61 | device = torch.device('cuda' if args.cuda else 'cpu') 62 | 63 | pprint(args.__dict__) 64 | interface = FileInterface(**args.__dict__) 65 | out = interface.cache(preprocess, args) if args.cache else preprocess(interface, args) 66 | processor = out['processor'] 67 | processed_metadata = out['processed_metadata'] 68 | train_dataset = out['train_dataset'] 69 | dev_dataset = out['dev_dataset'] 70 | train_loader = out['train_loader'] 71 | dev_loader = out['dev_loader'] 72 | 73 | model = Model(**args.__dict__).to(device) 74 | model.init(processed_metadata) 75 | 76 | loss_model = Loss().to(device) 77 | optimizer = torch.optim.Adam(p for p in model.parameters() if p.requires_grad) 78 | 79 | interface.bind(processor, model, optimizer=optimizer) 80 | 81 | step = 0 82 | train_report, dev_report = None, None 83 | 84 | print('Training') 85 | interface.save_args(args.__dict__) 86 | model.train() 87 | for epoch_idx in range(args.epochs): 88 | for i, train_batch in enumerate(train_loader): 89 | train_batch = {key: val.to(device) for key, val in train_batch.items()} 90 | model_output = model(step=step, **train_batch) 91 | train_results = processor.postprocess_batch(train_dataset, train_batch, model_output) 92 | train_loss = loss_model(step=step, **model_output, **train_batch) 93 | train_f1 = float(np.mean([result['f1'] for result in train_results])) 94 | train_em = float(np.mean([result['em'] for result in train_results])) 95 | 96 | # optimize 97 | optimizer.zero_grad() 98 | train_loss.backward() 99 | optimizer.step() 100 | step += 1 101 | 102 | # report & eval & save 103 | if step % args.report_period == 1: 104 | train_report = OrderedDict(step=step, train_loss=train_loss.item(), train_f1=train_f1, 105 | train_em=train_em, time=time.time() - start_time) 106 | print(interface.report(**train_report)) 107 | 108 | if step % args.eval_save_period == 1: 109 | with torch.no_grad(): 110 | model.eval() 111 | loss_model.eval() 112 | pred = {} 113 | dev_losses, dev_results = [], [] 114 | for dev_batch, _ in zip(dev_loader, range(args.eval_steps)): 115 | dev_batch = {key: val.to(device) for key, val in dev_batch.items()} 116 | model_output = model(**dev_batch) 117 | results = processor.postprocess_batch(dev_dataset, dev_batch, model_output) 118 | 119 | dev_loss = loss_model(step=step, **dev_batch, **model_output) 120 | 121 | for result in results: 122 | pred[result['id']] = result['pred'] 123 | dev_results.extend(results) 124 | dev_losses.append(dev_loss.item()) 125 | 126 | dev_loss = float(np.mean(dev_losses)) 127 | dev_f1 = float(np.mean([result['f1'] for result in dev_results])) 128 | dev_em = float(np.mean([result['em'] for result in dev_results])) 129 | dev_f1_best = dev_f1 if dev_report is None else max(dev_f1, dev_report['dev_f1_best']) 130 | dev_f1_best_step = step if dev_report is None or dev_f1 > dev_report['dev_f1_best'] else dev_report[ 131 | 'dev_f1_best_step'] 132 | 133 | dev_report = OrderedDict(step=step, dev_loss=dev_loss, dev_f1=dev_f1, dev_em=dev_em, 134 | time=time.time() - start_time, dev_f1_best=dev_f1_best, 135 | dev_f1_best_step=dev_f1_best_step) 136 | 137 | summary = False 138 | if dev_report['dev_f1_best_step'] == step: 139 | summary = True 140 | interface.save(iteration=step) 141 | interface.pred(pred) 142 | print(interface.report(summary=summary, **dev_report)) 143 | model.train() 144 | loss_model.train() 145 | 146 | if step == args.train_steps: 147 | break 148 | if step == args.train_steps: 149 | break 150 | 151 | 152 | def test(args): 153 | device = torch.device('cuda' if args.cuda else 'cpu') 154 | pprint(args.__dict__) 155 | 156 | interface = FileInterface(**args.__dict__) 157 | # use cache for metadata 158 | if args.cache: 159 | out = interface.cache(preprocess, args) 160 | processor = out['processor'] 161 | processed_metadata = out['processed_metadata'] 162 | else: 163 | processor = Processor(**args.__dict__) 164 | metadata = interface.load_metadata() 165 | processed_metadata = processor.process_metadata(metadata) 166 | 167 | model = Model(**args.__dict__).to(device) 168 | model.init(processed_metadata) 169 | interface.bind(processor, model) 170 | 171 | interface.load(args.iteration, session=args.load_dir) 172 | 173 | test_examples = interface.load_test() 174 | test_dataset = tuple(processor.preprocess(example) for example in test_examples) 175 | 176 | test_sampler = Sampler(test_dataset, 'test', **args.__dict__) 177 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, 178 | collate_fn=processor.collate) 179 | 180 | print('Inferencing') 181 | with torch.no_grad(): 182 | model.eval() 183 | pred = {} 184 | for batch_idx, (test_batch, _) in enumerate(zip(test_loader, range(args.eval_steps))): 185 | test_batch = {key: val.to(device) for key, val in test_batch.items()} 186 | model_output = model(**test_batch) 187 | results = processor.postprocess_batch(test_dataset, test_batch, model_output) 188 | if batch_idx % args.dump_period == 0: 189 | dump = processor.get_dump(test_dataset, test_batch, model_output, results) 190 | interface.dump(batch_idx, dump) 191 | for result in results: 192 | pred[result['id']] = result['pred'] 193 | 194 | print('[%d/%d]' % (batch_idx + 1, len(test_loader))) 195 | interface.pred(pred) 196 | 197 | 198 | def embed(args): 199 | device = torch.device('cuda' if args.cuda else 'cpu') 200 | pprint(args.__dict__) 201 | 202 | interface = FileInterface(**args.__dict__) 203 | # use cache for metadata 204 | if args.cache: 205 | out = interface.cache(preprocess, args) 206 | processor = out['processor'] 207 | processed_metadata = out['processed_metadata'] 208 | else: 209 | processor = Processor(**args.__dict__) 210 | metadata = interface.load_metadata() 211 | processed_metadata = processor.process_metadata(metadata) 212 | 213 | model = Model(**args.__dict__).to(device) 214 | model.init(processed_metadata) 215 | interface.bind(processor, model) 216 | 217 | interface.load(args.iteration, session=args.load_dir) 218 | 219 | test_examples = interface.load_test() 220 | test_dataset = tuple(processor.preprocess(example) for example in test_examples) 221 | 222 | test_sampler = Sampler(test_dataset, 'test', **args.__dict__) 223 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, 224 | collate_fn=processor.collate) 225 | 226 | print('Saving embeddings') 227 | with torch.no_grad(): 228 | model.eval() 229 | for batch_idx, (test_batch, _) in enumerate(zip(test_loader, range(args.eval_steps))): 230 | test_batch = {key: val.to(device) for key, val in test_batch.items()} 231 | 232 | if args.mode == 'embed' or args.mode == 'embed_context': 233 | 234 | context_output = model.get_context(**test_batch) 235 | context_results = processor.postprocess_context_batch(test_dataset, test_batch, context_output) 236 | 237 | for id_, phrases, matrix, metadata in context_results: 238 | if not args.metadata: 239 | metadata = None 240 | interface.context_emb(id_, phrases, matrix, metadata=metadata, emb_type=args.emb_type) 241 | 242 | if args.mode == 'embed' or args.mode == 'embed_question': 243 | 244 | question_output = model.get_question(**test_batch) 245 | question_results = processor.postprocess_question_batch(test_dataset, test_batch, question_output) 246 | 247 | for id_, emb in question_results: 248 | interface.question_emb(id_, emb, emb_type=args.emb_type) 249 | 250 | print('[%d/%d]' % (batch_idx + 1, len(test_loader))) 251 | 252 | if args.archive: 253 | print('Archiving') 254 | interface.archive() 255 | 256 | 257 | def serve(args): 258 | # serve_demo: Load saved embeddings, serve question model. question in, results out. 259 | # serve_question: only serve question model. question in, vector out. 260 | # serve_context: only serve context model. context in, phrase-vector pairs out. 261 | # serve: serve all three. 262 | device = torch.device('cuda' if args.cuda else 'cpu') 263 | pprint(args.__dict__) 264 | 265 | interface = FileInterface(**args.__dict__) 266 | # use cache for metadata 267 | if args.cache: 268 | out = interface.cache(preprocess, args) 269 | processor = out['processor'] 270 | processed_metadata = out['processed_metadata'] 271 | else: 272 | processor = Processor(**args.__dict__) 273 | metadata = interface.load_metadata() 274 | processed_metadata = processor.process_metadata(metadata) 275 | 276 | model = Model(**args.__dict__).to(device) 277 | model.init(processed_metadata) 278 | interface.bind(processor, model) 279 | 280 | interface.load(args.iteration, session=args.load_dir) 281 | 282 | with torch.no_grad(): 283 | model.eval() 284 | 285 | if args.mode == 'serve_demo': 286 | phrases = [] 287 | paras = [] 288 | results = [] 289 | embs = [] 290 | idxs = [] 291 | iterator = interface.context_load(metadata=True, emb_type=args.emb_type) 292 | for _, (cur_phrases, each_emb, metadata) in zip(range(args.num_train_mats), iterator): 293 | embs.append(each_emb) 294 | phrases.extend(cur_phrases) 295 | for span in metadata['answer_spans']: 296 | results.append([len(paras), span[0], span[1]]) 297 | idxs.append(len(idxs)) 298 | paras.append(metadata['context']) 299 | if args.emb_type == 'dense': 300 | import faiss 301 | emb = np.concatenate(embs, 0) 302 | 303 | d = 4 * args.hidden_size * args.num_heads 304 | if args.metric == 'ip': 305 | quantizer = faiss.IndexFlatIP(d) # Exact Search 306 | elif args.metric == 'l2': 307 | quantizer = faiss.IndexFlatL2(d) 308 | else: 309 | raise ValueError() 310 | 311 | if args.nlist != args.nprobe: 312 | # Approximate Search. nlist > nprobe makes it faster and less accurate 313 | if args.bpv is None: 314 | if args.metric == 'ip': 315 | search_index = faiss.IndexIVFFlat(quantizer, d, args.nlist, faiss.METRIC_INNER_PRODUCT) 316 | elif args.metric == 'l2': 317 | search_index = faiss.IndexIVFFlat(quantizer, d, args.nlist) 318 | else: 319 | raise ValueError() 320 | else: 321 | assert args.metric == 'l2' # only l2 is supported for product quantization 322 | search_index = faiss.IndexIVFPQ(quantizer, d, args.nlist, args.bpv, 8) 323 | search_index.train(emb) 324 | else: 325 | search_index = quantizer 326 | 327 | search_index.add(emb) 328 | for cur_phrases, each_emb, metadata in iterator: 329 | phrases.extend(cur_phrases) 330 | for span in metadata['answer_spans']: 331 | results.append([len(paras), span[0], span[1]]) 332 | paras.append(metadata['context']) 333 | search_index.add(each_emb) 334 | 335 | if args.nlist != args.nprobe: 336 | search_index.nprobe = args.nprobe 337 | 338 | def search(emb, k): 339 | D, I = search_index.search(emb, k) 340 | return D[0], I[0] 341 | 342 | elif args.emb_type == 'sparse': 343 | assert args.metric == 'l2' # currently only l2 is supported (couldn't find a good ip library) 344 | import pysparnn.cluster_index as ci 345 | 346 | cp = ci.MultiClusterIndex(embs, idxs) 347 | 348 | for cur_phrases, each_emb, metadata in iterator: 349 | phrases.extend(cur_phrases) 350 | for span in metadata['answer_spans']: 351 | results.append([len(paras), span[0], span[1]]) 352 | paras.append(metadata['context']) 353 | for each_vec in each_emb: 354 | cp.insert(each_vec, len(idxs)) 355 | idxs.append(len(idxs)) 356 | 357 | def search(emb, k): 358 | return zip(*[each[0] for each in cp.search(emb, k=k)]) 359 | 360 | else: 361 | raise ValueError() 362 | 363 | def retrieve(question, k): 364 | example = {'question': question, 'id': 'real', 'idx': 0} 365 | dataset = (processor.preprocess(example), ) 366 | loader = DataLoader(dataset, batch_size=1, collate_fn=processor.collate) 367 | batch = next(iter(loader)) 368 | question_output = model.get_question(**batch) 369 | question_results = processor.postprocess_question_batch(dataset, batch, question_output) 370 | id_, emb = question_results[0] 371 | D, I = search(emb, k) 372 | out = [(paras[results[i][0]], results[i][1], results[i][2], '%.4r' % d.item(),) 373 | for d, i in zip(D, I)] 374 | return out 375 | 376 | if args.mem_info: 377 | import psutil 378 | import os 379 | pid = os.getpid() 380 | py = psutil.Process(pid) 381 | info = py.memory_info()[0] / 2. ** 30 382 | print('Memory Use: %.2f GB' % info) 383 | 384 | # Demo server. Requires flask and tornado 385 | from flask import Flask, request, jsonify 386 | from flask_cors import CORS 387 | 388 | from tornado.wsgi import WSGIContainer 389 | from tornado.httpserver import HTTPServer 390 | from tornado.ioloop import IOLoop 391 | 392 | app = Flask(__name__, static_url_path='/static') 393 | 394 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False 395 | CORS(app) 396 | 397 | @app.route('/') 398 | def index(): 399 | return app.send_static_file('index.html') 400 | 401 | @app.route('/files/') 402 | def static_files(path): 403 | return app.send_static_file('files/' + path) 404 | 405 | @app.route('/api', methods=['GET']) 406 | def api(): 407 | query = request.args['query'] 408 | out = retrieve(query, 5) 409 | return jsonify(out) 410 | 411 | print('Starting server at %d' % args.port) 412 | http_server = HTTPServer(WSGIContainer(app)) 413 | http_server.listen(args.port) 414 | IOLoop.instance().start() 415 | 416 | 417 | def main(): 418 | argument_parser = ArgumentParser() 419 | argument_parser.add_arguments() 420 | args = argument_parser.parse_args() 421 | if args.mode == 'train': 422 | train(args) 423 | elif args.mode == 'test': 424 | test(args) 425 | elif args.mode == 'embed' or args.mode == 'embed_context' or args.mode == 'embed_question': 426 | embed(args) 427 | elif args.mode.startswith('serve'): 428 | serve(args) 429 | else: 430 | raise Exception() 431 | 432 | 433 | if __name__ == "__main__": 434 | from_ = importlib.import_module(sys.argv[1]) 435 | ArgumentParser = from_.ArgumentParser 436 | FileInterface = from_.FileInterface 437 | Processor = from_.Processor 438 | Sampler = from_.Sampler 439 | Model = from_.Model 440 | Loss = from_.Loss 441 | assert issubclass(ArgumentParser, base.ArgumentParser) 442 | assert issubclass(FileInterface, base.FileInterface) 443 | assert issubclass(Processor, base.Processor) 444 | assert issubclass(Sampler, base.Sampler) 445 | assert issubclass(Model, base.Model) 446 | assert issubclass(Loss, base.Loss) 447 | main() 448 | -------------------------------------------------------------------------------- /squad/static/files/popper.min.js: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) Federico Zivolo 2018 3 | Distributed under the MIT License (license terms are at http://opensource.org/licenses/MIT). 4 | */(function(e,t){'object'==typeof exports&&'undefined'!=typeof module?module.exports=t():'function'==typeof define&&define.amd?define(t):e.Popper=t()})(this,function(){'use strict';function e(e){return e&&'[object Function]'==={}.toString.call(e)}function t(e,t){if(1!==e.nodeType)return[];var o=getComputedStyle(e,null);return t?o[t]:o}function o(e){return'HTML'===e.nodeName?e:e.parentNode||e.host}function n(e){if(!e)return document.body;switch(e.nodeName){case'HTML':case'BODY':return e.ownerDocument.body;case'#document':return e.body;}var i=t(e),r=i.overflow,p=i.overflowX,s=i.overflowY;return /(auto|scroll|overlay)/.test(r+s+p)?e:n(o(e))}function r(e){return 11===e?re:10===e?pe:re||pe}function p(e){if(!e)return document.documentElement;for(var o=r(10)?document.body:null,n=e.offsetParent;n===o&&e.nextElementSibling;)n=(e=e.nextElementSibling).offsetParent;var i=n&&n.nodeName;return i&&'BODY'!==i&&'HTML'!==i?-1!==['TD','TABLE'].indexOf(n.nodeName)&&'static'===t(n,'position')?p(n):n:e?e.ownerDocument.documentElement:document.documentElement}function s(e){var t=e.nodeName;return'BODY'!==t&&('HTML'===t||p(e.firstElementChild)===e)}function d(e){return null===e.parentNode?e:d(e.parentNode)}function a(e,t){if(!e||!e.nodeType||!t||!t.nodeType)return document.documentElement;var o=e.compareDocumentPosition(t)&Node.DOCUMENT_POSITION_FOLLOWING,n=o?e:t,i=o?t:e,r=document.createRange();r.setStart(n,0),r.setEnd(i,0);var l=r.commonAncestorContainer;if(e!==l&&t!==l||n.contains(i))return s(l)?l:p(l);var f=d(e);return f.host?a(f.host,t):a(e,d(t).host)}function l(e){var t=1=o.clientWidth&&n>=o.clientHeight}),l=0a[e]&&!t.escapeWithReference&&(n=J(f[o],a[e]-('right'===e?f.width:f.height))),ae({},o,n)}};return l.forEach(function(e){var t=-1===['left','top'].indexOf(e)?'secondary':'primary';f=le({},f,m[t](e))}),e.offsets.popper=f,e},priority:['left','right','top','bottom'],padding:5,boundariesElement:'scrollParent'},keepTogether:{order:400,enabled:!0,fn:function(e){var t=e.offsets,o=t.popper,n=t.reference,i=e.placement.split('-')[0],r=Z,p=-1!==['top','bottom'].indexOf(i),s=p?'right':'bottom',d=p?'left':'top',a=p?'width':'height';return o[s]r(n[s])&&(e.offsets.popper[d]=r(n[s])),e}},arrow:{order:500,enabled:!0,fn:function(e,o){var n;if(!q(e.instance.modifiers,'arrow','keepTogether'))return e;var i=o.element;if('string'==typeof i){if(i=e.instance.popper.querySelector(i),!i)return e;}else if(!e.instance.popper.contains(i))return console.warn('WARNING: `arrow.element` must be child of its popper element!'),e;var r=e.placement.split('-')[0],p=e.offsets,s=p.popper,d=p.reference,a=-1!==['left','right'].indexOf(r),l=a?'height':'width',f=a?'Top':'Left',m=f.toLowerCase(),h=a?'left':'top',c=a?'bottom':'right',u=S(i)[l];d[c]-us[c]&&(e.offsets.popper[m]+=d[m]+u-s[c]),e.offsets.popper=g(e.offsets.popper);var b=d[m]+d[l]/2-u/2,y=t(e.instance.popper),w=parseFloat(y['margin'+f],10),E=parseFloat(y['border'+f+'Width'],10),v=b-e.offsets.popper[m]-w-E;return v=$(J(s[l]-u,v),0),e.arrowElement=i,e.offsets.arrow=(n={},ae(n,m,Q(v)),ae(n,h,''),n),e},element:'[x-arrow]'},flip:{order:600,enabled:!0,fn:function(e,t){if(W(e.instance.modifiers,'inner'))return e;if(e.flipped&&e.placement===e.originalPlacement)return e;var o=v(e.instance.popper,e.instance.reference,t.padding,t.boundariesElement,e.positionFixed),n=e.placement.split('-')[0],i=T(n),r=e.placement.split('-')[1]||'',p=[];switch(t.behavior){case he.FLIP:p=[n,i];break;case he.CLOCKWISE:p=z(n);break;case he.COUNTERCLOCKWISE:p=z(n,!0);break;default:p=t.behavior;}return p.forEach(function(s,d){if(n!==s||p.length===d+1)return e;n=e.placement.split('-')[0],i=T(n);var a=e.offsets.popper,l=e.offsets.reference,f=Z,m='left'===n&&f(a.right)>f(l.left)||'right'===n&&f(a.left)f(l.top)||'bottom'===n&&f(a.top)f(o.right),g=f(a.top)f(o.bottom),b='left'===n&&h||'right'===n&&c||'top'===n&&g||'bottom'===n&&u,y=-1!==['top','bottom'].indexOf(n),w=!!t.flipVariations&&(y&&'start'===r&&h||y&&'end'===r&&c||!y&&'start'===r&&g||!y&&'end'===r&&u);(m||b||w)&&(e.flipped=!0,(m||b)&&(n=p[d+1]),w&&(r=G(r)),e.placement=n+(r?'-'+r:''),e.offsets.popper=le({},e.offsets.popper,C(e.instance.popper,e.offsets.reference,e.placement)),e=P(e.instance.modifiers,e,'flip'))}),e},behavior:'flip',padding:5,boundariesElement:'viewport'},inner:{order:700,enabled:!1,fn:function(e){var t=e.placement,o=t.split('-')[0],n=e.offsets,i=n.popper,r=n.reference,p=-1!==['left','right'].indexOf(o),s=-1===['top','left'].indexOf(o);return i[p?'left':'top']=r[o]-(s?i[p?'width':'height']:0),e.placement=T(t),e.offsets.popper=g(i),e}},hide:{order:800,enabled:!0,fn:function(e){if(!q(e.instance.modifiers,'hide','preventOverflow'))return e;var t=e.offsets.reference,o=D(e.instance.modifiers,function(e){return'preventOverflow'===e.name}).boundaries;if(t.bottomo.right||t.top>o.bottom||t.rightthis._items.length-1||t<0))if(this._isSliding)P(this._element).one(Q.SLID,function(){return e.to(t)});else{if(n===t)return this.pause(),void this.cycle();var i=ndocument.documentElement.clientHeight;!this._isBodyOverflowing&&t&&(this._element.style.paddingLeft=this._scrollbarWidth+"px"),this._isBodyOverflowing&&!t&&(this._element.style.paddingRight=this._scrollbarWidth+"px")},t._resetAdjustments=function(){this._element.style.paddingLeft="",this._element.style.paddingRight=""},t._checkScrollbar=function(){var t=document.body.getBoundingClientRect();this._isBodyOverflowing=t.left+t.right
',trigger:"hover focus",title:"",delay:0,html:!(Ie={AUTO:"auto",TOP:"top",RIGHT:"right",BOTTOM:"bottom",LEFT:"left"}),selector:!(Se={animation:"boolean",template:"string",title:"(string|element|function)",trigger:"string",delay:"(number|object)",html:"boolean",selector:"(string|boolean)",placement:"(string|function)",offset:"(number|string)",container:"(string|element|boolean)",fallbackPlacement:"(string|array)",boundary:"(string|element)"}),placement:"top",offset:0,container:!1,fallbackPlacement:"flip",boundary:"scrollParent"},we="out",Ne={HIDE:"hide"+Ee,HIDDEN:"hidden"+Ee,SHOW:(De="show")+Ee,SHOWN:"shown"+Ee,INSERTED:"inserted"+Ee,CLICK:"click"+Ee,FOCUSIN:"focusin"+Ee,FOCUSOUT:"focusout"+Ee,MOUSEENTER:"mouseenter"+Ee,MOUSELEAVE:"mouseleave"+Ee},Oe="fade",ke="show",Pe=".tooltip-inner",je=".arrow",He="hover",Le="focus",Re="click",xe="manual",We=function(){function i(t,e){if("undefined"==typeof h)throw new TypeError("Bootstrap tooltips require Popper.js (https://popper.js.org)");this._isEnabled=!0,this._timeout=0,this._hoverState="",this._activeTrigger={},this._popper=null,this.element=t,this.config=this._getConfig(e),this.tip=null,this._setListeners()}var t=i.prototype;return t.enable=function(){this._isEnabled=!0},t.disable=function(){this._isEnabled=!1},t.toggleEnabled=function(){this._isEnabled=!this._isEnabled},t.toggle=function(t){if(this._isEnabled)if(t){var e=this.constructor.DATA_KEY,n=pe(t.currentTarget).data(e);n||(n=new this.constructor(t.currentTarget,this._getDelegateConfig()),pe(t.currentTarget).data(e,n)),n._activeTrigger.click=!n._activeTrigger.click,n._isWithActiveTrigger()?n._enter(null,n):n._leave(null,n)}else{if(pe(this.getTipElement()).hasClass(ke))return void this._leave(null,this);this._enter(null,this)}},t.dispose=function(){clearTimeout(this._timeout),pe.removeData(this.element,this.constructor.DATA_KEY),pe(this.element).off(this.constructor.EVENT_KEY),pe(this.element).closest(".modal").off("hide.bs.modal"),this.tip&&pe(this.tip).remove(),this._isEnabled=null,this._timeout=null,this._hoverState=null,(this._activeTrigger=null)!==this._popper&&this._popper.destroy(),this._popper=null,this.element=null,this.config=null,this.tip=null},t.show=function(){var e=this;if("none"===pe(this.element).css("display"))throw new Error("Please use show on visible elements");var t=pe.Event(this.constructor.Event.SHOW);if(this.isWithContent()&&this._isEnabled){pe(this.element).trigger(t);var n=pe.contains(this.element.ownerDocument.documentElement,this.element);if(t.isDefaultPrevented()||!n)return;var i=this.getTipElement(),r=Fn.getUID(this.constructor.NAME);i.setAttribute("id",r),this.element.setAttribute("aria-describedby",r),this.setContent(),this.config.animation&&pe(i).addClass(Oe);var o="function"==typeof this.config.placement?this.config.placement.call(this,i,this.element):this.config.placement,s=this._getAttachment(o);this.addAttachmentClass(s);var a=!1===this.config.container?document.body:pe(document).find(this.config.container);pe(i).data(this.constructor.DATA_KEY,this),pe.contains(this.element.ownerDocument.documentElement,this.tip)||pe(i).appendTo(a),pe(this.element).trigger(this.constructor.Event.INSERTED),this._popper=new h(this.element,i,{placement:s,modifiers:{offset:{offset:this.config.offset},flip:{behavior:this.config.fallbackPlacement},arrow:{element:je},preventOverflow:{boundariesElement:this.config.boundary}},onCreate:function(t){t.originalPlacement!==t.placement&&e._handlePopperPlacementChange(t)},onUpdate:function(t){e._handlePopperPlacementChange(t)}}),pe(i).addClass(ke),"ontouchstart"in document.documentElement&&pe(document.body).children().on("mouseover",null,pe.noop);var l=function(){e.config.animation&&e._fixTransition();var t=e._hoverState;e._hoverState=null,pe(e.element).trigger(e.constructor.Event.SHOWN),t===we&&e._leave(null,e)};if(pe(this.tip).hasClass(Oe)){var c=Fn.getTransitionDurationFromElement(this.tip);pe(this.tip).one(Fn.TRANSITION_END,l).emulateTransitionEnd(c)}else l()}},t.hide=function(t){var e=this,n=this.getTipElement(),i=pe.Event(this.constructor.Event.HIDE),r=function(){e._hoverState!==De&&n.parentNode&&n.parentNode.removeChild(n),e._cleanTipClass(),e.element.removeAttribute("aria-describedby"),pe(e.element).trigger(e.constructor.Event.HIDDEN),null!==e._popper&&e._popper.destroy(),t&&t()};if(pe(this.element).trigger(i),!i.isDefaultPrevented()){if(pe(n).removeClass(ke),"ontouchstart"in document.documentElement&&pe(document.body).children().off("mouseover",null,pe.noop),this._activeTrigger[Re]=!1,this._activeTrigger[Le]=!1,this._activeTrigger[He]=!1,pe(this.tip).hasClass(Oe)){var o=Fn.getTransitionDurationFromElement(n);pe(n).one(Fn.TRANSITION_END,r).emulateTransitionEnd(o)}else r();this._hoverState=""}},t.update=function(){null!==this._popper&&this._popper.scheduleUpdate()},t.isWithContent=function(){return Boolean(this.getTitle())},t.addAttachmentClass=function(t){pe(this.getTipElement()).addClass(Te+"-"+t)},t.getTipElement=function(){return this.tip=this.tip||pe(this.config.template)[0],this.tip},t.setContent=function(){var t=this.getTipElement();this.setElementContent(pe(t.querySelectorAll(Pe)),this.getTitle()),pe(t).removeClass(Oe+" "+ke)},t.setElementContent=function(t,e){var n=this.config.html;"object"==typeof e&&(e.nodeType||e.jquery)?n?pe(e).parent().is(t)||t.empty().append(e):t.text(pe(e).text()):t[n?"html":"text"](e)},t.getTitle=function(){var t=this.element.getAttribute("data-original-title");return t||(t="function"==typeof this.config.title?this.config.title.call(this.element):this.config.title),t},t._getAttachment=function(t){return Ie[t.toUpperCase()]},t._setListeners=function(){var i=this;this.config.trigger.split(" ").forEach(function(t){if("click"===t)pe(i.element).on(i.constructor.Event.CLICK,i.config.selector,function(t){return i.toggle(t)});else if(t!==xe){var e=t===He?i.constructor.Event.MOUSEENTER:i.constructor.Event.FOCUSIN,n=t===He?i.constructor.Event.MOUSELEAVE:i.constructor.Event.FOCUSOUT;pe(i.element).on(e,i.config.selector,function(t){return i._enter(t)}).on(n,i.config.selector,function(t){return i._leave(t)})}pe(i.element).closest(".modal").on("hide.bs.modal",function(){return i.hide()})}),this.config.selector?this.config=l({},this.config,{trigger:"manual",selector:""}):this._fixTitle()},t._fixTitle=function(){var t=typeof this.element.getAttribute("data-original-title");(this.element.getAttribute("title")||"string"!==t)&&(this.element.setAttribute("data-original-title",this.element.getAttribute("title")||""),this.element.setAttribute("title",""))},t._enter=function(t,e){var n=this.constructor.DATA_KEY;(e=e||pe(t.currentTarget).data(n))||(e=new this.constructor(t.currentTarget,this._getDelegateConfig()),pe(t.currentTarget).data(n,e)),t&&(e._activeTrigger["focusin"===t.type?Le:He]=!0),pe(e.getTipElement()).hasClass(ke)||e._hoverState===De?e._hoverState=De:(clearTimeout(e._timeout),e._hoverState=De,e.config.delay&&e.config.delay.show?e._timeout=setTimeout(function(){e._hoverState===De&&e.show()},e.config.delay.show):e.show())},t._leave=function(t,e){var n=this.constructor.DATA_KEY;(e=e||pe(t.currentTarget).data(n))||(e=new this.constructor(t.currentTarget,this._getDelegateConfig()),pe(t.currentTarget).data(n,e)),t&&(e._activeTrigger["focusout"===t.type?Le:He]=!1),e._isWithActiveTrigger()||(clearTimeout(e._timeout),e._hoverState=we,e.config.delay&&e.config.delay.hide?e._timeout=setTimeout(function(){e._hoverState===we&&e.hide()},e.config.delay.hide):e.hide())},t._isWithActiveTrigger=function(){for(var t in this._activeTrigger)if(this._activeTrigger[t])return!0;return!1},t._getConfig=function(t){return"number"==typeof(t=l({},this.constructor.Default,pe(this.element).data(),"object"==typeof t&&t?t:{})).delay&&(t.delay={show:t.delay,hide:t.delay}),"number"==typeof t.title&&(t.title=t.title.toString()),"number"==typeof t.content&&(t.content=t.content.toString()),Fn.typeCheckConfig(ve,t,this.constructor.DefaultType),t},t._getDelegateConfig=function(){var t={};if(this.config)for(var e in this.config)this.constructor.Default[e]!==this.config[e]&&(t[e]=this.config[e]);return t},t._cleanTipClass=function(){var t=pe(this.getTipElement()),e=t.attr("class").match(be);null!==e&&e.length&&t.removeClass(e.join(""))},t._handlePopperPlacementChange=function(t){var e=t.instance;this.tip=e.popper,this._cleanTipClass(),this.addAttachmentClass(this._getAttachment(t.placement))},t._fixTransition=function(){var t=this.getTipElement(),e=this.config.animation;null===t.getAttribute("x-placement")&&(pe(t).removeClass(Oe),this.config.animation=!1,this.hide(),this.show(),this.config.animation=e)},i._jQueryInterface=function(n){return this.each(function(){var t=pe(this).data(ye),e="object"==typeof n&&n;if((t||!/dispose|hide/.test(n))&&(t||(t=new i(this,e),pe(this).data(ye,t)),"string"==typeof n)){if("undefined"==typeof t[n])throw new TypeError('No method named "'+n+'"');t[n]()}})},s(i,null,[{key:"VERSION",get:function(){return"4.1.3"}},{key:"Default",get:function(){return Ae}},{key:"NAME",get:function(){return ve}},{key:"DATA_KEY",get:function(){return ye}},{key:"Event",get:function(){return Ne}},{key:"EVENT_KEY",get:function(){return Ee}},{key:"DefaultType",get:function(){return Se}}]),i}(),pe.fn[ve]=We._jQueryInterface,pe.fn[ve].Constructor=We,pe.fn[ve].noConflict=function(){return pe.fn[ve]=Ce,We._jQueryInterface},We),Jn=(qe="popover",Ke="."+(Fe="bs.popover"),Me=(Ue=e).fn[qe],Qe="bs-popover",Be=new RegExp("(^|\\s)"+Qe+"\\S+","g"),Ve=l({},zn.Default,{placement:"right",trigger:"click",content:"",template:''}),Ye=l({},zn.DefaultType,{content:"(string|element|function)"}),ze="fade",Ze=".popover-header",Ge=".popover-body",$e={HIDE:"hide"+Ke,HIDDEN:"hidden"+Ke,SHOW:(Je="show")+Ke,SHOWN:"shown"+Ke,INSERTED:"inserted"+Ke,CLICK:"click"+Ke,FOCUSIN:"focusin"+Ke,FOCUSOUT:"focusout"+Ke,MOUSEENTER:"mouseenter"+Ke,MOUSELEAVE:"mouseleave"+Ke},Xe=function(t){var e,n;function i(){return t.apply(this,arguments)||this}n=t,(e=i).prototype=Object.create(n.prototype),(e.prototype.constructor=e).__proto__=n;var r=i.prototype;return r.isWithContent=function(){return this.getTitle()||this._getContent()},r.addAttachmentClass=function(t){Ue(this.getTipElement()).addClass(Qe+"-"+t)},r.getTipElement=function(){return this.tip=this.tip||Ue(this.config.template)[0],this.tip},r.setContent=function(){var t=Ue(this.getTipElement());this.setElementContent(t.find(Ze),this.getTitle());var e=this._getContent();"function"==typeof e&&(e=e.call(this.element)),this.setElementContent(t.find(Ge),e),t.removeClass(ze+" "+Je)},r._getContent=function(){return this.element.getAttribute("data-content")||this.config.content},r._cleanTipClass=function(){var t=Ue(this.getTipElement()),e=t.attr("class").match(Be);null!==e&&0=this._offsets[r]&&("undefined"==typeof this._offsets[r+1]||t