├── multi-step-reasoner.png ├── scripts └── reader │ ├── word.py │ ├── interactive.py │ ├── predict.py │ ├── preprocess.py │ └── README.md ├── setup.sh ├── msr ├── reader │ ├── __pycache__ │ │ ├── config.cpython-36.pyc │ │ ├── data.cpython-36.pyc │ │ ├── layers.cpython-36.pyc │ │ ├── model.cpython-36.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── vector.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── predictor.cpython-36.pyc │ │ ├── rnn_reader.cpython-36.pyc │ │ └── model.cpython-36 (blake2.cs.umass.edu's conflicted copy 2018-07-30).pyc │ ├── __init__.py │ ├── data.py │ ├── predictor.py │ ├── vector.py │ ├── config.py │ ├── rnn_reader.py │ ├── utils.py │ └── layers.py └── retriever │ ├── __pycache__ │ ├── utils.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── doc_db.cpython-36.pyc │ ├── tfidf_doc_ranker.cpython-36.pyc │ └── trained_retriever.cpython-36.pyc │ ├── __init__.py │ ├── doc_db.py │ ├── utils.py │ ├── tfidf_doc_ranker.py │ └── trained_retriever.py ├── paragraph_encoder ├── README.md ├── model │ ├── retriever.py │ ├── vector.py │ ├── retriever_module.py │ ├── layers.py │ ├── data.py │ └── utils.py ├── multi_corpus.py ├── config.py └── train_para_encoder.py ├── run_pretrained_models.sh ├── requirements.txt ├── README.md └── LICENSE /multi-step-reasoner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/multi-step-reasoner.png -------------------------------------------------------------------------------- /scripts/reader/word.py: -------------------------------------------------------------------------------- 1 | 2 | # Python code for representing a word 3 | str = "word" # There, I just represented a word 4 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # set pythonpath 4 | export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/paragraph_encoder 5 | -------------------------------------------------------------------------------- /msr/reader/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/vector.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/vector.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /msr/retriever/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/retriever/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/predictor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/predictor.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/rnn_reader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/rnn_reader.cpython-36.pyc -------------------------------------------------------------------------------- /msr/retriever/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/retriever/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /msr/retriever/__pycache__/doc_db.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/retriever/__pycache__/doc_db.cpython-36.pyc -------------------------------------------------------------------------------- /msr/retriever/__pycache__/tfidf_doc_ranker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/retriever/__pycache__/tfidf_doc_ranker.cpython-36.pyc -------------------------------------------------------------------------------- /msr/retriever/__pycache__/trained_retriever.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/retriever/__pycache__/trained_retriever.cpython-36.pyc -------------------------------------------------------------------------------- /msr/reader/__pycache__/model.cpython-36 (blake2.cs.umass.edu's conflicted copy 2018-07-30).pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajarshd/Multi-Step-Reasoning/HEAD/msr/reader/__pycache__/model.cpython-36 (blake2.cs.umass.edu's conflicted copy 2018-07-30).pyc -------------------------------------------------------------------------------- /msr/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /msr/reader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /paragraph_encoder/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Paragraph encoder 3 | 4 | * If you just want to use the paragraph representation used in the paper, please download the pretrained vectors. Refer to the data section of this [README](https://github.com/rajarshd/multi-step-for-multi-hop#data) for more details. 5 | 6 | 7 | Please run the following commands from the top-level directory. 8 | 9 | ## Training 10 | ``` 11 | python paragraph_encoder/train_para_encoder.py --data_dir data/ --src quasart|searchqa|triviaqa --embed_dir data/embeddings --model_dir model_save_dir 12 | ``` 13 | 14 | ### To save the vectors using the pretrained models 15 | ``` 16 | python paragraph_encoder/train_para_encoder.py --eval_only 1 --pretrained /path/to/model_save_dir/model.mdl --src quasart|searchqa|triviaqa --save_dir data/ 17 | ``` 18 | 19 | 20 | -------------------------------------------------------------------------------- /paragraph_encoder/model/retriever.py: -------------------------------------------------------------------------------- 1 | from model.retriever_module import LSTMParagraphScorer 2 | import logging 3 | 4 | 5 | logger = logging.getLogger() 6 | 7 | 8 | class LSTMRetriever(): 9 | def __init__(self, args, word_dict, feature_dict): 10 | 11 | self.args = args 12 | self.word_dict = word_dict 13 | self.feature_dict = feature_dict 14 | self.model = LSTMParagraphScorer(args, word_dict, feature_dict) 15 | if self.args.cuda: 16 | self.model = self.model.cuda() 17 | 18 | def get_trainable_params(self): 19 | return [p for p in self.model.parameters() if p.requires_grad] 20 | 21 | def score_paras(self, paras, para_mask, query, query_mask): 22 | 23 | scores, doc, ques = self.model(paras, para_mask, query, query_mask) 24 | return scores, doc.cpu().data.numpy(), ques.cpu().data.numpy() -------------------------------------------------------------------------------- /run_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #set -x 4 | #sh run_pretrained_models.sh quasart /mnt/nfs/scratch1/rajarshi/data/ICLR_code_release/data/ models/ /mnt/nfs/scratch1/rajarshi/ 5 | if [ "$#" -ne 4 ]; then 6 | echo "Usage: /bin/bash run_pretrained_models.sh dataset_name data_dir model_dir out_dir" 7 | echo "dataset_name -- one of triviaqa|searchqa|quasart" 8 | echo "data_dir -- top level dir path to the downloaded and unzipped data dir" 9 | echo "model_dir -- top level dir path to the downloaded and unzipped pretrained model dir" 10 | echo "out_dir -- a directory to write logs" 11 | exit 1 12 | fi 13 | dataset_name=$1 14 | data_dir=$2 15 | model_dir=$3 16 | out_dir=$4 17 | 18 | echo "Evaluating for $dataset_name..." 19 | 20 | if [ $dataset_name = "triviaqa" ]; then 21 | num_paras_test=10 22 | multi_step_reasoning_steps=3 23 | test_batch_size=10 24 | elif [ $dataset_name = "searchqa" ]; then 25 | num_paras_test=10 26 | multi_step_reasoning_steps=7 27 | test_batch_size=32 28 | elif [ $dataset_name = "quasart" ]; then 29 | num_paras_test=25 30 | multi_step_reasoning_steps=5 31 | test_batch_size=32 32 | fi 33 | 34 | python scripts/reader/train.py --domain web-open --num_paras_test $num_paras_test \ 35 | --multi_step_reasoning_steps $multi_step_reasoning_steps \ 36 | --dataset_name $dataset_name --top-spans 10 --eval_only 1 \ 37 | --pretrained $model_dir/$dataset_name/model.mdl \ 38 | --model_dir $out_dir \ 39 | --data_dir $data_dir \ 40 | --test_batch_size $test_batch_size \ 41 | --saved_para_vector $data_dir/$dataset_name/paragraph_vectors/web-open/ 42 | -------------------------------------------------------------------------------- /msr/retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | from . import DEFAULTS 12 | 13 | 14 | class DocDB(object): 15 | """Sqlite backed document storage. 16 | 17 | Implements get_doc_text(doc_id). 18 | """ 19 | 20 | def __init__(self, db_path=None): 21 | self.path = db_path or DEFAULTS['db_path'] 22 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 23 | 24 | def __enter__(self): 25 | return self 26 | 27 | def __exit__(self, *args): 28 | self.close() 29 | 30 | def path(self): 31 | """Return the path to the file that backs this database.""" 32 | return self.path 33 | 34 | def close(self): 35 | """Close the connection to the database.""" 36 | self.connection.close() 37 | 38 | def get_doc_ids(self): 39 | """Fetch all ids of docs stored in the db.""" 40 | cursor = self.connection.cursor() 41 | cursor.execute("SELECT id FROM documents") 42 | results = [r[0] for r in cursor.fetchall()] 43 | cursor.close() 44 | return results 45 | 46 | def get_doc_text(self, doc_id): 47 | """Fetch the raw text of the doc for 'doc_id'.""" 48 | cursor = self.connection.cursor() 49 | cursor.execute( 50 | "SELECT text FROM documents WHERE id = ?", 51 | (utils.normalize(doc_id),) 52 | ) 53 | result = cursor.fetchone() 54 | cursor.close() 55 | return result if result is None else result[0] 56 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements automatically generated by pigar. 2 | # https://github.com/damnever/pigar 3 | 4 | # create_small.py: 5 5 | # drqa/reader/data.py: 9 6 | # drqa/reader/model.py: 12 7 | # drqa/retriever/tfidf_doc_ranker.py: 10 8 | # drqa/retriever/trained_retriever.py: 4 9 | # drqa/retriever/utils.py: 11 10 | # paragraph_encoder/model/data.py: 11 11 | # paragraph_encoder/model/vector.py: 10 12 | # paragraph_encoder/multi_corpus.py: 1 13 | # paragraph_encoder/train_para_encoder.py: 2 14 | # scripts/reader/train.py: 12 15 | numpy == 1.15.0 16 | 17 | # drqa/tokenizers/corenlp_tokenizer.py: 14 18 | pexpect == 4.2.1 19 | 20 | # drqa/reader/utils.py: 13 21 | # drqa/retriever/utils.py: 9 22 | # drqa/tokenizers/regexp_tokenizer.py: 13 23 | # drqa/tokenizers/simple_tokenizer.py: 11 24 | regex == 2017.4.5 25 | 26 | # drqa/retriever/tfidf_doc_ranker.py: 11 27 | # drqa/retriever/utils.py: 12 28 | scipy == 1.1.0 29 | 30 | # drqa/retriever/utils.py: 13 31 | # paragraph_encoder/multi_corpus.py: 4,5 32 | sklearn == 0.0 33 | 34 | # paragraph_encoder/config.py: 4 35 | # paragraph_encoder/model/data.py: 14 36 | # paragraph_encoder/model/util.py: 18 37 | # paragraph_encoder/model/utils.py: 18 38 | smart_open == 1.5.5 39 | 40 | # drqa/tokenizers/spacy_tokenizer.py: 12 41 | spacy == 2.0.12 42 | 43 | # create_small.py: 6 44 | # drqa/reader/data.py: 13,14 45 | # drqa/reader/layers.py: 9,10,11,12 46 | # drqa/reader/model.py: 9,10,11,17 47 | # drqa/reader/rnn_reader.py: 9,10,11 48 | # drqa/reader/utils.py: 14 49 | # drqa/reader/vector.py: 10 50 | # drqa/retriever/trained_retriever.py: 6,7 51 | # paragraph_encoder/model/data.py: 19,20 52 | # paragraph_encoder/model/layers.py: 9,10,11,12 53 | # paragraph_encoder/model/retriever_module.py: 8,9,10 54 | # paragraph_encoder/model/util.py: 9 55 | # paragraph_encoder/model/utils.py: 9 56 | # paragraph_encoder/model/vector.py: 9 57 | # paragraph_encoder/train_para_encoder.py: 21 58 | # scripts/reader/interactive.py: 9 59 | # scripts/reader/predict.py: 11 60 | # scripts/reader/train.py: 28 61 | torch == 0.4.0 62 | 63 | # drqa/reader/utils.py: 16 64 | # drqa/retriever/trained_retriever.py: 5 65 | # paragraph_encoder/model/util.py: 13 66 | # paragraph_encoder/model/utils.py: 13 67 | # paragraph_encoder/multi_corpus.py: 6 68 | # paragraph_encoder/train_para_encoder.py: 9 69 | # scripts/reader/predict.py: 16 70 | # scripts/reader/train.py: 18 71 | tqdm == 4.24.0 72 | 73 | # paragraph_encoder/model/util.py: 10 74 | # paragraph_encoder/model/utils.py: 10 75 | ujson == 1.35 76 | -------------------------------------------------------------------------------- /paragraph_encoder/model/vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # Few methods have been adapted from https://github.com/facebookresearch/DrQA 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for putting examples into torch format.""" 9 | import torch 10 | import numpy as np 11 | 12 | 13 | def vectorize(args, ex): 14 | """Torchify a single example.""" 15 | word_dict = args.word_dict 16 | 17 | # Index words 18 | if len(ex['document']) == 0: 19 | if args.train_time: 20 | return 21 | else: 22 | return 23 | if len(ex['question']) == 0: 24 | if args.train_time: 25 | return 26 | else: 27 | return 28 | 29 | document = torch.LongTensor([word_dict[w] for w in ex['document']]) 30 | question = torch.LongTensor([word_dict[w] for w in ex['question']]) 31 | if args.train_time: 32 | if ex['ans_occurance'] == 0: 33 | if np.random.binomial(1, args.neg_sample) == 0: 34 | return 35 | return document, question, ex['ans_occurance'], ex['id'] 36 | 37 | 38 | def batchify(args, para_mode, train_time): 39 | return lambda x: batchify_(args, x, para_mode, train_time) 40 | 41 | 42 | def batchify_(args, batch, para_mode, train_time): 43 | """Gather a batch of individual examples into one batch.""" 44 | 45 | new_batch = [] 46 | for d in batch: 47 | if d is not None: 48 | new_batch.append(d) 49 | batch = new_batch 50 | if len(batch) == 0: 51 | return None 52 | ids = [ex[-1] for ex in batch] 53 | docs = [ex[0] for ex in batch] 54 | questions = [ex[1] for ex in batch] 55 | num_occurances = [ex[-2] for ex in batch] 56 | num_occurances = torch.LongTensor(num_occurances) 57 | # Batch documents and features 58 | max_length = max([d.size(0) for d in docs]) 59 | x1 = torch.LongTensor(len(docs), max_length).zero_() 60 | x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) 61 | 62 | for i, d in enumerate(docs): 63 | x1[i, :d.size(0)].copy_(d) 64 | x1_mask[i, :d.size(0)].fill_(0) 65 | 66 | # Batch questions 67 | max_length = max([q.size(0) for q in questions]) 68 | x2 = torch.LongTensor(len(questions), max_length).zero_() 69 | x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) 70 | for i, q in enumerate(questions): 71 | x2[i, :q.size(0)].copy_(q) 72 | x2_mask[i, :q.size(0)].fill_(0) 73 | 74 | return x1, x1_mask, x2, x2_mask, num_occurances, ids -------------------------------------------------------------------------------- /scripts/reader/interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to run the DrQA reader model interactively.""" 8 | 9 | import torch 10 | import code 11 | import argparse 12 | import logging 13 | import prettytable 14 | import time 15 | 16 | from drqa.reader import Predictor 17 | 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 21 | console = logging.StreamHandler() 22 | console.setFormatter(fmt) 23 | logger.addHandler(console) 24 | 25 | 26 | # ------------------------------------------------------------------------------ 27 | # Commandline arguments & init 28 | # ------------------------------------------------------------------------------ 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--model', type=str, default=None, 33 | help='Path to model to use') 34 | parser.add_argument('--tokenizer', type=str, default=None, 35 | help=("String option specifying tokenizer type to use " 36 | "(e.g. 'corenlp')")) 37 | parser.add_argument('--no-cuda', action='store_true', 38 | help='Use CPU only') 39 | parser.add_argument('--gpu', type=int, default=-1, 40 | help='Specify GPU device id to use') 41 | parser.add_argument('--no-normalize', action='store_true', 42 | help='Do not softmax normalize output scores.') 43 | args = parser.parse_args() 44 | 45 | args.cuda = not args.no_cuda and torch.cuda.is_available() 46 | if args.cuda: 47 | torch.cuda.set_device(args.gpu) 48 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 49 | else: 50 | logger.info('Running on CPU only.') 51 | 52 | predictor = Predictor(args.model, args.tokenizer, num_workers=0, 53 | normalize=not args.no_normalize) 54 | if args.cuda: 55 | predictor.cuda() 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | # Drop in to interactive mode 60 | # ------------------------------------------------------------------------------ 61 | 62 | 63 | def process(document, question, candidates=None, top_n=1): 64 | t0 = time.time() 65 | predictions = predictor.predict(document, question, candidates, top_n) 66 | table = prettytable.PrettyTable(['Rank', 'Span', 'Score']) 67 | for i, p in enumerate(predictions, 1): 68 | table.add_row([i, p[0], p[1]]) 69 | print(table) 70 | print('Time: %.4f' % (time.time() - t0)) 71 | 72 | 73 | banner = """ 74 | DrQA Interactive Document Reader Module 75 | >> process(document, question, candidates=None, top_n=1) 76 | >> usage() 77 | """ 78 | 79 | 80 | def usage(): 81 | print(banner) 82 | 83 | 84 | code.interact(banner=banner, local=locals()) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi Step Reasoning for Open Domain Question Answering 2 | 3 | 4 | ![gif](multi-step-reasoner.png) 5 | Code for the paper [Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering](https://openreview.net/forum?id=HkfPSh05K7) 6 | 7 | *Acknowledgement*: This codebase started from the awesome [Dr.QA repository](https://github.com/facebookresearch/DrQA) created and maintained by [Adam Fisch](https://people.csail.mit.edu/fisch/). Thanks Adam! 8 | 9 | ## Setup 10 | The requirements are in the [requirements file](requirements.txt). In my env, I also needed to set PYTHONPATH (as in the [setup.sh](setup.sh)) 11 | ``` 12 | pip install -r requirements.txt 13 | source setup.sh 14 | ``` 15 | 16 | ## Data 17 | We are making the pre-processed data and paragraph vectors available so that is is easier to get started. They can downloaded from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/data.tar.gz). (41GB compressed, 56GB decompressed; user/pass: guest/guest). If you need the pretrained paragraph encoder used to generate the vectors, feel free to get in touch with me. 18 | After un-taring, you will find a directory corresponding to each dataset. Each directory further contains: 19 | ``` 20 | data/ -- Processed data (*.pkl files) 21 | paragraph_vectors/ -- Saved paragraph vectors of context for each dataset used for nearest-neighbor search 22 | vocab/ -- int2str mapping 23 | embeddings/ -- Saved lookup table for faster initialization. The embeddings are essentially saved fast-text embeddings. 24 | ``` 25 | 26 | ## Paragraph encoder 27 | If you want to train new paragraph embeddings instead of using the ones we used, please refer to this [readme](paragraph_encoder/README.md) 28 | 29 | 30 | ## Training 31 | ``` 32 | python scripts/reader/train.py --data_dir --model_dir --dataset_name searchqa|triviaqa\quasart --saved_para_vectors_dir /dataset_name/paragraph_vectors/web-open 33 | ``` 34 | Some important command line args 35 | ``` 36 | dataset_name -- searchqa|triviaqa|quasart 37 | data_dir -- path to dataset that you downloaded 38 | model_dir -- path where model would be checkpointed 39 | saved_para_vectors_dir -- path to cached paragraph and query representations in disk. It should be in the data you have downloaded 40 | multi_step_reasoning_steps -- Number of steps of interaction between retriever and reader 41 | num_positive_paras -- (Relevant during training) -- Number of "positive" (wrt distant supervision) paragraphs fed to train to the reader model. 42 | num_paras_test -- (Relevant during inference time) -- Number of paragraphs to be sent to the reader by the retriever. 43 | freeze_reader -- when set to 1, the reader parameters are fixed and only the parameters of the GRU (multi-step-reasoner) is trained. 44 | fine_tune_RL -- fune tune the GRU (multi-step-reasoner) with reward (F1) from the fixed reader 45 | ``` 46 | Training details: 47 | 1. During training, we first train the reader model by setting ```multi_step_reasoning_steps = 1``` 48 | 2. After the reader has been trained, we fix the reader and just pretrain the ```multi-step-reasoner``` (```freeze_reader 1```) 49 | 3. Next, we fine tune the reasoner with reinforcement learning (```freeze_reader = 1, fine_tune_RL = 1```) 50 | 51 | In our experiments for searchqa and quasart, we found step 2 (pretraining the GRU was not important) and the reasoner was directly able to learn via RL. However, pretraining never hurt the performance as well. 52 | 53 | ## Pretrained models 54 | 55 | We are also providing pretrained models for download and scripts to run them directly. Download the pretrained models from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/models.tar.gz). 56 | ``` 57 | Usage: /bin/bash run_pretrained_models.sh dataset_name data_dir model_dir out_dir 58 | dataset_name -- searchqa|triviaqa|quasart 59 | data_dir -- path to dataset that you downloaded 60 | model_dir -- path to pretrained model that you downloaded 61 | out_dir -- directory for logging 62 | ``` 63 | ## To-do 64 | - [ ] Integrate with code for SGTree 65 | ## Citation 66 | ``` 67 | @inproceedings{ 68 | das2018multistep, 69 | title={Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering}, 70 | author={Rajarshi Das and Shehzaad Dhuliawala and Manzil Zaheer and Andrew McCallum}, 71 | booktitle={ICLR}, 72 | year={2019}, 73 | } 74 | ``` 75 | 76 | -------------------------------------------------------------------------------- /msr/retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | -------------------------------------------------------------------------------- /scripts/reader/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to make and save model predictions on an input dataset.""" 8 | 9 | import os 10 | import time 11 | import torch 12 | import argparse 13 | import logging 14 | import json 15 | 16 | from tqdm import tqdm 17 | from drqa.reader import Predictor 18 | 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.INFO) 21 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 22 | console = logging.StreamHandler() 23 | console.setFormatter(fmt) 24 | logger.addHandler(console) 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('dataset', type=str, default=None, 28 | help='SQuAD-like dataset to evaluate on') 29 | parser.add_argument('--model', type=str, default=None, 30 | help='Path to model to use') 31 | parser.add_argument('--embedding-file', type=str, default=None, 32 | help=('Expand dictionary to use all pretrained ' 33 | 'embeddings in this file.')) 34 | parser.add_argument('--out-dir', type=str, default='/tmp', 35 | help=('Directory to write prediction file to ' 36 | '(-.preds)')) 37 | parser.add_argument('--tokenizer', type=str, default=None, 38 | help=("String option specifying tokenizer type to use " 39 | "(e.g. 'corenlp')")) 40 | parser.add_argument('--num-workers', type=int, default=None, 41 | help='Number of CPU processes (for tokenizing, etc)') 42 | parser.add_argument('--no-cuda', action='store_true', 43 | help='Use CPU only') 44 | parser.add_argument('--gpu', type=int, default=-1, 45 | help='Specify GPU device id to use') 46 | parser.add_argument('--batch-size', type=int, default=128, 47 | help='Example batching size') 48 | parser.add_argument('--top-n', type=int, default=1, 49 | help='Store top N predicted spans per example') 50 | parser.add_argument('--official', action='store_true', 51 | help='Only store single top span instead of top N list') 52 | args = parser.parse_args() 53 | t0 = time.time() 54 | 55 | args.cuda = not args.no_cuda and torch.cuda.is_available() 56 | if args.cuda: 57 | torch.cuda.set_device(args.gpu) 58 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 59 | else: 60 | logger.info('Running on CPU only.') 61 | 62 | predictor = Predictor( 63 | model=args.model, 64 | tokenizer=args.tokenizer, 65 | embedding_file=args.embedding_file, 66 | num_workers=args.num_workers, 67 | ) 68 | if args.cuda: 69 | predictor.cuda() 70 | 71 | 72 | # ------------------------------------------------------------------------------ 73 | # Read in dataset and make predictions. 74 | # ------------------------------------------------------------------------------ 75 | 76 | 77 | examples = [] 78 | qids = [] 79 | with open(args.dataset) as f: 80 | data = json.load(f)['data'] 81 | for article in data: 82 | for paragraph in article['paragraphs']: 83 | context = paragraph['context'] 84 | for qa in paragraph['qas']: 85 | qids.append(qa['id']) 86 | examples.append((context, qa['question'])) 87 | 88 | results = {} 89 | for i in tqdm(range(0, len(examples), args.batch_size)): 90 | predictions = predictor.predict_batch( 91 | examples[i:i + args.batch_size], top_n=args.top_n 92 | ) 93 | for j in range(len(predictions)): 94 | # Official eval expects just a qid --> span 95 | if args.official: 96 | results[qids[i + j]] = predictions[j][0][0] 97 | 98 | # Otherwise we store top N and scores for debugging. 99 | else: 100 | results[qids[i + j]] = [(p[0], float(p[1])) for p in predictions[j]] 101 | 102 | model = os.path.splitext(os.path.basename(args.model or 'default'))[0] 103 | basename = os.path.splitext(os.path.basename(args.dataset))[0] 104 | outfile = os.path.join(args.out_dir, basename + '-' + model + '.preds') 105 | 106 | logger.info('Writing results to %s' % outfile) 107 | with open(outfile, 'w') as f: 108 | json.dump(results, f) 109 | 110 | logger.info('Total time: %.2f' % (time.time() - t0)) 111 | -------------------------------------------------------------------------------- /msr/retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from . import utils 17 | from . import DEFAULTS 18 | from .. import tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /paragraph_encoder/model/retriever_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # Few methods have been adapted from https://github.com/facebookresearch/DrQA 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import logging 12 | from model.utils import load_embeddings 13 | from .layers import StackedBRNN 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | def weighted_avg(x, weights): 19 | """Return a weighted average of x (a sequence of vectors). 20 | Args: 21 | x: batch * len * hdim 22 | weights: batch * len, sum(dim = 1) = 1 23 | Output: 24 | x_avg: batch * hdim 25 | """ 26 | return weights.unsqueeze(1).bmm(x).squeeze(1) 27 | 28 | 29 | class LinearSeqAttn(nn.Module): 30 | """Self attention over a sequence: 31 | * o_i = softmax(Wx_i) for x_i in X. 32 | """ 33 | 34 | def __init__(self, input_size): 35 | super(LinearSeqAttn, self).__init__() 36 | self.linear = nn.Linear(input_size, 1) 37 | 38 | def forward(self, x, x_mask): 39 | """ 40 | Args: 41 | x: batch * len * hdim 42 | x_mask: batch * len (1 for padding, 0 for true) 43 | Output: 44 | alpha: batch * len 45 | """ 46 | x = x.contiguous() 47 | 48 | x_flat = x.view(-1, x.size(-1)) 49 | scores = self.linear(x_flat).view(x.size(0), x.size(1)) 50 | scores.data.masked_fill_(x_mask.data, -float('inf')) 51 | alpha = F.softmax(scores) 52 | return alpha 53 | 54 | class LSTMParagraphScorer(nn.Module): 55 | 56 | def __init__(self, args, word_dict, feature_dict): 57 | super(LSTMParagraphScorer, self).__init__() 58 | self.args = args 59 | self.word_dict = word_dict 60 | self.embedding = nn.Embedding(args.vocab_size, args.embedding_dim, padding_idx=0) 61 | 62 | if args.pretrained_words: 63 | self._set_embeddings() 64 | if args.fix_embeddings: 65 | for p in self.embedding.parameters(): 66 | p.requires_grad = False 67 | 68 | self.document_lstm = StackedBRNN(self.args.embedding_dim, self.args.paraclf_hidden_size, 3, dropout_rate=0.2, 69 | concat_layers=True) 70 | self.question_lstm = StackedBRNN(self.args.embedding_dim, self.args.paraclf_hidden_size, 3, dropout_rate=0.2, 71 | concat_layers=True) 72 | self.para_selfaatn = LinearSeqAttn(args.paraclf_hidden_size * 6) 73 | self.query_selfaatn = LinearSeqAttn(args.paraclf_hidden_size * 6) 74 | self.bilinear = nn.Linear(args.paraclf_hidden_size * 6, args.paraclf_hidden_size * 6, bias=False) 75 | 76 | def _set_embeddings(self): 77 | # Read word embeddings. 78 | if not self.args.embedding_file or not self.args.pretrained_words: 79 | logger.warn('[ WARNING: No embeddings provided. ' 80 | 'Keeping random initialization. ]') 81 | return 82 | logger.info('[ Loading pre-trained embeddings from {} ]'.format(self.args.embedding_file)) 83 | 84 | embeddings = load_embeddings(self.args, self.word_dict) 85 | logger.info('[ Num embeddings = %d ]' % embeddings.size(0)) 86 | 87 | # Sanity check dimensions 88 | new_size = embeddings.size() 89 | old_size = self.embedding.weight.size() 90 | 91 | assert (new_size[1] == old_size[1]) 92 | if new_size[0] != old_size[0]: 93 | logger.warn('[ WARNING: Number of embeddings changed (%d->%d) ]' % 94 | (old_size[0], new_size[0])) 95 | 96 | self.embedding.weight.data = embeddings 97 | 98 | def forward(self, x1, x1_mask, x2, x2_mask): 99 | 100 | """ 101 | x1 = document word indices [sum_paras * len_p] 102 | x1_mask = document padding mask [sum_paras * len_p] 103 | x2 = question word indices [batch * len_q] 104 | x2_mask = question padding mask [batch * len_q] 105 | """ 106 | # Embed both document and question 107 | x1_emb = self.embedding(x1) 108 | x2_emb = self.embedding(x2) 109 | 110 | # Dropout on embeddings 111 | if self.args.dropout_emb > 0 and not self.training: 112 | x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, training=self.training) 113 | x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, training=self.training) 114 | 115 | o1 = self.document_lstm(x1_emb, x1_mask) 116 | o2 = self.question_lstm(x2_emb, x2_mask) 117 | 118 | doc_attn = self.para_selfaatn(o1, x1_mask) 119 | ques_attn = self.query_selfaatn(o2, x2_mask) 120 | 121 | doc = weighted_avg(o1, doc_attn) 122 | ques = weighted_avg(o2, ques_attn) 123 | 124 | doc = self.bilinear(doc) 125 | scores = ques * doc 126 | scores = torch.sum(scores, -1, keepdim=True) 127 | 128 | return scores, doc, ques -------------------------------------------------------------------------------- /msr/reader/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Data processing/loading helpers.""" 8 | 9 | import numpy as np 10 | import logging 11 | import unicodedata 12 | 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.sampler import Sampler 15 | from .vector import vectorize 16 | import json 17 | import os 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | # ------------------------------------------------------------------------------ 22 | # Dictionary class for tokens. 23 | # ------------------------------------------------------------------------------ 24 | 25 | 26 | class Dictionary(object): 27 | NULL = '' 28 | UNK = '' 29 | START = 2 30 | 31 | @staticmethod 32 | def normalize(token): 33 | return unicodedata.normalize('NFD', token) 34 | 35 | def __init__(self, args): 36 | self.args = args 37 | self.tok2ind = {self.NULL: 0, self.UNK: 1} 38 | self.ind2tok = {0: self.NULL, 1: self.UNK} 39 | 40 | def __len__(self): 41 | return len(self.tok2ind) 42 | 43 | def __iter__(self): 44 | return iter(self.tok2ind) 45 | 46 | def __contains__(self, key): 47 | if type(key) == int: 48 | return key in self.ind2tok 49 | elif type(key) == str: 50 | return self.normalize(key) in self.tok2ind 51 | 52 | def __getitem__(self, key): 53 | if type(key) == int: 54 | return self.ind2tok.get(key, self.UNK) 55 | if type(key) == str: 56 | return self.tok2ind.get(self.normalize(key), 57 | self.tok2ind.get(self.UNK)) 58 | 59 | def __setitem__(self, key, item): 60 | if type(key) == int and type(item) == str: 61 | self.ind2tok[key] = item 62 | elif type(key) == str and type(item) == int: 63 | self.tok2ind[key] = item 64 | else: 65 | raise RuntimeError('Invalid (key, item) types.') 66 | 67 | def add(self, token): 68 | token = self.normalize(token) 69 | if token not in self.tok2ind: 70 | index = len(self.tok2ind) 71 | self.tok2ind[token] = index 72 | self.ind2tok[index] = token 73 | def save(self): 74 | 75 | fout = open(os.path.join(self.args.vocab_dir, "ind2tok.json"), "w") 76 | json.dump(self.ind2tok, fout) 77 | fout.close() 78 | fout = open(os.path.join(self.args.vocab_dir, "tok2ind.json"), "w") 79 | json.dump(self.tok2ind, fout) 80 | fout.close() 81 | logger.info("Dictionary saved at {}".format(self.args.vocab_dir)) 82 | 83 | def tokens(self): 84 | """Get dictionary tokens. 85 | 86 | Return all the words indexed by this dictionary, except for special 87 | tokens. 88 | """ 89 | tokens = [k for k in self.tok2ind.keys() 90 | if k not in {'', ''}] 91 | return tokens 92 | 93 | 94 | # ------------------------------------------------------------------------------ 95 | # PyTorch dataset class for SQuAD (and SQuAD-like) data. 96 | # ------------------------------------------------------------------------------ 97 | 98 | class ReaderDataset(Dataset): 99 | 100 | def __init__(self, args, examples, word_dict, feature_dict, single_answer=False, train_time=False): 101 | self.args = args 102 | self.word_dict = word_dict 103 | self.feature_dict = feature_dict 104 | self.examples = examples 105 | # make a list of qids, so that we can iterate over efficiently 106 | self.qids = list(examples.questions.keys()) 107 | self.single_answer = single_answer 108 | self.train_time = train_time 109 | 110 | 111 | def __len__(self): 112 | return len(self.examples.questions) 113 | 114 | def __getitem__(self, index): 115 | 116 | question = self.examples.questions[self.qids[index]] 117 | paragraphs = [self.examples.paragraphs[pid] for pid in question.pids] 118 | 119 | return vectorize(self.args, question, paragraphs, self.word_dict, self.feature_dict, self.single_answer, 120 | train_time=self.train_time) 121 | 122 | def lengths(self): 123 | return [(len(ex['document']), len(ex['question'])) 124 | for ex in self.examples] 125 | 126 | 127 | # ------------------------------------------------------------------------------ 128 | # PyTorch sampler returning batched of sorted lengths (by doc and question). 129 | # ------------------------------------------------------------------------------ 130 | 131 | 132 | class SortedBatchSampler(Sampler): 133 | 134 | def __init__(self, lengths, batch_size, shuffle=True): 135 | self.lengths = lengths 136 | self.batch_size = batch_size 137 | self.shuffle = shuffle 138 | 139 | def __iter__(self): 140 | lengths = np.array( 141 | [(-l[0], -l[1], np.random.random()) for l in self.lengths], 142 | dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] 143 | ) 144 | indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) 145 | batches = [indices[i:i + self.batch_size] 146 | for i in range(0, len(indices), self.batch_size)] 147 | if self.shuffle: 148 | np.random.shuffle(batches) 149 | return iter([i for batch in batches for i in batch]) 150 | 151 | def __len__(self): 152 | return len(self.lengths) 153 | -------------------------------------------------------------------------------- /msr/reader/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """DrQA Document Reader predictor""" 8 | 9 | import logging 10 | 11 | from multiprocessing import Pool as ProcessPool 12 | from multiprocessing.util import Finalize 13 | 14 | from .vector import vectorize, batchify 15 | from .model import Model 16 | from . import DEFAULTS, utils 17 | from .. import tokenizers 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # ------------------------------------------------------------------------------ 23 | # Tokenize + annotate 24 | # ------------------------------------------------------------------------------ 25 | 26 | PROCESS_TOK = None 27 | 28 | 29 | def init(tokenizer_class, annotators): 30 | global PROCESS_TOK 31 | PROCESS_TOK = tokenizer_class(annotators=annotators) 32 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 33 | 34 | 35 | def tokenize(text): 36 | global PROCESS_TOK 37 | return PROCESS_TOK.tokenize(text) 38 | 39 | 40 | # ------------------------------------------------------------------------------ 41 | # Predictor class. 42 | # ------------------------------------------------------------------------------ 43 | 44 | 45 | class Predictor(object): 46 | """Load a pretrained DocReader model and predict inputs on the fly.""" 47 | 48 | def __init__(self, model=None, tokenizer=None, normalize=True, 49 | embedding_file=None, num_workers=None): 50 | """ 51 | Args: 52 | model: path to saved model file. 53 | tokenizer: option string to select tokenizer class. 54 | normalize: squash output score to 0-1 probabilities with a softmax. 55 | embedding_file: if provided, will expand dictionary to use all 56 | available pretrained vectors in this file. 57 | num_workers: number of CPU processes to use to preprocess batches. 58 | """ 59 | logger.info('Initializing model...') 60 | self.model = Model.load(model or DEFAULTS['model'], 61 | normalize=normalize) 62 | 63 | if embedding_file: 64 | logger.info('Expanding dictionary...') 65 | words = utils.index_embedding_words(embedding_file) 66 | added = self.model.expand_dictionary(words) 67 | self.model.load_embeddings(added, embedding_file) 68 | 69 | logger.info('Initializing tokenizer...') 70 | annotators = tokenizers.get_annotators_for_model(self.model) 71 | if not tokenizer: 72 | tokenizer_class = DEFAULTS['tokenizer'] 73 | else: 74 | tokenizer_class = tokenizers.get_class(tokenizer) 75 | 76 | if num_workers is None or num_workers > 0: 77 | self.workers = ProcessPool( 78 | num_workers, 79 | initializer=init, 80 | initargs=(tokenizer_class, annotators), 81 | ) 82 | else: 83 | self.workers = None 84 | self.tokenizer = tokenizer_class(annotators=annotators) 85 | 86 | def predict(self, document, question, candidates=None, top_n=1): 87 | """Predict a single document - question pair.""" 88 | results = self.predict_batch([(document, question, candidates,)], top_n) 89 | return results[0] 90 | 91 | def predict_batch(self, batch, top_n=1): 92 | """Predict a batch of document - question pairs.""" 93 | documents, questions, candidates = [], [], [] 94 | for b in batch: 95 | documents.append(b[0]) 96 | questions.append(b[1]) 97 | candidates.append(b[2] if len(b) == 3 else None) 98 | candidates = candidates if any(candidates) else None 99 | 100 | # Tokenize the inputs, perhaps multi-processed. 101 | if self.workers: 102 | q_tokens = self.workers.map_async(tokenize, questions) 103 | d_tokens = self.workers.map_async(tokenize, documents) 104 | q_tokens = list(q_tokens.get()) 105 | d_tokens = list(d_tokens.get()) 106 | else: 107 | q_tokens = list(map(self.tokenizer.tokenize, questions)) 108 | d_tokens = list(map(self.tokenizer.tokenize, documents)) 109 | 110 | examples = [] 111 | for i in range(len(questions)): 112 | examples.append({ 113 | 'id': i, 114 | 'question': q_tokens[i].words(), 115 | 'qlemma': q_tokens[i].lemmas(), 116 | 'document': d_tokens[i].words(), 117 | 'lemma': d_tokens[i].lemmas(), 118 | 'pos': d_tokens[i].pos(), 119 | 'ner': d_tokens[i].entities(), 120 | }) 121 | 122 | # Stick document tokens in candidates for decoding 123 | if candidates: 124 | candidates = [{'input': d_tokens[i], 'cands': candidates[i]} 125 | for i in range(len(candidates))] 126 | 127 | # Build the batch and run it through the model 128 | batch_exs = batchify([vectorize(e, self.model) for e in examples]) 129 | s, e, score = self.model.predict(batch_exs, candidates, top_n) 130 | 131 | # Retrieve the predicted spans 132 | results = [] 133 | for i in range(len(s)): 134 | predictions = [] 135 | for j in range(len(s[i])): 136 | span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize() 137 | predictions.append((span, score[i][j])) 138 | results.append(predictions) 139 | return results 140 | 141 | def cuda(self): 142 | self.model.cuda() 143 | 144 | def cpu(self): 145 | self.model.cpu() 146 | -------------------------------------------------------------------------------- /paragraph_encoder/multi_corpus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.metrics import pairwise_distances 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | import argparse 9 | import time 10 | 11 | class MultiCorpus: 12 | class Paragraph: 13 | def __init__(self, args, pid, text, answer_span, qid, tfidf): 14 | """ 15 | :param args: 16 | :param pid: 17 | :param text: 18 | :param answer_span: numpy array of size num_occ X 2 19 | :param qid: 20 | :param tfidf: 21 | """ 22 | self.args = args 23 | self.pid = pid 24 | self.text = text 25 | self.answer_span = answer_span 26 | self.ans_occurance = answer_span.shape[0] 27 | self.qid = qid 28 | self.tfidf_score = tfidf 29 | self.model_score = None 30 | class Question: 31 | def __init__(self, args, qid, text, pids): 32 | self.args = args 33 | self.qid = qid 34 | self.text = text 35 | self.pids = pids 36 | 37 | def __init__(self, args): 38 | 39 | self.args = args 40 | self.tfidf = TfidfVectorizer(strip_accents="unicode", stop_words="english") 41 | self.questions = {} 42 | self.paragraphs = {} 43 | 44 | def dists(self, question, paragraphs): 45 | 46 | text = [] 47 | for para in paragraphs: 48 | text.append(" ".join("".join(s) for s in para.text)) 49 | try: 50 | para_features = self.tfidf.fit_transform(text) 51 | q_features = self.tfidf.transform([" ".join(question)]) 52 | except: 53 | print("tfidf fit_transform threw an exception") 54 | return [(paragraphs[i], float('inf')) for i in paragraphs] 55 | 56 | dists = pairwise_distances(q_features, para_features, "cosine").ravel() 57 | sorted_ix = np.lexsort(([x.start for x in paragraphs], dists)) # in case of ties, use the earlier paragraph 58 | return [(paragraphs[i], dists[i]) for i in sorted_ix] 59 | 60 | 61 | def dists_text(self, question, paragraph_texts): 62 | """ 63 | modified dist which takes in only paragraph object 64 | :param question: 65 | :param paragraphs: 66 | :return: 67 | """ 68 | text = [] 69 | for para in paragraph_texts: 70 | text.append(" ".join(para)) 71 | 72 | try: 73 | para_features = self.tfidf.fit_transform(text) 74 | q_features = self.tfidf.transform([question]) 75 | except: 76 | print("tfidf fit_transform threw an exception") 77 | return [(paragraph_texts[i], float('inf')) for i in paragraph_texts] 78 | 79 | dists = pairwise_distances(q_features, para_features, "cosine").ravel() 80 | sorted_ix = np.argsort(dists) 81 | return [(paragraph_texts[i], dists[i]) for i in sorted_ix] 82 | 83 | def addQuestionParas(self, qid, qtext, paragraphs): 84 | 85 | # for para in paragraphs: 86 | # para.text = [w.encode("ascii", errors="ignore").decode() for w in para.text] 87 | scores = None 88 | if self.args.calculate_tfidf: 89 | scores = self.dists(qtext, paragraphs) 90 | 91 | para_ids = [] 92 | for p_counter, p in enumerate(paragraphs): 93 | tfidf_score = float('inf') 94 | if scores is not None: 95 | _, tfidf_score = scores[p_counter] 96 | 97 | pid = qid + "_para_" + str(p_counter) 98 | para_ids.append(pid) 99 | paragraph = self.Paragraph(self.args, pid, p.text, p.answer_spans, qid, tfidf_score) 100 | self.paragraphs[pid] = paragraph 101 | 102 | question = self.Question(self.args, qid, qtext, para_ids) 103 | 104 | self.questions[qid] = question 105 | 106 | def addQuestionParas(self, qid, qtext, paragraph_texts, paragraph_answer_spans): 107 | 108 | # for para in paragraphs: 109 | # para.text = [w.encode("ascii", errors="ignore").decode() for w in para.text] 110 | scores = None 111 | if self.args.calculate_tfidf: 112 | scores = self.dists_text(" ".join(qtext), paragraph_texts) 113 | 114 | para_ids = [] 115 | for p_counter, p_text in enumerate(paragraph_texts): 116 | tfidf_score = float('inf') 117 | if scores is not None: 118 | _, tfidf_score = scores[p_counter] 119 | 120 | pid = qid + "_para_" + str(p_counter) 121 | para_ids.append(pid) 122 | paragraph = self.Paragraph(self.args, pid, p_text, paragraph_answer_spans[p_counter], qid, tfidf_score) 123 | self.paragraphs[pid] = paragraph 124 | 125 | question = self.Question(self.args, qid, qtext, para_ids) 126 | 127 | self.questions[qid] = question 128 | 129 | 130 | def get_topk_tfidf(corpus): 131 | top1 = 0 132 | top3 = 0 133 | top5 = 0 134 | for qid in corpus.questions: 135 | 136 | 137 | para_scores = [(corpus.paragraphs[pid].tfidf_score, corpus.paragraphs[pid].ans_occurance) for pid in 138 | corpus.questions[qid].pids] 139 | sorted_para_scores = sorted(para_scores, key=lambda x: x[0]) 140 | # import pdb 141 | # pdb.set_trace() 142 | if sorted_para_scores[0][1] > 0: 143 | top1 += 1 144 | if sum([ans[1] for ans in sorted_para_scores[:3]]) > 0: 145 | top3 += 1 146 | if sum([ans[1] for ans in sorted_para_scores[:5]]) > 0: 147 | top5 += 1 148 | 149 | print( 150 | 'top1 = {}, top3 = {}, top5 = {} '.format(top1 / len(corpus.questions), top3 / len(corpus.questions), 151 | top5 / len(corpus.questions))) -------------------------------------------------------------------------------- /scripts/reader/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Preprocess the SQuAD dataset for training.""" 8 | 9 | import argparse 10 | import os 11 | import sys 12 | import json 13 | import time 14 | 15 | from multiprocessing import Pool 16 | from multiprocessing.util import Finalize 17 | from functools import partial 18 | from drqa import tokenizers 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Tokenize + annotate. 22 | # ------------------------------------------------------------------------------ 23 | 24 | TOK = None 25 | 26 | 27 | def init(tokenizer_class, options): 28 | global TOK 29 | TOK = tokenizer_class(**options) 30 | Finalize(TOK, TOK.shutdown, exitpriority=100) 31 | 32 | 33 | def tokenize(text): 34 | """Call the global process tokenizer on the input text.""" 35 | global TOK 36 | tokens = TOK.tokenize(text) 37 | output = { 38 | 'words': tokens.words(), 39 | 'offsets': tokens.offsets(), 40 | 'pos': tokens.pos(), 41 | 'lemma': tokens.lemmas(), 42 | 'ner': tokens.entities(), 43 | } 44 | return output 45 | 46 | 47 | # ------------------------------------------------------------------------------ 48 | # Process dataset examples 49 | # ------------------------------------------------------------------------------ 50 | 51 | 52 | def load_dataset(path): 53 | """Load json file and store fields separately.""" 54 | with open(path) as f: 55 | data = json.load(f)['data'] 56 | output = {'qids': [], 'questions': [], 'answers': [], 57 | 'contexts': [], 'qid2cid': []} 58 | for article in data: 59 | for paragraph in article['paragraphs']: 60 | output['contexts'].append(paragraph['context']) 61 | for qa in paragraph['qas']: 62 | output['qids'].append(qa['id']) 63 | output['questions'].append(qa['question']) 64 | output['qid2cid'].append(len(output['contexts']) - 1) 65 | if 'answers' in qa: 66 | output['answers'].append(qa['answers']) 67 | return output 68 | 69 | 70 | def find_answer(offsets, begin_offset, end_offset): 71 | """Match token offsets with the char begin/end offsets of the answer.""" 72 | start = [i for i, tok in enumerate(offsets) if tok[0] == begin_offset] 73 | end = [i for i, tok in enumerate(offsets) if tok[1] == end_offset] 74 | assert(len(start) <= 1) 75 | assert(len(end) <= 1) 76 | if len(start) == 1 and len(end) == 1: 77 | return start[0], end[0] 78 | 79 | 80 | def process_dataset(data, tokenizer, workers=None): 81 | """Iterate processing (tokenize, parse, etc) dataset multithreaded.""" 82 | tokenizer_class = tokenizers.get_class(tokenizer) 83 | make_pool = partial(Pool, workers, initializer=init) 84 | workers = make_pool(initargs=(tokenizer_class, {'annotators': {'lemma'}})) 85 | q_tokens = workers.map(tokenize, data['questions']) 86 | workers.close() 87 | workers.join() 88 | 89 | workers = make_pool( 90 | initargs=(tokenizer_class, {'annotators': {'lemma', 'pos', 'ner'}}) 91 | ) 92 | c_tokens = workers.map(tokenize, data['contexts']) 93 | workers.close() 94 | workers.join() 95 | 96 | for idx in range(len(data['qids'])): 97 | question = q_tokens[idx]['words'] 98 | qlemma = q_tokens[idx]['lemma'] 99 | document = c_tokens[data['qid2cid'][idx]]['words'] 100 | offsets = c_tokens[data['qid2cid'][idx]]['offsets'] 101 | lemma = c_tokens[data['qid2cid'][idx]]['lemma'] 102 | pos = c_tokens[data['qid2cid'][idx]]['pos'] 103 | ner = c_tokens[data['qid2cid'][idx]]['ner'] 104 | ans_tokens = [] 105 | if len(data['answers']) > 0: 106 | for ans in data['answers'][idx]: 107 | found = find_answer(offsets, 108 | ans['answer_start'], 109 | ans['answer_start'] + len(ans['text'])) 110 | if found: 111 | ans_tokens.append(found) 112 | yield { 113 | 'id': data['qids'][idx], 114 | 'question': question, 115 | 'document': document, 116 | 'offsets': offsets, 117 | 'answers': ans_tokens, 118 | 'qlemma': qlemma, 119 | 'lemma': lemma, 120 | 'pos': pos, 121 | 'ner': ner, 122 | } 123 | 124 | 125 | # ----------------------------------------------------------------------------- 126 | # Commandline options 127 | # ----------------------------------------------------------------------------- 128 | 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('data_dir', type=str, help='Path to SQuAD data directory') 132 | parser.add_argument('out_dir', type=str, help='Path to output file dir') 133 | parser.add_argument('--split', type=str, help='Filename for train/dev split', 134 | default='SQuAD-v1.1-train') 135 | parser.add_argument('--workers', type=int, default=None) 136 | parser.add_argument('--tokenizer', type=str, default='corenlp') 137 | args = parser.parse_args() 138 | 139 | t0 = time.time() 140 | 141 | in_file = os.path.join(args.data_dir, args.split + '.json') 142 | print('Loading dataset %s' % in_file, file=sys.stderr) 143 | dataset = load_dataset(in_file) 144 | 145 | out_file = os.path.join( 146 | args.out_dir, '%s-processed-%s.txt' % (args.split, args.tokenizer) 147 | ) 148 | print('Will write to file %s' % out_file, file=sys.stderr) 149 | with open(out_file, 'w') as f: 150 | for ex in process_dataset(dataset, args.tokenizer, args.workers): 151 | f.write(json.dumps(ex) + '\n') 152 | print('Total time: %.4f (s)' % (time.time() - t0)) 153 | -------------------------------------------------------------------------------- /msr/reader/vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Functions for putting examples into torch format.""" 8 | 9 | from collections import Counter 10 | import torch 11 | 12 | 13 | def vectorize(args, one_ex_q, one_ex_paras, word_dict, feature_dict, single_answer=False, train_time=False): 14 | """Torchify a single example.""" 15 | 16 | # Index words 17 | # ex is an instance of docqa.data_processing.multi_paragraph_qa.MultiParagraphQuestion 18 | ex = {} 19 | # doc is a list of paragraphs 20 | # last entry denotes the special token "ANS_NOT_IN_PARA" 21 | ex['document'] = [one_ex_paras[i].text for i in range(len(one_ex_paras))] 22 | ex['question'] = one_ex_q.text 23 | ex['answers'] = [] 24 | for j in range(len(one_ex_paras)): 25 | answers_in_para = [] 26 | for i in range(one_ex_paras[j].answer_span.shape[0]): 27 | answers_in_para.append((int(one_ex_paras[j].answer_span[i][0]), int(one_ex_paras[j].answer_span[i][1]))) 28 | ex['answers'].append(answers_in_para) 29 | 30 | ex['id'] = one_ex_q.qid 31 | document = [torch.LongTensor([word_dict[w] for w in ex['document'][i]]) for i in range(len(one_ex_paras))] 32 | question = torch.LongTensor([word_dict[w] for w in ex['question']]) 33 | 34 | # Create extra features vector 35 | if len(feature_dict) > 0: 36 | features = [torch.zeros(len(ex['document'][i]), len(feature_dict)) for i in range(len(one_ex_paras))] 37 | else: 38 | features = None 39 | 40 | # f_{exact_match} 41 | if args.use_in_question: 42 | q_words_cased = {w for w in ex['question']} 43 | q_words_uncased = {w.lower() for w in ex['question']} 44 | # q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None 45 | for i in range(len(ex['document'])): 46 | for j in range(len(ex['document'][i])): 47 | if ex['document'][i][j] in q_words_cased: 48 | features[i][j][feature_dict['in_question']] = 1.0 49 | if ex['document'][i][j].lower() in q_words_uncased: 50 | features[i][j][feature_dict['in_question_uncased']] = 1.0 51 | # if q_lemma and ex['lemma'][i] in q_lemma: 52 | # features[i][feature_dict['in_question_lemma']] = 1.0 53 | 54 | # # f_{token} (POS) 55 | # if args.use_pos: 56 | # for i, w in enumerate(ex['pos']): 57 | # f = 'pos=%s' % w 58 | # if f in feature_dict: 59 | # features[i][feature_dict[f]] = 1.0 60 | # 61 | # # f_{token} (NER) 62 | # if args.use_ner: 63 | # for i, w in enumerate(ex['ner']): 64 | # f = 'ner=%s' % w 65 | # if f in feature_dict: 66 | 67 | # features[i][feature_dict[f]] = 1.0 68 | # 69 | 70 | # # f_{token} (TF) 71 | # if args.use_tf: 72 | # counter = Counter([w.lower() for w in ex['document']]) 73 | # l = len(ex['document']) 74 | # for i, w in enumerate(ex['document']): 75 | # features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l 76 | 77 | # Maybe return without target 78 | # if not train_time: 79 | # return document, features, question, ex['id'] 80 | 81 | # ...or with target(s) (might still be empty if answers is empty) 82 | start, end = [], [] 83 | for ans_spans_para in ex['answers']: 84 | start_para = [span[0] for span in ans_spans_para] 85 | end_para = [span[1] for span in ans_spans_para] 86 | start.append(start_para) 87 | end.append(end_para) 88 | 89 | return document, features, question, start, end, ex['id'] 90 | 91 | 92 | def batchify(batch): 93 | """Gather a batch of individual examples into one batch.""" 94 | NUM_INPUTS = 3 95 | NUM_TARGETS = 2 96 | NUM_EXTRA = 1 97 | 98 | ids = [ex[-1] for ex in batch] 99 | docs = [ex[0] for ex in batch] 100 | features = [ex[1] for ex in batch] 101 | questions = [ex[2] for ex in batch] 102 | 103 | # Batch documents and features 104 | max_length = max([para.size(0) for d in docs for para in d]) 105 | max_num_paras = max([len(d) for d in docs]) 106 | 107 | x1 = torch.LongTensor(len(docs), max_num_paras, max_length).zero_() 108 | x1_mask = torch.ByteTensor(len(docs), max_num_paras, max_length).fill_(1) 109 | if features[0] is None: 110 | x1_f = None 111 | else: 112 | x1_f = torch.zeros(len(docs), max_num_paras, max_length, features[0][0].size(1)) 113 | for i, d in enumerate(docs): 114 | for j, p in enumerate(d): 115 | x1[i, j, :p.size(0)].copy_(p) 116 | x1_mask[i, j, :p.size(0)].fill_(0) 117 | if x1_f is not None: 118 | x1_f[i, j, :p.size(0), :].copy_(features[i][j]) 119 | 120 | # Batch questions 121 | max_length = max([q.size(0) for q in questions]) 122 | x2 = torch.LongTensor(len(questions), max_length).zero_() 123 | x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) 124 | for i, q in enumerate(questions): 125 | x2[i, :q.size(0)].copy_(q) 126 | x2_mask[i, :q.size(0)].fill_(0) 127 | 128 | # Maybe return without targets 129 | if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: 130 | return x1, x1_f, x1_mask, x2, x2_mask, ids 131 | 132 | elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: 133 | # ...Otherwise add targets 134 | if torch.is_tensor(batch[0][3]): 135 | y_s = torch.cat([ex[3] for ex in batch]) 136 | y_e = torch.cat([ex[4] for ex in batch]) 137 | else: 138 | y_s, y_e = [], [] 139 | for ex in batch: 140 | start_span_pos = [torch.LongTensor(start_pos) for start_pos in ex[3]] 141 | end_span_pos = [torch.LongTensor(end_pos) for end_pos in ex[4]] 142 | y_s.append(start_span_pos) 143 | y_e.append(end_span_pos) 144 | else: 145 | raise RuntimeError('Incorrect number of inputs per example.') 146 | 147 | return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids 148 | -------------------------------------------------------------------------------- /msr/reader/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Model architecture/optimization options for DrQA document reader.""" 8 | 9 | import argparse 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # Index of arguments concerning the core model architecture 15 | MODEL_ARCHITECTURE = { 16 | 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', 17 | 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', 18 | 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf' 19 | } 20 | 21 | # Index of arguments concerning the model optimizer/training 22 | MODEL_OPTIMIZER = { 23 | 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', 24 | 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', 25 | 'max_len', 'grad_clipping', 'tune_partial' 26 | } 27 | 28 | 29 | def str2bool(v): 30 | return v.lower() in ('yes', 'true', 't', '1', 'y') 31 | 32 | 33 | def add_model_args(parser): 34 | parser.register('type', 'bool', str2bool) 35 | 36 | # Model architecture 37 | model = parser.add_argument_group('DrQA Reader Model Architecture') 38 | model.add_argument('--model-type', type=str, default='rnn', 39 | help='Model architecture type') 40 | model.add_argument('--embedding-dim', type=int, default=300, 41 | help='Embedding size if embedding_file is not given') 42 | model.add_argument('--hidden-size', type=int, default=128, 43 | help='Hidden size of RNN units') 44 | model.add_argument('--doc-layers', type=int, default=3, 45 | help='Number of encoding layers for document') 46 | model.add_argument('--question-layers', type=int, default=3, 47 | help='Number of encoding layers for question') 48 | model.add_argument('--rnn-type', type=str, default='lstm', 49 | help='RNN type: LSTM, GRU, or RNN') 50 | model.add_argument('--top-spans', type=int, default=10, 51 | help='aggregate ascores over spans') 52 | 53 | # Model specific details 54 | detail = parser.add_argument_group('DrQA Reader Model Details') 55 | detail.add_argument('--concat-rnn-layers', type='bool', default=True, 56 | help='Combine hidden states from each encoding layer') 57 | detail.add_argument('--question-merge', type=str, default='self_attn', 58 | help='The way of computing the question representation') 59 | detail.add_argument('--use-qemb', type='bool', default=True, 60 | help='Whether to use weighted question embeddings') 61 | detail.add_argument('--use-in-question', type='bool', default=True, 62 | help='Whether to use in_question_* features') 63 | detail.add_argument('--use-pos', type='bool', default=True, 64 | help='Whether to use pos features') 65 | detail.add_argument('--use-ner', type='bool', default=True, 66 | help='Whether to use ner features') 67 | detail.add_argument('--use-lemma', type='bool', default=True, 68 | help='Whether to use lemma features') 69 | detail.add_argument('--use-tf', type='bool', default=True, 70 | help='Whether to use term frequency features') 71 | 72 | # Optimization details 73 | optim = parser.add_argument_group('DrQA Reader Optimization') 74 | optim.add_argument('--dropout-emb', type=float, default=0.4, 75 | help='Dropout rate for word embeddings') 76 | optim.add_argument('--dropout-rnn', type=float, default=0.4, 77 | help='Dropout rate for RNN states') 78 | optim.add_argument('--dropout-rnn-output', type='bool', default=True, 79 | help='Whether to dropout the RNN output') 80 | optim.add_argument('--optimizer', type=str, default='adamax', 81 | help='Optimizer: sgd or adamax') 82 | optim.add_argument('--learning-rate', type=float, default=0.1, 83 | help='Learning rate for SGD only') 84 | optim.add_argument('--grad-clipping', type=float, default=10, 85 | help='Gradient clipping') 86 | optim.add_argument('--weight-decay', type=float, default=0, 87 | help='Weight decay factor') 88 | optim.add_argument('--momentum', type=float, default=0, 89 | help='Momentum factor') 90 | optim.add_argument('--fix-embeddings', type='bool', default=True, 91 | help='Keep word embeddings fixed (use pretrained)') 92 | optim.add_argument('--tune-partial', type=int, default=0, 93 | help='Backprop through only the top N question words') 94 | optim.add_argument('--rnn-padding', type='bool', default=False, 95 | help='Explicitly account for padding in RNN encoding') 96 | optim.add_argument('--max-len', type=int, default=15, 97 | help='The max span allowed during decoding') 98 | 99 | 100 | def get_model_args(args): 101 | """Filter args for model ones. 102 | 103 | From a args Namespace, return a new Namespace with *only* the args specific 104 | to the model architecture or optimization. (i.e. the ones defined here.) 105 | """ 106 | global MODEL_ARCHITECTURE, MODEL_OPTIMIZER 107 | required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER 108 | arg_values = {k: v for k, v in vars(args).items() if k in required_args} 109 | return argparse.Namespace(**arg_values) 110 | 111 | 112 | def override_model_args(old_args, new_args): 113 | """Set args to new parameters. 114 | 115 | Decide which model args to keep and which to override when resolving a set 116 | of saved args and new args. 117 | 118 | We keep the new optimation, but leave the model architecture alone. 119 | """ 120 | global MODEL_OPTIMIZER 121 | old_args, new_args = vars(old_args), vars(new_args) 122 | for k in new_args.keys(): 123 | if k in old_args and old_args[k] != new_args[k]: 124 | # if k in MODEL_OPTIMIZER: 125 | logger.info('Overriding saved %s: %s --> %s' % 126 | (k, old_args[k], new_args[k])) 127 | old_args[k] = new_args[k] 128 | # else: 129 | # logger.info('Keeping saved %s: %s' % (k, old_args[k])) 130 | elif k not in old_args: 131 | logger.info("Adding new argument {}".format(k)) 132 | old_args[k] = new_args[k] 133 | return argparse.Namespace(**old_args) 134 | -------------------------------------------------------------------------------- /msr/retriever/trained_retriever.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | class Retriever(object): 13 | def __init__(self, args, read_dir): 14 | self.args = args 15 | self.read_dir = read_dir 16 | self.qid2filemap = json.load(open(os.path.join(read_dir, "map.json"))) 17 | self.reverse_qid2filemap = {v: k for k, v in self.qid2filemap.items()} 18 | self.qid2correctparamap = json.load(open(os.path.join(read_dir, "correct_paras.json"))) 19 | self.qids = None 20 | self.current_query_vectors = None # stores the current value of query vectors for the batch 21 | self.current_para_vectors = None # all para vectors 22 | self.all_query_vectors = None 23 | self.all_para_vectors = None 24 | self.cum_num_paras = [] 25 | self.all_cum_num_paras = [] 26 | self.embedding_dim = 128 27 | logger.info("Reading saved paragraph and query vectors from disk...{}".format(self.read_dir)) 28 | self.all_query_vectors = np.load(os.path.join(read_dir, "question.npy")) 29 | self.all_query_vectors = torch.FloatTensor(self.all_query_vectors) 30 | self.all_para_vectors = np.load(os.path.join(read_dir, "document.npy")) 31 | self.all_para_vectors = torch.FloatTensor(self.all_para_vectors) 32 | self.all_cum_num_paras = np.load(os.path.join(read_dir, "all_cumlen.npy")) 33 | self.all_cum_num_paras = torch.LongTensor(self.all_cum_num_paras) 34 | # self.qid2indexmap = torch.load(os.path.join(read_dir, "qid2indexmap.pkl")) 35 | 36 | # test cases 37 | assert self.all_cum_num_paras.size(0) == self.all_query_vectors.size(0) 38 | assert self.all_cum_num_paras[-1] == self.all_para_vectors.size(0) 39 | logger.info("Done Reading!") 40 | 41 | def reset(self): 42 | """ 43 | resets for a new batch of queries 44 | :return: 45 | """ 46 | self.current_query_vectors = None 47 | self.current_para_vectors = None 48 | self.cum_num_paras = [] 49 | 50 | def __call__(self, qids, train_time=False): 51 | 52 | # transform qids from strings to int 53 | self.qqids = qids 54 | self.qids = [self.qid2filemap[qid] for qid in qids] 55 | if self.current_query_vectors is None: # first time; read from disk 56 | 57 | 58 | for i, qid in enumerate(self.qids): 59 | en_ind = self.all_cum_num_paras[qid] 60 | st_ind = 0 if qid == 0 else self.all_cum_num_paras[qid - 1] 61 | self.current_para_vectors = self.all_para_vectors[ 62 | st_ind:en_ind] if self.current_para_vectors is None else torch.cat( 63 | [self.current_para_vectors, self.all_para_vectors[st_ind:en_ind]], dim=0) 64 | 65 | self.cum_num_paras.append(self.current_para_vectors.size(0)) 66 | 67 | self.current_query_vectors = torch.index_select(self.all_query_vectors, 0, torch.LongTensor(self.qids)) 68 | if self.args.cuda: 69 | self.current_query_vectors = Variable(self.current_query_vectors.cuda()) 70 | self.current_para_vectors = Variable(self.current_para_vectors.cuda()) 71 | 72 | # take inner product 73 | para_scores = torch.mm(self.current_query_vectors, self.current_para_vectors.t()) 74 | # now for each query, slice out the scores for the corresponding paras 75 | sorted_para_ids_per_query = [] 76 | sorted_para_scores_per_query = [] 77 | all_num_positive_paras = [] 78 | for i in range(para_scores.size(0)): # for each query 79 | st = 0 if i == 0 else self.cum_num_paras[i - 1] 80 | en = self.cum_num_paras[i] 81 | para_scores_for_query = para_scores[i, st:en] 82 | sorted_scores, para_ids_query = torch.sort(para_scores_for_query, descending=True) 83 | 84 | if train_time or self.args.cheat: 85 | # during train time, make sure that the top (may be few) paras have annotation 86 | # get correct_paras 87 | correct_para_ids = self.qid2correctparamap[self.reverse_qid2filemap[self.qids[i]]] 88 | # for some qids, there arent any labels, will have to handle them separately in model.update 89 | if len(correct_para_ids) > 0: 90 | np.random.shuffle(correct_para_ids) 91 | num_positive_paras = min(self.args.num_positive_paras, len(correct_para_ids)) 92 | correct_para_ids = correct_para_ids[:num_positive_paras] 93 | para_ids_query = para_ids_query.cpu().data.numpy().tolist() 94 | sorted_scores = sorted_scores.cpu().data.numpy().tolist() 95 | temp_para_ids_query = [] 96 | temp_sorted_scores = [] 97 | correct_para_inds = [] 98 | for i, p in enumerate(para_ids_query): 99 | if p not in correct_para_ids: 100 | temp_para_ids_query.append(p) 101 | temp_sorted_scores.append(sorted_scores[i]) 102 | else: 103 | correct_para_inds.append(i) 104 | para_ids_query = correct_para_ids + temp_para_ids_query 105 | sorted_scores = [sorted_scores[i] for i in correct_para_inds] + temp_sorted_scores 106 | para_ids_query = Variable(torch.LongTensor(para_ids_query)) 107 | sorted_scores = Variable(torch.FloatTensor(sorted_scores)) 108 | if self.args.cuda: 109 | para_ids_query = para_ids_query.cuda() 110 | sorted_scores = sorted_scores.cuda() 111 | 112 | all_num_positive_paras.append(num_positive_paras) 113 | else: 114 | all_num_positive_paras.append(0) 115 | 116 | sorted_para_ids_per_query.append(para_ids_query.data) 117 | sorted_para_scores_per_query.append(sorted_scores) 118 | 119 | return self.current_query_vectors, sorted_para_scores_per_query, sorted_para_ids_per_query, all_num_positive_paras 120 | 121 | def update_query_vectors(self, q_vectors): 122 | self.current_query_vectors = q_vectors 123 | 124 | def get_nearest_correct_para_vector(self): 125 | 126 | # gather the top para_id for each qid 127 | top_para_ids, incorrect_para_ids, mask = [], [], [] 128 | 129 | for i, qid in enumerate(self.qqids): 130 | st = 0 if i == 0 else self.cum_num_paras[i - 1] 131 | correct_paras = self.qid2correctparamap[self.reverse_qid2filemap[self.qids[i]]] 132 | np.random.shuffle(correct_paras) 133 | try: 134 | top_para_ids.append(correct_paras[0] + st) 135 | mask.append(1) 136 | except IndexError: 137 | top_para_ids.append(0 + st) # some question-paras have no answer occurrences 138 | mask.append(0) 139 | incorrect_paras = list(set(range(self.cum_num_paras[i] - st)) - set(correct_paras)) 140 | np.random.shuffle(incorrect_paras) 141 | if len(incorrect_paras) == 0: 142 | incorrect_para_ids.append(0 + st) 143 | mask[i] = 0 144 | else: 145 | incorrect_para_ids.append(incorrect_paras[0] + st) 146 | # now select the appropriate para vector 147 | top_para_ids = torch.cuda.LongTensor(top_para_ids) if self.args.cuda else torch.LongTensor(top_para_ids) 148 | incorrect_para_ids = torch.cuda.LongTensor(incorrect_para_ids) if self.args.cuda else torch.LongTensor( 149 | incorrect_para_ids) 150 | mask = torch.cuda.ByteTensor(mask) if self.args.cuda else torch.ByteTensor(mask) 151 | nearest_correct_paras = torch.index_select(self.current_para_vectors, 0, Variable(top_para_ids)) 152 | farthest_incorrect_paras = torch.index_select(self.current_para_vectors, 0, Variable(incorrect_para_ids)) 153 | return nearest_correct_paras, farthest_incorrect_paras, mask 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /scripts/reader/README.md: -------------------------------------------------------------------------------- 1 | # Document Reader 2 | 3 | ## Preprocessing 4 | 5 | `preprocess.py` takes a SQuAD-formatted dataset and outputs a preprocessed, training-ready file. Specifically, it handles tokenization, mapping character offsets to token offsets, and any additional featurization such as lemmatization, part-of-speech tagging, and named entity recognition. 6 | 7 | To preprocess SQuAD (assuming both input and output files are in `data/datasets`): 8 | 9 | ```bash 10 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD-v1.1-train 11 | ``` 12 | ```bash 13 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD-v1.1-dev 14 | ``` 15 | - _You need to have [SQuAD](../../README.md#qa-datasets) train-v1.1.json and dev-v1.1.json in data/datasets (here renamed as SQuAD-v1.1-.json)_ 16 | 17 | ## Training 18 | 19 | `train.py` is the main train script for the Document Reader. 20 | 21 | To get started with training a model on SQuAD with our best hyper parameters: 22 | 23 | ```bash 24 | python scripts/reader/train.py --embedding-file glove.840B.300d.txt --tune-partial 1000 25 | ``` 26 | - _You need to have the [glove embeddings](#note-on-word-embeddings) downloaded to data/embeddings/glove.840B.300d.txt._ 27 | - _You need to have done the preprocessing above._ 28 | 29 | The training has many options that you can tune: 30 | 31 | ``` 32 | Environment: 33 | --no-cuda Train on CPU, even if GPUs are available. (default: False) 34 | --gpu Run on a specific GPU (default: -1) 35 | --data-workers Number of subprocesses for data loading (default: 5) 36 | --parallel Use DataParallel on all available GPUs (default: False) 37 | --random-seed Random seed for all numpy/torch/cuda operations (for reproducibility). 38 | --num-epochs Train data iterations. 39 | --batch-size Batch size for training. 40 | --test-batch-size Batch size during validation/testing. 41 | 42 | Filesystem: 43 | --model-dir Directory for saved models/checkpoints/logs (default: /tmp/drqa-models). 44 | --model-name Unique model identifier (.mdl, .txt, .checkpoint) (default: ). 45 | --data-dir Directory of training/validation data (default: data/datasets). 46 | --train-file Preprocessed train file (default: SQuAD-v1.1-train-processed-corenlp.txt). 47 | --dev-file Preprocessed dev file (default: SQuAD-v1.1-dev-processed-corenlp.txt). 48 | --dev-json Unprocessed dev file to run validation while training on (used to get original text for getting spans and answer texts) (default: SQuAD-v1.1-dev.json). 49 | --embed-dir Directory of pre-trained embedding files (default: data/embeddings). 50 | --embedding-file Space-separated pretrained embeddings file (default: None). 51 | 52 | Saving/Loading: 53 | --checkpoint Save model + optimizer state after each epoch (default: False). 54 | --pretrained Path to a pretrained model to warm-start with (default: ). 55 | --expand-dictionary Expand dictionary of pretrained (--pretrained) model to include training/dev words of new data (default: False). 56 | 57 | Preprocessing: 58 | --uncased-question Question words will be lower-cased (default: False). 59 | --uncased-doc Document words will be lower-cased (default: False). 60 | --restrict-vocab Only use pre-trained words in embedding_file (default: True). 61 | 62 | General: 63 | --official-eval Validate with official SQuAD eval (default: True). 64 | --valid-metric The evaluation metric used for model selection (default: f1). 65 | --display-iter Log state after every epochs (default: 25). 66 | --sort-by-len Sort batches by length for speed (default: True). 67 | 68 | DrQA Reader Model Architecture: 69 | --model-type Model architecture type (default: rnn). 70 | --embedding-dim Embedding size if embedding_file is not given (default: 300). 71 | --hidden-size Hidden size of RNN units (default: 128). 72 | --doc-layers Number of encoding layers for document (default: 3). 73 | --question-layers Number of encoding layers for question (default: 3). 74 | --rnn-type RNN type: LSTM, GRU, or RNN (default: lstm). 75 | 76 | DrQA Reader Model Details: 77 | --concat-rnn-layers Combine hidden states from each encoding layer (default: True). 78 | --question-merge The way of computing the question representation (default: self_attn). 79 | --use-qemb Whether to use weighted question embeddings (default: True). 80 | --use-in-question Whether to use in_question_* (cased, uncased, lemma) features (default: True). 81 | --use-pos Whether to use pos features (default: True). 82 | --use-ner Whether to use ner features (default: True). 83 | --use-lemma Whether to use lemma features (default: True). 84 | --use-tf Whether to use term frequency features (default: True). 85 | 86 | DrQA Reader Optimization: 87 | --dropout-emb Dropout rate for word embeddings (default: 0.4). 88 | --dropout-rnn Dropout rate for RNN states (default: 0.4). 89 | --dropout-rnn-output Whether to dropout the RNN output (default: True). 90 | --optimizer Optimizer: sgd or adamax (default: adamax). 91 | --learning-rate Learning rate for SGD only (default: 0.1). 92 | --grad-clipping Gradient clipping (default: 10). 93 | --weight-decay Weight decay factor (default: 0). 94 | --momentum Momentum factor (default: 0). 95 | --fix-embeddings Keep word embeddings fixed (use pretrained) (default: True). 96 | --tune-partial Backprop through only the top N question words (default: 0). 97 | --rnn-padding Explicitly account for padding (and skip it) in RNN encoding (default: False). 98 | --max-len MAX_LEN The max span allowed during decoding (default: 15). 99 | ``` 100 | 101 | ### Note on Word Embeddings 102 | 103 | Using pre-trained word embeddings is very important for performance. The models we provide were trained with cased GloVe embeddings trained on Common Crawl, however we have also found that other embeddings such as FastText do quite well. 104 | 105 | We suggest downloading the embeddings files and storing them under `data/embeddings/.txt` (this is the default for `--embedding-dir`). The code expects space separated plain text files (\ \ ... \). 106 | 107 | - [GloVe: Common Crawl (cased)](http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip) 108 | - [FastText: Wikipedia (uncased)](https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec) 109 | 110 | ## Predicting 111 | 112 | `predict.py` uses a trained Document Reader model to make predictions for an input dataset. 113 | 114 | Required arguments: 115 | ``` 116 | dataset SQuAD-like dataset to evaluate on (format B). 117 | ``` 118 | 119 | Optional arguments: 120 | ``` 121 | --model Path to model to use. 122 | --embedding-file Expand dictionary to use all pretrained embeddings in this file. 123 | --out-dir Directory to write prediction file to (-.preds). 124 | --tokenizer String option specifying tokenizer type to use (e.g. 'corenlp'). 125 | --num-workers Number of CPU processes (for tokenizing, etc). 126 | --no-cuda Use CPU only. 127 | --gpu Specify GPU device id to use. 128 | --batch-size Example batching size (Reduce in case of OOM). 129 | --top-n Store top N predicted spans per example. 130 | --official Only store single top span instead of top N list. (The SQuAD eval script takes a dict of qid: span). 131 | ``` 132 | 133 | Note: The CoreNLP NER annotator is not fully deterministic (depends on the order examples are processed). Predictions may fluctuate very slightly between runs if `num-workers` > 1 and the model was trained with `use-ner` on. 134 | 135 | Evaluation is done with the official_eval.py script from the SQuAD creators, available [here](https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/). It is also available by default at `scripts/reader/official_eval.py` after running `./download.sh`. 136 | 137 | ```bash 138 | python scripts/reader/official_eval.py /path/to/format/B/dataset.json /path/to/predictions/with/--official/flag/set.json 139 | ``` 140 | 141 | ## Interactive 142 | 143 | The Document Reader can also be used interactively (like the [full pipeline](../../README.md#quick-start-demo)). 144 | 145 | ```bash 146 | python scripts/reader/interactive.py --model /path/to/model 147 | ``` 148 | 149 | ``` 150 | >>> text = "Mary had a little lamb, whose fleece was white as snow. And everywhere that Mary went the lamb was sure to go." 151 | >>> question = "What color is Mary's lamb?" 152 | >>> process(text, question) 153 | 154 | +------+-------+---------+ 155 | | Rank | Span | Score | 156 | +------+-------+---------+ 157 | | 1 | white | 0.78002 | 158 | +------+-------+---------+ 159 | ``` -------------------------------------------------------------------------------- /msr/reader/rnn_reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Implementation of the RNN based DrQA reader.""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from . import layers 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | # ------------------------------------------------------------------------------ 17 | # Network 18 | # ------------------------------------------------------------------------------ 19 | 20 | class StackedRNNCell(nn.Module): 21 | """ 22 | impl of stacked rnn cell. 23 | """ 24 | 25 | def __init__(self, args, cell_type, in_size, h_size, num_layers=3): 26 | super(StackedRNNCell, self).__init__() 27 | self.cells = nn.ModuleList() 28 | self.args = args 29 | self.num_layers = num_layers 30 | self.cell_type = cell_type 31 | self.rnn = None 32 | if cell_type == 'lstm': 33 | self.rnn = nn.LSTMCell 34 | elif cell_type == 'gru': 35 | self.rnn = nn.GRUCell 36 | elif cell_type == 'rnn': 37 | self.rnn = nn.RNNCell 38 | if self.rnn is None: 39 | logger.info('[ Defaulting to LSTM cell_type ]') 40 | self.cell_type = 'lstm' 41 | self.rnn = nn.LSTMCell 42 | 43 | for _ in range(num_layers): 44 | if args.cuda: 45 | self.cells.append(self.rnn(in_size, h_size).cuda()) 46 | else: 47 | self.cells.append(self.rnn(in_size, h_size)) 48 | in_size = h_size 49 | 50 | def forward(self, x, hiddens): 51 | 52 | """ 53 | :param x: input embedding 54 | :param hiddens: an array of length num_layers, hiddens[j] contains (h_t, c_t) of jth layer 55 | :return: 56 | """ 57 | input = x 58 | hiddens_out = [] 59 | for l in range(self.num_layers): 60 | h_out = self.cells[l](input, hiddens[l]) 61 | hiddens_out.append(h_out) 62 | input = h_out[0] if self.cell_type == 'lstm' else h_out 63 | return hiddens_out 64 | 65 | 66 | class MultiStepReasoner(nn.Module): 67 | """ 68 | does multistep reasoning by taking the reader state and the previous query to generate a new query 69 | """ 70 | 71 | def __init__(self, args, input_dim, hidden_dim): 72 | super(MultiStepReasoner, self).__init__() 73 | self.args = args 74 | self.gru_cell = StackedRNNCell(args, 'gru', input_dim, hidden_dim, self.args.num_gru_layers) 75 | self.args.cuda = True 76 | self.linear1 = nn.Linear(self.args.num_gru_layers * hidden_dim, hidden_dim) 77 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 78 | if self.args.cuda: 79 | self.gru_cell = self.gru_cell.cuda() 80 | self.linear1 = self.linear1.cuda() 81 | self.linear2 = self.linear2.cuda() 82 | 83 | def forward(self, retriever_query, reader_state): 84 | hiddens = self.gru_cell(reader_state, [retriever_query for _ in range(self.args.num_gru_layers)]) 85 | hiddens = torch.cat(hiddens, dim=1) 86 | # pass it through a MLP 87 | out = self.linear2(F.relu(self.linear1(hiddens))) 88 | return out 89 | 90 | 91 | class RnnDocReader(nn.Module): 92 | RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 93 | 94 | def __init__(self, args, normalize=True): 95 | super(RnnDocReader, self).__init__() 96 | # Store config 97 | self.args = args 98 | 99 | # Word embeddings (+1 for padding) 100 | self.embedding = nn.Embedding(args.vocab_size, 101 | args.embedding_dim, 102 | padding_idx=0) 103 | 104 | # Projection for attention weighted question 105 | if args.use_qemb: 106 | self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) 107 | 108 | # Input size to RNN: word emb + question emb + manual features 109 | doc_input_size = args.embedding_dim + args.num_features 110 | if args.use_qemb: 111 | doc_input_size += args.embedding_dim 112 | 113 | # RNN document encoder 114 | self.doc_rnn = layers.StackedBRNN( 115 | input_size=doc_input_size, 116 | hidden_size=args.hidden_size, 117 | num_layers=args.doc_layers, 118 | dropout_rate=args.dropout_rnn, 119 | dropout_output=args.dropout_rnn_output, 120 | concat_layers=args.concat_rnn_layers, 121 | rnn_type=self.RNN_TYPES[args.rnn_type], 122 | padding=args.rnn_padding, 123 | ) 124 | 125 | # RNN question encoder 126 | self.question_rnn = layers.StackedBRNN( 127 | input_size=args.embedding_dim, 128 | hidden_size=args.hidden_size, 129 | num_layers=args.question_layers, 130 | dropout_rate=args.dropout_rnn, 131 | dropout_output=args.dropout_rnn_output, 132 | concat_layers=args.concat_rnn_layers, 133 | rnn_type=self.RNN_TYPES[args.rnn_type], 134 | padding=args.rnn_padding, 135 | ) 136 | 137 | self.question_hidden, self.doc_hiddens = None, None 138 | 139 | # Output sizes of rnn encoders 140 | doc_hidden_size = 2 * args.hidden_size 141 | question_hidden_size = 2 * args.hidden_size 142 | if args.concat_rnn_layers: 143 | doc_hidden_size *= args.doc_layers 144 | question_hidden_size *= args.question_layers 145 | 146 | self.args.doc_hidden_size = doc_hidden_size 147 | # Question merging 148 | if args.question_merge not in ['avg', 'self_attn']: 149 | raise NotImplementedError('merge_mode = %s' % args.merge_mode) 150 | if args.question_merge == 'self_attn': 151 | self.self_attn = layers.LinearSeqAttn(question_hidden_size) 152 | 153 | # this is for computing the reader state 154 | self.reader_state_self_attn = layers.BilinearSeqAttn( 155 | doc_hidden_size, 156 | question_hidden_size, 157 | normalize=normalize, 158 | ) 159 | 160 | # Bilinear attention for span start/end 161 | self.start_attn = layers.BilinearSeqAttn( 162 | doc_hidden_size, 163 | question_hidden_size, 164 | normalize=normalize, 165 | ) 166 | self.end_attn = layers.BilinearSeqAttn( 167 | doc_hidden_size, 168 | question_hidden_size, 169 | normalize=normalize, 170 | ) 171 | 172 | def forward(self, x1, x1_f, x1_mask, x2, x2_mask): 173 | """Inputs: 174 | x1 = document word indices [batch * len_d] 175 | x1_f = document word features indices [batch * len_d * nfeat] 176 | x1_mask = document padding mask [batch * len_d] 177 | x2 = question word indices [batch * len_q] 178 | x2_mask = question padding mask [batch * len_q] 179 | """ 180 | if self.question_hidden is None: # read the paras only once 181 | # Embed both document and question 182 | x1_emb = self.embedding(x1) 183 | x2_emb = self.embedding(x2) 184 | 185 | # Dropout on embeddings 186 | if self.args.dropout_emb > 0: 187 | x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, 188 | training=self.training) 189 | x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, 190 | training=self.training) 191 | 192 | # Form document encoding inputs 193 | drnn_input = [x1_emb] 194 | # import pdb 195 | # pdb.set_trace() 196 | # Add attention-weighted question representation 197 | if self.args.use_qemb and self.question_hidden is None: 198 | x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) 199 | drnn_input.append(x2_weighted_emb) 200 | 201 | # Add manual features 202 | if self.args.num_features > 0: 203 | drnn_input.append(x1_f) 204 | 205 | # Encode document with RNN 206 | self.doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) 207 | 208 | 209 | # Encode question with RNN + merge hiddens 210 | if self.question_hidden is None: 211 | question_hiddens = self.question_rnn(x2_emb, x2_mask) 212 | if self.args.question_merge == 'avg': 213 | q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) 214 | elif self.args.question_merge == 'self_attn': 215 | q_merge_weights = self.self_attn(question_hiddens, x2_mask) 216 | self.question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) 217 | 218 | # reader_state = torch.max(doc_hiddens, dim=1)[0] 219 | reader_state_weights = F.softmax(self.reader_state_self_attn(self.doc_hiddens, self.question_hidden, x1_mask), dim=1) 220 | reader_state = layers.weighted_avg(self.doc_hiddens, reader_state_weights) 221 | 222 | # Predict start and end positions 223 | start_scores = self.start_attn(self.doc_hiddens, self.question_hidden, x1_mask) 224 | end_scores = self.end_attn(self.doc_hiddens, self.question_hidden, x1_mask) 225 | return start_scores, end_scores, reader_state 226 | 227 | 228 | def reset(self): 229 | self.question_hidden = None 230 | self.doc_hiddens = None 231 | 232 | def get_current_reader_query(self): 233 | return self.question_hidden 234 | 235 | def set_current_reader_query(self, query): 236 | self.question_hidden = query -------------------------------------------------------------------------------- /paragraph_encoder/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | from smart_open import smart_open 5 | import subprocess 6 | import logging 7 | import uuid 8 | import time 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | USER = os.getenv('USER') 14 | 15 | SRC_DIR = "." 16 | 17 | def str2bool(v): 18 | return v.lower() in ('yes', 'true', 't', '1', 'y') 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.register('type', 'bool', str2bool) 24 | 25 | # Runtime environment 26 | parser.add_argument('--no_cuda', type='bool', default=False) 27 | parser.add_argument('--test_time_cuda', type='bool', default=True) 28 | parser.add_argument('--gpu', type=int, default=-1) 29 | parser.add_argument('--data_workers', type=int, default=5) 30 | # Basics 31 | parser.add_argument('--model_dir', type=str) 32 | parser.add_argument('--model_name', type=str, default=None) 33 | parser.add_argument('--checkpoint', type='bool', default=True) 34 | parser.add_argument('--random_seed', type=int, default=42) 35 | parser.add_argument('--uncased_question', type='bool', default=False) 36 | parser.add_argument('--uncased_doc', type='bool', default=False) 37 | parser.add_argument('--use_only_distant_sup', type=int, default=1, 38 | help="For hotpotQA use only string matching as supervision or the relevant paragraph supervision given by them") 39 | 40 | # vocab dir 41 | parser.add_argument('--create_vocab', type=int, default=0, 42 | help='Create vocab files and write them even if they exist') 43 | 44 | # Data files 45 | parser.add_argument('--src', type=str, default='triviaqa', help='triviaqa or squad or qangaroo') 46 | 47 | parser.add_argument('--domain', type=str, default='web-open', help='domain web/wiki/web-open') 48 | 49 | parser.add_argument('--min_para_len', type=int, default=0, 50 | help='Minimum length of a paragraph. Triviaqa has a lot of small paragraphs.') 51 | parser.add_argument('--max_para_len', type=int, default=400, 52 | help='Maximum length of a paragraph. Ignore really long paragraphs too.') 53 | parser.add_argument('--max_train_questions', type=int, default=600000, 54 | help='Maximum number of questions to train on. TriviaQA is huge.') 55 | parser.add_argument('--num_train_in_memory', type=int, default=50000, 56 | help='Maximum number of questions to keep in memory at once.') 57 | parser.add_argument('--data_dir', type=str) 58 | parser.add_argument('--embed_dir', type=str) 59 | parser.add_argument('--word_embeddings_file', type=str, default='data/embeddings/fasttext') 60 | parser.add_argument('--eval_file', type=str, default='web-dev.json') 61 | parser.add_argument('--verified_eval_file', type=str, default='verified-web-dev.json') 62 | parser.add_argument('--eval_only', type=int, default=0, help='Load a saved model and evaluate on dev set') 63 | parser.add_argument('--eval_correct_paras', type=int, default=0, 64 | help='eval by sending only the correct paras of the doc') 65 | parser.add_argument('--train_correct_paras', type=int, default=0, 66 | help='train a model on the correct paras of the doc') 67 | parser.add_argument('--eval_verified', type=int, default=1, help='eval the verified dev set after each partition.') 68 | 69 | # '--train_file', type = str, default = 'train.txt' 70 | parser.add_argument('--train_file_name', type=str, default='processed_train') 71 | parser.add_argument('--dev_file_name', type=str, default='processed_dev') 72 | parser.add_argument('--test_file_name', type=str, default='processed_test') 73 | 74 | parser.add_argument('--embedding_file', type=str, default='crawl-300d-2M.txt') 75 | parser.add_argument('--embedding_table', type=str, default='embedding_table.mdl') 76 | 77 | parser.add_argument('--para_mode', type=int, default=1, 78 | help='represent a doc as a list of paras instead of a huge list of words') 79 | parser.add_argument('--small', type=int, default=0, 80 | help='small dataset') 81 | 82 | parser.add_argument('--eval_only_para_clf', type=int, default=0, help='Load a saved model and evaluate on dev set') 83 | parser.add_argument('--pretrained_words', type=int, default=1) 84 | parser.add_argument('--batch_size', type=int, default=128) 85 | parser.add_argument('--neg_sample', type=float, default=1.0) 86 | parser.add_argument('--test', type=int, default=0) 87 | 88 | parser.add_argument('--test_batch_size', type=int, default=128) 89 | parser.add_argument('--use_tfidf_retriever', type=int, default=0, help='An additional tf-idf retriever to weed out paras') 90 | parser.add_argument('--num_topk_paras', type=int, default=5, help='Number of paras to choose from') 91 | parser.add_argument('--save_para_clf_output', type=int, default=0, 92 | help='Save the top-k para returned by the para classifier') 93 | parser.add_argument('--save_para_clf_output_dir', type=str, default=None, 94 | help='Path where to save') 95 | 96 | 97 | parser.add_argument('--pretrained', type=str, default=None, help='Pre-trained model') 98 | 99 | parser.add_argument('--use_qemb', 100 | type='bool', 101 | default=True, 102 | help='Whether to use weighted question embeddings' 103 | ) 104 | parser.add_argument( 105 | '--use_in_question', 106 | type='bool', 107 | default=True, 108 | help='Whether to use in_question features' 109 | ) 110 | parser.add_argument( 111 | '--use_pos', 112 | type='bool', 113 | default=False, 114 | help='Whether to use pos features' 115 | ) 116 | parser.add_argument( 117 | '--use_ner', 118 | type='bool', 119 | default=False, 120 | help='Whether to use ner features' 121 | ) 122 | parser.add_argument( 123 | '--use_lemma', 124 | type='bool', 125 | default=False, 126 | help='Whether to use lemma features' 127 | ) 128 | parser.add_argument( 129 | '--use_tf', 130 | type='bool', 131 | default=False, 132 | help='Whether to use tf features' 133 | ) 134 | parser.add_argument( 135 | '--unlabeled', 136 | type='bool', 137 | default=False, 138 | help='Data is unlabeled (prediction only)' 139 | ) 140 | parser.add_argument( 141 | '--use_distant_supervision', 142 | type='bool', 143 | default=True, 144 | help='Whether to gather labels by distant supervision' 145 | ) 146 | parser.add_argument( 147 | '--use_single_answer_alias', 148 | type='bool', 149 | default=False, 150 | help='Whether to use one alias of the answer i.e. just use "Obama" or all aliases for "Obama"' 151 | ) 152 | parser.add_argument( 153 | '--fix_embeddings', 154 | type='bool', 155 | default=True, 156 | help='Keep word embeddings fixed (pretrained)' 157 | ) 158 | parser.add_argument( 159 | '--paraclf_hidden_size', type=int, default=300, help='Hidden size of paragraph classifier', 160 | ) 161 | parser.add_argument( 162 | '--num_epochs', 163 | type=int, 164 | default=100, 165 | help='Number of epochs (default 40)' 166 | ) 167 | parser.add_argument( 168 | '--display_iter', 169 | type=int, 170 | default=25, 171 | help='Print train error after every \ 172 | epoches (default 25)' 173 | ) 174 | parser.add_argument( 175 | '--dropout_emb', 176 | type=float, 177 | default=0.1, 178 | help='Dropout rate for word embeddings' 179 | ) 180 | 181 | parser.add_argument( 182 | '--optimizer', 183 | type=str, 184 | default='adamax', 185 | help='Optimizer: sgd or adamax (default)' 186 | ) 187 | parser.add_argument( 188 | '--learning_rate', 189 | '-lr', 190 | type=float, 191 | default=0.1, 192 | help='Learning rate for SGD (default 0.1)' 193 | ) 194 | parser.add_argument( 195 | '--grad_clipping', 196 | type=float, 197 | default=10, 198 | help='Gradient clipping (default 10.0)' 199 | ) 200 | parser.add_argument( 201 | '--use_annealing_schedule', 202 | type='bool', 203 | default=True, 204 | help='Whether to use an annealing schedule or not.' 205 | ) 206 | parser.add_argument( 207 | '--weight_decay', 208 | type=float, 209 | default=0, 210 | help='Weight decay (default 0)' 211 | ) 212 | parser.add_argument( 213 | '--momentum', type=float, default=0, help='Momentum (default 0)' 214 | ) 215 | args = parser.parse_args() 216 | 217 | if len(args.embedding_file) == 0: 218 | args.embedding_file = None 219 | else: 220 | args.embedding_file = os.path.join(args.word_embeddings_file, args.embedding_file) 221 | if not os.path.isfile(args.embedding_file): 222 | raise IOError('No such file: %s' % args.embedding_file) 223 | args.embedding_table = os.path.join(args.data_dir, args.src, "embeddings", args.domain, args.embedding_table) 224 | args.final_model_dir = time.strftime("%Y%m%d-") + str(uuid.uuid4())[:8] 225 | args.model_dir = os.path.join(args.model_dir, args.final_model_dir) 226 | args.log_file = os.path.join(args.model_dir, 'log.txt') 227 | args.model_file = os.path.join(args.model_dir, 'model.mdl') 228 | args.para_mode = (args.para_mode == 1) 229 | args.eval_only_para_clf = (args.eval_only_para_clf == 1) 230 | args.eval_verified = (args.eval_verified == 1) 231 | args.use_tfidf_retriever = (args.use_tfidf_retriever == 1) 232 | args.pretrained_words = (args.pretrained_words == 1) 233 | args.use_only_distant_sup = (args.use_only_distant_sup == 1) 234 | 235 | args.vocab_dir = os.path.join(args.data_dir , args.src, 'vocab', args.domain +"/") 236 | if not os.path.exists(args.vocab_dir+'tok2ind.json'): 237 | args.create_vocab = True 238 | 239 | subprocess.call(['mkdir', '-p', args.model_dir]) 240 | 241 | # Embeddings options 242 | if args.embedding_file is not None: 243 | with smart_open(args.embedding_file) as f: 244 | dim = len(f.readline().decode('utf-8').strip().split(' ')) - 1 245 | args.embedding_dim = dim 246 | elif args.embedding_dim is None: 247 | raise RuntimeError( 248 | 'Either embedding_file or embedding_dim ' 249 | 'needs to be specified.' 250 | ) 251 | 252 | return args -------------------------------------------------------------------------------- /msr/reader/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """DrQA reader utilities.""" 8 | 9 | import json 10 | import time 11 | import logging 12 | import string 13 | import regex as re 14 | import torch 15 | import os 16 | from tqdm import tqdm 17 | 18 | from collections import Counter 19 | from .data import Dictionary 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | # ------------------------------------------------------------------------------ 25 | # Data loading 26 | # ------------------------------------------------------------------------------ 27 | 28 | 29 | def load_data(args, filename, skip_no_answer=False): 30 | """Load examples from preprocessed file. 31 | One example per line, JSON encoded. 32 | """ 33 | # Load JSON lines 34 | with open(filename) as f: 35 | examples = [json.loads(line) for line in f] 36 | 37 | # Make case insensitive? 38 | if args.uncased_question or args.uncased_doc: 39 | for ex in examples: 40 | if args.uncased_question: 41 | ex['question'] = [w.lower() for w in ex['question']] 42 | if args.uncased_doc: 43 | ex['document'] = [w.lower() for w in ex['document']] 44 | 45 | # Skip unparsed (start/end) examples 46 | if skip_no_answer: 47 | examples = [ex for ex in examples if len(ex['answers']) > 0] 48 | 49 | return examples 50 | 51 | 52 | def load_text(filename): 53 | """Load the paragraphs only of a SQuAD dataset. Store as qid -> text.""" 54 | # Load JSON file 55 | with open(filename) as f: 56 | examples = json.load(f)['data'] 57 | 58 | texts = {} 59 | for article in examples: 60 | for paragraph in article['paragraphs']: 61 | for qa in paragraph['qas']: 62 | texts[qa['id']] = paragraph['context'] 63 | return texts 64 | 65 | 66 | def load_answers(filename): 67 | """Load the answers only of a SQuAD dataset. Store as qid -> [answers].""" 68 | # Load JSON file 69 | with open(filename) as f: 70 | examples = json.load(f)['data'] 71 | 72 | ans = {} 73 | for article in examples: 74 | for paragraph in article['paragraphs']: 75 | for qa in paragraph['qas']: 76 | ans[qa['id']] = list(map(lambda x: x['text'], qa['answers'])) 77 | return ans 78 | 79 | 80 | # ------------------------------------------------------------------------------ 81 | # Dictionary building 82 | # ------------------------------------------------------------------------------ 83 | 84 | 85 | def index_embedding_words(embedding_file): 86 | """Put all the words in embedding_file into a set.""" 87 | words = set() 88 | with open(embedding_file) as f: 89 | for line in f: 90 | w = Dictionary.normalize(line.rstrip().split(' ')[0]) 91 | words.add(w) 92 | return words 93 | 94 | 95 | def load_words(args, examples): 96 | """Iterate and index all the words in examples (documents + questions).""" 97 | def _insert(iterable): 98 | for w in iterable: 99 | w = Dictionary.normalize(w) 100 | if valid_words and w not in valid_words: 101 | continue 102 | words.add(w) 103 | 104 | if args.restrict_vocab and args.embedding_file: 105 | logger.info('Restricting to words in %s' % args.embedding_file) 106 | valid_words = index_embedding_words(args.embedding_file) 107 | logger.info('Num words in set = %d' % len(valid_words)) 108 | else: 109 | valid_words = None 110 | 111 | words = set() 112 | qids = list(examples.questions.keys()) 113 | for qid in tqdm(qids): 114 | question = examples.questions[qid] 115 | question_text = question.text 116 | context_text = [] 117 | pids_for_question = question.pids 118 | paras = [examples.paragraphs[pid] for pid in pids_for_question] 119 | for p in paras: 120 | context_text += p.text 121 | _insert(question_text) 122 | _insert(context_text) 123 | return words 124 | 125 | 126 | def build_word_dict(args, train_exs, dev_exs): 127 | """Return a dictionary from question and document words in 128 | provided examples. 129 | """ 130 | word_dict = Dictionary(args) 131 | if not args.create_vocab: 132 | logger.info('[ Reading vocab files from {}]'.format(args.vocab_dir)) 133 | word_dict.tok2ind = json.load(open(os.path.join(args.vocab_dir, 'tok2ind.json'))) 134 | word_dict.ind2tok = json.load(open(os.path.join(args.vocab_dir, 'ind2tok.json'))) 135 | return word_dict # 136 | # return the cached one 137 | for w in load_words(args, train_exs): 138 | word_dict.add(w) 139 | for w in load_words(args, dev_exs): 140 | word_dict.add(w) 141 | 142 | # save so we dont have to make it from scratch again 143 | word_dict.save() 144 | return word_dict 145 | 146 | 147 | def top_question_words(args, examples, word_dict): 148 | """Count and return the most common question words in provided examples.""" 149 | word_count = Counter() 150 | for ex in examples: 151 | for w in ex['question']: 152 | w = Dictionary.normalize(w) 153 | if w in word_dict: 154 | word_count.update([w]) 155 | return word_count.most_common(args.tune_partial) 156 | 157 | 158 | def build_feature_dict(args, examples): 159 | """Index features (one hot) from fields in examples and options.""" 160 | if not args.create_vocab: 161 | return json.load(open(os.path.join(args.vocab_dir, 'feat_dict.json'))) 162 | 163 | def _insert(feature): 164 | if feature not in feature_dict: 165 | feature_dict[feature] = len(feature_dict) 166 | 167 | feature_dict = {} 168 | 169 | # Exact match features 170 | if args.use_in_question: 171 | _insert('in_question') 172 | _insert('in_question_uncased') 173 | with open(os.path.join(args.vocab_dir, 'feat_dict.json'), "w") as feat_dict_fp: 174 | json.dump(feature_dict, feat_dict_fp) 175 | return feature_dict 176 | 177 | 178 | # ------------------------------------------------------------------------------ 179 | # Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset. 180 | # ------------------------------------------------------------------------------ 181 | 182 | 183 | def normalize_answer(s): 184 | """Lower text and remove punctuation, articles and extra whitespace.""" 185 | def remove_articles(text): 186 | return re.sub(r'\b(a|an|the)\b', ' ', text) 187 | 188 | def white_space_fix(text): 189 | return ' '.join(text.split()) 190 | 191 | def remove_punc(text): 192 | exclude = set(string.punctuation) 193 | return ''.join(ch for ch in text if ch not in exclude) 194 | 195 | def lower(text): 196 | return text.lower() 197 | 198 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 199 | 200 | 201 | def f1_score(prediction, ground_truth): 202 | """Compute the geometric mean of precision and recall for answer tokens.""" 203 | prediction_tokens = normalize_answer(prediction).split() 204 | ground_truth_tokens = normalize_answer(ground_truth).split() 205 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 206 | num_same = sum(common.values()) 207 | if num_same == 0: 208 | return 0 209 | precision = 1.0 * num_same / len(prediction_tokens) 210 | recall = 1.0 * num_same / len(ground_truth_tokens) 211 | f1 = (2 * precision * recall) / (precision + recall) 212 | return f1 213 | 214 | 215 | def exact_match_score(prediction, ground_truth): 216 | """Check if the prediction is a (soft) exact match with the ground truth.""" 217 | return normalize_answer(prediction) == normalize_answer(ground_truth) 218 | 219 | 220 | def regex_match_score(prediction, pattern): 221 | """Check if the prediction matches the given regular expression.""" 222 | try: 223 | compiled = re.compile( 224 | pattern, 225 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 226 | ) 227 | except BaseException: 228 | logger.warn('Regular expression failed to compile: %s' % pattern) 229 | return False 230 | return compiled.match(prediction) is not None 231 | 232 | 233 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 234 | """Given a prediction and multiple valid answers, return the score of 235 | the best prediction-answer_n pair given a metric function. 236 | """ 237 | scores_for_ground_truths = [] 238 | for ground_truth in ground_truths: 239 | score = metric_fn(prediction, ground_truth) 240 | scores_for_ground_truths.append(score) 241 | return max(scores_for_ground_truths) 242 | 243 | 244 | def logsumexp(inputs, dim=None, keepdim=False): 245 | """Numerically stable logsumexp. 246 | 247 | Args: 248 | inputs: A Variable with any shape. 249 | dim: An integer. 250 | keepdim: A boolean. 251 | 252 | Returns: 253 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). 254 | """ 255 | # For a 1-D array x (any array along a single dimension), 256 | # log sum exp(x) = s + log sum exp(x - s) 257 | # with s = max(x) being a common choice. 258 | if dim is None: 259 | inputs = inputs.view(-1) 260 | dim = 0 261 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 262 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 263 | if not keepdim: 264 | outputs = outputs.squeeze(dim) 265 | return outputs 266 | 267 | 268 | # ------------------------------------------------------------------------------ 269 | # Utility classes 270 | # ------------------------------------------------------------------------------ 271 | 272 | 273 | class AverageMeter(object): 274 | """Computes and stores the average and current value.""" 275 | 276 | def __init__(self): 277 | self.reset() 278 | 279 | def reset(self): 280 | self.val = 0 281 | self.avg = 0 282 | self.sum = 0 283 | self.count = 0 284 | 285 | def update(self, val, n=1): 286 | self.val = val 287 | self.sum += val * n 288 | self.count += n 289 | self.avg = self.sum / self.count 290 | 291 | 292 | class Timer(object): 293 | """Computes elapsed time.""" 294 | 295 | def __init__(self): 296 | self.running = True 297 | self.total = 0 298 | self.start = time.time() 299 | 300 | def reset(self): 301 | self.running = True 302 | self.total = 0 303 | self.start = time.time() 304 | return self 305 | 306 | def resume(self): 307 | if not self.running: 308 | self.running = True 309 | self.start = time.time() 310 | return self 311 | 312 | def stop(self): 313 | if self.running: 314 | self.running = False 315 | self.total += time.time() - self.start 316 | return self 317 | 318 | def time(self): 319 | if self.running: 320 | return self.total + time.time() - self.start 321 | return self.total 322 | -------------------------------------------------------------------------------- /paragraph_encoder/model/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Definitions of model layers/NN modules""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | 15 | # ------------------------------------------------------------------------------ 16 | # Modules 17 | # ------------------------------------------------------------------------------ 18 | 19 | 20 | class StackedBRNN(nn.Module): 21 | """Stacked Bi-directional RNNs. 22 | Differs from standard PyTorch library in that it has the option to save 23 | and concat the hidden states between layers. (i.e. the output hidden size 24 | for each sequence input is num_layers * hidden_size). 25 | """ 26 | 27 | def __init__(self, input_size, hidden_size, num_layers, 28 | dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, 29 | concat_layers=False, padding=False): 30 | super(StackedBRNN, self).__init__() 31 | self.padding = padding 32 | self.dropout_output = dropout_output 33 | self.dropout_rate = dropout_rate 34 | self.num_layers = num_layers 35 | self.concat_layers = concat_layers 36 | self.rnns = nn.ModuleList() 37 | for i in range(num_layers): 38 | input_size = input_size if i == 0 else 2 * hidden_size 39 | self.rnns.append(rnn_type(input_size, hidden_size, 40 | num_layers=1, 41 | bidirectional=True)) 42 | 43 | def forward(self, x, x_mask): 44 | """Encode either padded or non-padded sequences. 45 | Can choose to either handle or ignore variable length sequences. 46 | Always handle padding in eval. 47 | Args: 48 | x: batch * len * hdim 49 | x_mask: batch * len (1 for padding, 0 for true) 50 | Output: 51 | x_encoded: batch * len * hdim_encoded 52 | """ 53 | if x_mask.data.sum() == 0: 54 | # No padding necessary. 55 | output = self._forward_unpadded(x, x_mask) 56 | elif self.padding or not self.training: 57 | # Pad if we care or if its during eval. 58 | output = self._forward_padded(x, x_mask) 59 | else: 60 | # We don't care. 61 | output = self._forward_unpadded(x, x_mask) 62 | 63 | return output.contiguous() 64 | 65 | def _forward_unpadded(self, x, x_mask): 66 | """Faster encoding that ignores any padding.""" 67 | # Transpose batch and sequence dims 68 | x = x.transpose(0, 1) 69 | 70 | # Encode all layers 71 | outputs = [x] 72 | for i in range(self.num_layers): 73 | rnn_input = outputs[-1] 74 | 75 | # Apply dropout to hidden input 76 | if self.dropout_rate > 0: 77 | rnn_input = F.dropout(rnn_input, 78 | p=self.dropout_rate, 79 | training=self.training) 80 | # Forward 81 | rnn_output = self.rnns[i](rnn_input)[0] 82 | outputs.append(rnn_output) 83 | 84 | # Concat hidden layers 85 | if self.concat_layers: 86 | output = torch.cat(outputs[1:], 2) 87 | else: 88 | output = outputs[-1] 89 | 90 | # Transpose back 91 | output = output.transpose(0, 1) 92 | 93 | # Dropout on output layer 94 | if self.dropout_output and self.dropout_rate > 0: 95 | output = F.dropout(output, 96 | p=self.dropout_rate, 97 | training=self.training) 98 | return output 99 | 100 | def _forward_padded(self, x, x_mask): 101 | """Slower (significantly), but more precise, encoding that handles 102 | padding. 103 | """ 104 | # Compute sorted sequence lengths 105 | lengths = x_mask.data.eq(0).long().sum(1).squeeze() 106 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 107 | _, idx_unsort = torch.sort(idx_sort, dim=0) 108 | 109 | lengths = list(lengths[idx_sort]) 110 | idx_sort = Variable(idx_sort) 111 | idx_unsort = Variable(idx_unsort) 112 | 113 | # Sort x 114 | x = x.index_select(0, idx_sort) 115 | 116 | # Transpose batch and sequence dims 117 | x = x.transpose(0, 1) 118 | 119 | # Pack it up 120 | rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) 121 | 122 | # Encode all layers 123 | outputs = [rnn_input] 124 | for i in range(self.num_layers): 125 | rnn_input = outputs[-1] 126 | 127 | # Apply dropout to input 128 | if self.dropout_rate > 0: 129 | dropout_input = F.dropout(rnn_input.data, 130 | p=self.dropout_rate, 131 | training=self.training) 132 | rnn_input = nn.utils.rnn.PackedSequence(dropout_input, 133 | rnn_input.batch_sizes) 134 | outputs.append(self.rnns[i](rnn_input)[0]) 135 | 136 | # Unpack everything 137 | for i, o in enumerate(outputs[1:], 1): 138 | outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] 139 | 140 | # Concat hidden layers or take final 141 | if self.concat_layers: 142 | output = torch.cat(outputs[1:], 2) 143 | else: 144 | output = outputs[-1] 145 | 146 | # Transpose and unsort 147 | output = output.transpose(0, 1) 148 | output = output.index_select(0, idx_unsort) 149 | 150 | # Pad up to original batch sequence length 151 | if output.size(1) != x_mask.size(1): 152 | padding = torch.zeros(output.size(0), 153 | x_mask.size(1) - output.size(1), 154 | output.size(2)).type(output.data.type()) 155 | output = torch.cat([output, Variable(padding)], 1) 156 | 157 | # Dropout on output layer 158 | if self.dropout_output and self.dropout_rate > 0: 159 | output = F.dropout(output, 160 | p=self.dropout_rate, 161 | training=self.training) 162 | return output 163 | 164 | 165 | class SeqAttnMatch(nn.Module): 166 | """Given sequences X and Y, match sequence Y to each element in X. 167 | * o_i = sum(alpha_j * y_j) for i in X 168 | * alpha_j = softmax(y_j * x_i) 169 | """ 170 | 171 | def __init__(self, input_size, identity=False): 172 | super(SeqAttnMatch, self).__init__() 173 | if not identity: 174 | self.linear = nn.Linear(input_size, input_size) 175 | else: 176 | self.linear = None 177 | 178 | def forward(self, x, y, y_mask): 179 | """ 180 | Args: 181 | x: batch * len1 * hdim 182 | y: batch * len2 * hdim 183 | y_mask: batch * len2 (1 for padding, 0 for true) 184 | Output: 185 | matched_seq: batch * len1 * hdim 186 | """ 187 | # Project vectors 188 | if self.linear: 189 | x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) 190 | x_proj = F.relu(x_proj) 191 | y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) 192 | y_proj = F.relu(y_proj) 193 | else: 194 | x_proj = x 195 | y_proj = y 196 | 197 | # Compute scores 198 | scores = x_proj.bmm(y_proj.transpose(2, 1)) 199 | 200 | # Mask padding 201 | y_mask = y_mask.unsqueeze(1).expand(scores.size()) 202 | scores.data.masked_fill_(y_mask.data, -float('inf')) 203 | 204 | # Normalize with softmax 205 | alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) 206 | alpha = alpha_flat.view(-1, x.size(1), y.size(1)) 207 | 208 | # Take weighted average 209 | matched_seq = alpha.bmm(y) 210 | return matched_seq 211 | 212 | 213 | class BilinearSeqAttn(nn.Module): 214 | """A bilinear attention layer over a sequence X w.r.t y: 215 | * o_i = softmax(x_i'Wy) for x_i in X. 216 | Optionally don't normalize output weights. 217 | """ 218 | 219 | def __init__(self, x_size, y_size, identity=False, normalize=True): 220 | super(BilinearSeqAttn, self).__init__() 221 | self.normalize = normalize 222 | 223 | # If identity is true, we just use a dot product without transformation. 224 | if not identity: 225 | self.linear = nn.Linear(y_size, x_size) 226 | else: 227 | self.linear = None 228 | 229 | def forward(self, x, y, x_mask): 230 | """ 231 | Args: 232 | x: batch * len * hdim1 233 | y: batch * hdim2 234 | x_mask: batch * len (1 for padding, 0 for true) 235 | Output: 236 | alpha = batch * len 237 | """ 238 | Wy = self.linear(y) if self.linear is not None else y 239 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) 240 | xWy.data.masked_fill_(x_mask.data, -float('inf')) 241 | if self.normalize: 242 | if self.training: 243 | # In training we output log-softmax for NLL 244 | alpha = F.log_softmax(xWy, dim=-1) 245 | else: 246 | # ...Otherwise 0-1 probabilities 247 | alpha = F.softmax(xWy, dim=-1) 248 | else: 249 | alpha = xWy.exp() 250 | return alpha 251 | 252 | 253 | class LinearSeqAttn(nn.Module): 254 | """Self attention over a sequence: 255 | * o_i = softmax(Wx_i) for x_i in X. 256 | """ 257 | 258 | def __init__(self, input_size): 259 | super(LinearSeqAttn, self).__init__() 260 | self.linear = nn.Linear(input_size, 1) 261 | 262 | def forward(self, x, x_mask): 263 | """ 264 | Args: 265 | x: batch * len * hdim 266 | x_mask: batch * len (1 for padding, 0 for true) 267 | Output: 268 | alpha: batch * len 269 | """ 270 | x_flat = x.view(-1, x.size(-1)) 271 | scores = self.linear(x_flat).view(x.size(0), x.size(1)) 272 | scores.data.masked_fill_(x_mask.data, -float('inf')) 273 | alpha = F.softmax(scores, dim=-1) 274 | return alpha 275 | 276 | 277 | # ------------------------------------------------------------------------------ 278 | # Functional 279 | # ------------------------------------------------------------------------------ 280 | 281 | 282 | def uniform_weights(x, x_mask): 283 | """Return uniform weights over non-masked x (a sequence of vectors). 284 | Args: 285 | x: batch * len * hdim 286 | x_mask: batch * len (1 for padding, 0 for true) 287 | Output: 288 | x_avg: batch * hdim 289 | """ 290 | alpha = Variable(torch.ones(x.size(0), x.size(1))) 291 | if x.data.is_cuda: 292 | alpha = alpha.cuda() 293 | alpha = alpha * x_mask.eq(0).float() 294 | alpha = alpha / alpha.sum(1).expand(alpha.size()) 295 | return alpha 296 | 297 | 298 | def weighted_avg(x, weights): 299 | """Return a weighted average of x (a sequence of vectors). 300 | Args: 301 | x: batch * len * hdim 302 | weights: batch * len, sum(dim = 1) = 1 303 | Output: 304 | x_avg: batch * hdim 305 | """ 306 | return weights.unsqueeze(1).bmm(x).squeeze(1) -------------------------------------------------------------------------------- /msr/reader/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Definitions of model layers/NN modules""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | 15 | # ------------------------------------------------------------------------------ 16 | # Modules 17 | # ------------------------------------------------------------------------------ 18 | 19 | 20 | class StackedBRNN(nn.Module): 21 | """Stacked Bi-directional RNNs. 22 | 23 | Differs from standard PyTorch library in that it has the option to save 24 | and concat the hidden states between layers. (i.e. the output hidden size 25 | for each sequence input is num_layers * hidden_size). 26 | """ 27 | 28 | def __init__(self, input_size, hidden_size, num_layers, 29 | dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, 30 | concat_layers=False, padding=False): 31 | super(StackedBRNN, self).__init__() 32 | self.padding = padding 33 | self.dropout_output = dropout_output 34 | self.dropout_rate = dropout_rate 35 | self.num_layers = num_layers 36 | self.concat_layers = concat_layers 37 | self.rnns = nn.ModuleList() 38 | for i in range(num_layers): 39 | input_size = input_size if i == 0 else 2 * hidden_size 40 | self.rnns.append(rnn_type(input_size, hidden_size, 41 | num_layers=1, 42 | bidirectional=True)) 43 | 44 | def forward(self, x, x_mask): 45 | """Encode either padded or non-padded sequences. 46 | 47 | Can choose to either handle or ignore variable length sequences. 48 | Always handle padding in eval. 49 | 50 | Args: 51 | x: batch * len * hdim 52 | x_mask: batch * len (1 for padding, 0 for true) 53 | Output: 54 | x_encoded: batch * len * hdim_encoded 55 | """ 56 | if x_mask.data.sum() == 0: 57 | # No padding necessary. 58 | output = self._forward_unpadded(x, x_mask) 59 | elif self.padding or not self.training: 60 | # Pad if we care or if its during eval. 61 | output = self._forward_padded(x, x_mask) 62 | else: 63 | # We don't care. 64 | output = self._forward_unpadded(x, x_mask) 65 | 66 | return output.contiguous() 67 | 68 | def _forward_unpadded(self, x, x_mask): 69 | """Faster encoding that ignores any padding.""" 70 | # Transpose batch and sequence dims 71 | x = x.transpose(0, 1) 72 | 73 | # Encode all layers 74 | outputs = [x] 75 | for i in range(self.num_layers): 76 | rnn_input = outputs[-1] 77 | 78 | # Apply dropout to hidden input 79 | if self.dropout_rate > 0: 80 | rnn_input = F.dropout(rnn_input, 81 | p=self.dropout_rate, 82 | training=self.training) 83 | # Forward 84 | rnn_output = self.rnns[i](rnn_input)[0] 85 | outputs.append(rnn_output) 86 | 87 | # Concat hidden layers 88 | if self.concat_layers: 89 | output = torch.cat(outputs[1:], 2) 90 | else: 91 | output = outputs[-1] 92 | 93 | # Transpose back 94 | output = output.transpose(0, 1) 95 | 96 | # Dropout on output layer 97 | if self.dropout_output and self.dropout_rate > 0: 98 | output = F.dropout(output, 99 | p=self.dropout_rate, 100 | training=self.training) 101 | return output 102 | 103 | def _forward_padded(self, x, x_mask): 104 | """Slower (significantly), but more precise, encoding that handles 105 | padding. 106 | """ 107 | # Compute sorted sequence lengths 108 | lengths = x_mask.data.eq(0).long().sum(1).squeeze() 109 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 110 | _, idx_unsort = torch.sort(idx_sort, dim=0) 111 | 112 | lengths = list(lengths[idx_sort]) 113 | idx_sort = Variable(idx_sort) 114 | idx_unsort = Variable(idx_unsort) 115 | 116 | # Sort x 117 | x = x.index_select(0, idx_sort) 118 | 119 | # Transpose batch and sequence dims 120 | x = x.transpose(0, 1) 121 | 122 | # Pack it up 123 | rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) 124 | 125 | # Encode all layers 126 | outputs = [rnn_input] 127 | for i in range(self.num_layers): 128 | rnn_input = outputs[-1] 129 | 130 | # Apply dropout to input 131 | if self.dropout_rate > 0: 132 | dropout_input = F.dropout(rnn_input.data, 133 | p=self.dropout_rate, 134 | training=self.training) 135 | rnn_input = nn.utils.rnn.PackedSequence(dropout_input, 136 | rnn_input.batch_sizes) 137 | outputs.append(self.rnns[i](rnn_input)[0]) 138 | 139 | # Unpack everything 140 | for i, o in enumerate(outputs[1:], 1): 141 | outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] 142 | 143 | # Concat hidden layers or take final 144 | if self.concat_layers: 145 | output = torch.cat(outputs[1:], 2) 146 | else: 147 | output = outputs[-1] 148 | 149 | # Transpose and unsort 150 | output = output.transpose(0, 1) 151 | output = output.index_select(0, idx_unsort) 152 | 153 | # Pad up to original batch sequence length 154 | if output.size(1) != x_mask.size(1): 155 | padding = torch.zeros(output.size(0), 156 | x_mask.size(1) - output.size(1), 157 | output.size(2)).type(output.data.type()) 158 | output = torch.cat([output, Variable(padding)], 1) 159 | 160 | # Dropout on output layer 161 | if self.dropout_output and self.dropout_rate > 0: 162 | output = F.dropout(output, 163 | p=self.dropout_rate, 164 | training=self.training) 165 | return output 166 | 167 | 168 | class SeqAttnMatch(nn.Module): 169 | """Given sequences X and Y, match sequence Y to each element in X. 170 | 171 | * o_i = sum(alpha_j * y_j) for i in X 172 | * alpha_j = softmax(y_j * x_i) 173 | """ 174 | 175 | def __init__(self, input_size, identity=False): 176 | super(SeqAttnMatch, self).__init__() 177 | if not identity: 178 | self.linear = nn.Linear(input_size, input_size) 179 | else: 180 | self.linear = None 181 | 182 | def forward(self, x, y, y_mask): 183 | """ 184 | Args: 185 | x: batch * len1 * hdim 186 | y: batch * len2 * hdim 187 | y_mask: batch * len2 (1 for padding, 0 for true) 188 | Output: 189 | matched_seq: batch * len1 * hdim 190 | """ 191 | # Project vectors 192 | if self.linear: 193 | x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) 194 | x_proj = F.relu(x_proj) 195 | y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) 196 | y_proj = F.relu(y_proj) 197 | else: 198 | x_proj = x 199 | y_proj = y 200 | 201 | # Compute scores 202 | scores = x_proj.bmm(y_proj.transpose(2, 1)) 203 | 204 | # Mask padding 205 | y_mask = y_mask.unsqueeze(1).expand(scores.size()) 206 | scores.data.masked_fill_(y_mask.data, -float('inf')) 207 | 208 | # Normalize with softmax 209 | alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) 210 | alpha = alpha_flat.view(-1, x.size(1), y.size(1)) 211 | 212 | # Take weighted average 213 | matched_seq = alpha.bmm(y) 214 | return matched_seq 215 | 216 | 217 | class BilinearSeqAttn(nn.Module): 218 | """A bilinear attention layer over a sequence X w.r.t y: 219 | 220 | * o_i = softmax(x_i'Wy) for x_i in X. 221 | 222 | Optionally don't normalize output weights. 223 | """ 224 | 225 | def __init__(self, x_size, y_size, identity=False, normalize=True): 226 | super(BilinearSeqAttn, self).__init__() 227 | self.normalize = normalize 228 | 229 | # If identity is true, we just use a dot product without transformation. 230 | if not identity: 231 | self.linear = nn.Linear(y_size, x_size) 232 | else: 233 | self.linear = None 234 | 235 | def forward(self, x, y, x_mask): 236 | """ 237 | Args: 238 | x: batch * len * hdim1 239 | y: batch * hdim2 240 | x_mask: batch * len (1 for padding, 0 for true) 241 | Output: 242 | alpha = batch * len 243 | """ 244 | Wy = self.linear(y) if self.linear is not None else y 245 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) 246 | xWy.data.masked_fill_(x_mask.data, -1e32) 247 | if self.normalize: 248 | if self.training: 249 | alpha = xWy # removed logsoftmax from here, we will do it outside (possibly on multiple paras) 250 | else: 251 | # ...Otherwise 0-1 probabilities 252 | # alpha = F.softmax(xWy) # dont do softmax 253 | alpha = xWy 254 | else: 255 | alpha = xWy.exp() 256 | return alpha 257 | 258 | 259 | class LinearSeqAttn(nn.Module): 260 | """Self attention over a sequence: 261 | 262 | * o_i = softmax(Wx_i) for x_i in X. 263 | """ 264 | 265 | def __init__(self, input_size): 266 | super(LinearSeqAttn, self).__init__() 267 | self.linear = nn.Linear(input_size, 1) 268 | 269 | def forward(self, x, x_mask): 270 | """ 271 | Args: 272 | x: batch * len * hdim 273 | x_mask: batch * len (1 for padding, 0 for true) 274 | Output: 275 | alpha: batch * len 276 | """ 277 | x_flat = x.view(-1, x.size(-1)) 278 | scores = self.linear(x_flat).view(x.size(0), x.size(1)) 279 | scores.data.masked_fill_(x_mask.data, -float('inf')) 280 | alpha = F.softmax(scores) 281 | return alpha 282 | 283 | 284 | # ------------------------------------------------------------------------------ 285 | # Functional 286 | # ------------------------------------------------------------------------------ 287 | 288 | 289 | def uniform_weights(x, x_mask): 290 | """Return uniform weights over non-masked x (a sequence of vectors). 291 | 292 | Args: 293 | x: batch * len * hdim 294 | x_mask: batch * len (1 for padding, 0 for true) 295 | Output: 296 | x_avg: batch * hdim 297 | """ 298 | alpha = Variable(torch.ones(x.size(0), x.size(1))) 299 | if x.data.is_cuda: 300 | alpha = alpha.cuda() 301 | alpha = alpha * x_mask.eq(0).float() 302 | alpha = alpha / alpha.sum(1).expand(alpha.size()) 303 | return alpha 304 | 305 | 306 | def weighted_avg(x, weights): 307 | """Return a weighted average of x (a sequence of vectors). 308 | 309 | Args: 310 | x: batch * len * hdim 311 | weights: batch * len, sum(dim = 1) = 1 312 | Output: 313 | x_avg: batch * hdim 314 | """ 315 | return weights.unsqueeze(1).bmm(x).squeeze(1) 316 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /paragraph_encoder/model/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # Few methods have been adapted from https://github.com/facebookresearch/DrQA 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Data processing/loading helpers.""" 9 | 10 | 11 | import numpy as np 12 | import json 13 | import logging 14 | from smart_open import smart_open 15 | import unicodedata 16 | import heapq 17 | import os 18 | 19 | from torch.utils.data import Dataset 20 | from torch.utils.data.sampler import Sampler 21 | from .vector import vectorize 22 | 23 | logger = logging.getLogger() 24 | 25 | 26 | # ------------------------------------------------------------------------------ 27 | # Dictionary class for tokens. 28 | # ------------------------------------------------------------------------------ 29 | 30 | 31 | class Dictionary(object): 32 | NULL = '' 33 | UNK = '' 34 | 35 | @staticmethod 36 | def normalize(token): 37 | return unicodedata.normalize('NFD', token) 38 | 39 | def __init__(self, args): 40 | self.args = args 41 | if not args.create_vocab: 42 | logger.info('[ Reading vocab files from {}]'.format(args.vocab_dir)) 43 | self.tok2ind = json.load(open(args.vocab_dir+'tok2ind.json')) 44 | self.ind2tok = json.load(open(args.vocab_dir+'ind2tok.json')) 45 | 46 | else: 47 | self.tok2ind = {self.NULL: 0, self.UNK: 1} 48 | self.ind2tok = {0: self.NULL, 1: self.UNK} 49 | self.oov_words = {} 50 | 51 | # Index words in embedding file 52 | if args.pretrained_words and args.embedding_file: 53 | logger.info('[ Indexing words in embedding file... ]') 54 | self.valid_words = set() 55 | with smart_open(args.embedding_file) as f: 56 | for line in f: 57 | w = self.normalize(line.decode('utf-8').rstrip().split(' ')[0]) 58 | self.valid_words.add(w) 59 | logger.info('[ Num words in set = %d ]' % len(self.valid_words)) 60 | else: 61 | self.valid_words = None 62 | 63 | def __len__(self): 64 | return len(self.tok2ind) 65 | 66 | def __iter__(self): 67 | return iter(self.tok2ind) 68 | 69 | def __contains__(self, key): 70 | if type(key) == int: 71 | return key in self.ind2tok 72 | elif type(key) == str: 73 | return self.normalize(key) in self.tok2ind 74 | 75 | def __getitem__(self, key): 76 | if type(key) == int: 77 | return self.ind2tok.get(key, self.UNK) 78 | if type(key) == str: 79 | return self.tok2ind.get(self.normalize(key), 80 | self.tok2ind.get(self.UNK)) 81 | 82 | def add(self, token): 83 | token = self.normalize(token) 84 | if self.valid_words and token not in self.valid_words: 85 | # logger.info('{} not a valid word'.format(token)) 86 | if token not in self.oov_words: 87 | self.oov_words[token] = len(self.oov_words) 88 | return 89 | if token not in self.tok2ind: 90 | index = len(self.tok2ind) 91 | self.tok2ind[token] = index 92 | self.ind2tok[index] = token 93 | 94 | def swap_top(self, top_words): 95 | """ 96 | Reindexes the dictionary to have top_words labelled 2:N. 97 | (0, 1 are for , ) 98 | """ 99 | for idx, w in enumerate(top_words, 2): 100 | if w in self.tok2ind: 101 | w_2, idx_2 = self.ind2tok[idx], self.tok2ind[w] 102 | self.tok2ind[w], self.ind2tok[idx] = idx, w 103 | self.tok2ind[w_2], self.ind2tok[idx_2] = idx_2, w_2 104 | 105 | def save(self): 106 | 107 | fout = open(os.path.join(self.args.vocab_dir, "ind2tok.json"), "w") 108 | json.dump(self.ind2tok, fout) 109 | fout.close() 110 | fout = open(os.path.join(self.args.vocab_dir, "tok2ind.json"), "w") 111 | json.dump(self.tok2ind, fout) 112 | fout.close() 113 | logger.info("Dictionary saved at {}".format(self.args.vocab_dir)) 114 | 115 | 116 | # ------------------------------------------------------------------------------ 117 | # PyTorch dataset class for SQuAD (and SQuAD-like) data. 118 | # ------------------------------------------------------------------------------ 119 | 120 | 121 | class SquadDataset(Dataset): 122 | def __init__(self, args, examples, word_dict, 123 | feature_dict, single_answer=False, para_mode=False, train_time=True): 124 | self.examples = examples 125 | self.word_dict = word_dict 126 | self.feature_dict = feature_dict 127 | self.args = args 128 | self.single_answer = single_answer 129 | self.para_mode = para_mode 130 | self.train_time = train_time 131 | 132 | def __len__(self): 133 | return len(self.examples) 134 | 135 | def __getitem__(self, index): 136 | return vectorize(self.args, self.examples[index], self.word_dict, self.feature_dict, self.single_answer, 137 | self.para_mode, self.train_time) 138 | 139 | def lengths(self): 140 | if not self.para_mode: 141 | return [(len(ex['document']), len(ex['question'])) for ex in self.examples] 142 | else: 143 | q_key = 'question_str' if (self.args.src == 'triviaqa' or self.args.src == 'qangaroo') else 'question' 144 | return [(len(ex['document']), max([len(para) for para in ex['document']]), len(ex[q_key])) for ex in self.examples] 145 | 146 | class MultiCorpusDataset(Dataset): 147 | def __init__(self, args, corpus, word_dict, 148 | feature_dict, single_answer=False, para_mode=True, train_time=True): 149 | self.corpus = corpus 150 | self.word_dict = word_dict 151 | self.feature_dict = feature_dict 152 | self.args = args 153 | self.single_answer = single_answer 154 | self.para_mode = para_mode 155 | self.train_time = train_time 156 | self.pid_list = list(self.corpus.paragraphs.keys()) 157 | def __len__(self): 158 | if self.para_mode: 159 | return len(self.corpus.paragraphs) 160 | else: 161 | return len(self.corpus.questions) 162 | 163 | def __getitem__(self, index): 164 | if self.para_mode: 165 | ex = {} 166 | pid = self.pid_list[index] 167 | para = self.corpus.paragraphs[pid] 168 | assert pid == para.pid 169 | ex['document'] = para.text 170 | ex['id'] = para.pid 171 | ex['ans_occurance'] = para.ans_occurance 172 | qid = para.qid 173 | question = self.corpus.questions[qid] 174 | ex['question'] = question.text 175 | assert pid in question.pids 176 | 177 | return vectorize(self.args, ex) 178 | else: 179 | raise NotImplementedError("later") 180 | 181 | 182 | # ------------------------------------------------------------------------------ 183 | # PyTorch sampler returning batched of sorted lengths (by doc and question). 184 | # ------------------------------------------------------------------------------ 185 | 186 | class SortedBatchSampler(Sampler): 187 | def __init__(self, lengths, batch_size, shuffle=True, para_mode=False): 188 | self.lengths = lengths 189 | self.batch_size = batch_size 190 | self.shuffle = shuffle 191 | self.para_mode = para_mode 192 | 193 | def __iter__(self): 194 | if not self.para_mode: 195 | lengths = np.array( 196 | [(-l[0], -l[1], np.random.random()) for l in self.lengths], 197 | dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] 198 | ) 199 | else: 200 | lengths = np.array([(-l[0], -l[1], -l[2], np.random.random()) for l in self.lengths], dtype=[('l1', np.int_), ('l2', np.int_), ('l3', np.int_), ('rand', np.float_)]) 201 | indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) if not self.para_mode else np.argsort(lengths, order=('l1', 'l2', 'l3', 'rand')) 202 | batches = [indices[i:i + self.batch_size] 203 | for i in range(0, len(indices), self.batch_size)] 204 | if self.shuffle: 205 | np.random.shuffle(batches) 206 | return iter([i for batch in batches for i in batch]) 207 | 208 | def __len__(self): 209 | return len(self.lengths) 210 | 211 | class CorrectParaSortedBatchSampler(Sampler): 212 | """ 213 | This awesome sampler was written by Peng Qi (http://qipeng.me/) 214 | """ 215 | def __init__(self, dataset, batch_size, shuffle=True, para_mode=True): 216 | self.dataset = dataset 217 | self.batch_size = batch_size 218 | self.shuffle = shuffle 219 | self.para_mode = para_mode 220 | 221 | def __iter__(self): 222 | import sys 223 | correct_paras = [(ex[5] > 0).long().sum() for ex in self.dataset] 224 | 225 | # make sure the number of correct paras in each minibatch is about the same 226 | mean = sum(correct_paras) / len(correct_paras) 227 | target = mean * self.batch_size 228 | 229 | # also make sure the number of total paras in each minibatch is about the same 230 | lengths = [x[0] for x in self.dataset.lengths()] 231 | target2 = sum(lengths) / len(lengths) * self.batch_size 232 | 233 | heuristic_weight = 0.1 # heuristic importance of making sum_para_len uniform compared to making sum_correct_paras uniform 234 | 235 | indices = [x[0] for x in sorted(enumerate(zip(correct_paras, lengths)), key=lambda x: x[1], reverse=True)] 236 | 237 | batches = [[] for _ in range((len(self.dataset) + self.batch_size - 1) // self.batch_size)] 238 | 239 | batches_by_size = {0: {0: [(i, 0, 0) for i in range(len(batches))] } } 240 | 241 | K = 100 # "beam" size 242 | 243 | for idx in indices: 244 | costs = [] 245 | for size in batches_by_size: 246 | cost_reduction = -(2 * size + correct_paras[idx] - 2 * target) * correct_paras[idx] 247 | 248 | costs += [(size, cost_reduction)] 249 | 250 | costs = heapq.nlargest(K, costs, key=lambda x: x[1]) 251 | 252 | best_cand = None 253 | for size, cost in costs: 254 | best_size2 = -1 255 | best_reduction = -float('inf') 256 | for size2 in batches_by_size[size]: 257 | cost_reduction = -(2 * size2 + lengths[idx] - 2 * target2) * lengths[idx] 258 | 259 | if cost_reduction > best_reduction: 260 | best_size2 = size2 261 | best_reduction = cost_reduction 262 | 263 | assert best_size2 >= 0 264 | 265 | cost_reduction_all = cost + best_reduction * heuristic_weight 266 | if best_cand is None or cost_reduction_all > best_cand[2]: 267 | best_cand = (size, best_size2, cost_reduction_all, cost, best_reduction) 268 | 269 | assert best_cand is not None 270 | 271 | best_size, best_size2 = best_cand[:2] 272 | 273 | assert len(batches_by_size[best_size]) > 0 274 | 275 | # all batches of the same size are created equal 276 | best_batch, batches_by_size[best_size][best_size2] = batches_by_size[best_size][best_size2][0], batches_by_size[best_size][best_size2][1:] 277 | 278 | if len(batches_by_size[best_size][best_size2]) == 0: 279 | del batches_by_size[best_size][best_size2] 280 | if len(batches_by_size[best_size]) == 0: 281 | del batches_by_size[best_size] 282 | 283 | batches[best_batch[0]] += [idx] 284 | newsize = best_batch[1] + correct_paras[idx] 285 | newsize2 = best_batch[2] + lengths[idx] 286 | 287 | if len(batches[best_batch[0]]) < self.batch_size: 288 | # insert back 289 | if newsize not in batches_by_size: 290 | batches_by_size[newsize] = {} 291 | 292 | if newsize2 not in batches_by_size[newsize]: 293 | batches_by_size[newsize][newsize2] = [(best_batch[0], newsize, newsize2)] 294 | else: 295 | batches_by_size[newsize][newsize2] += [(best_batch[0], newsize, newsize2)] 296 | 297 | if self.shuffle: 298 | np.random.shuffle(batches) 299 | 300 | return iter([x for batch in batches for x in batch]) 301 | 302 | def __len__(self): 303 | return len(self.dataset) -------------------------------------------------------------------------------- /paragraph_encoder/model/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # Few methods have been adapted from https://github.com/facebookresearch/DrQA 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import ujson as json 11 | import time 12 | import os 13 | from tqdm import tqdm 14 | import logging 15 | import re 16 | from collections import defaultdict 17 | import pprint 18 | from smart_open import smart_open 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Data loading 22 | # ------------------------------------------------------------------------------ 23 | 24 | logger = logging.getLogger() 25 | pp = pprint.PrettyPrinter(indent=4) 26 | doc_count_map = defaultdict(int) # map of count of answers to number of docs so k -> N means N docs have k occurrences of the answer 27 | para_count_map = defaultdict(int) # map of count of answers to number of para 28 | orig_para_count_map = defaultdict(int) # map of count of answers to orig_para 29 | span_len_map = defaultdict(int) # map of span_len to count 30 | 31 | 32 | from collections import Counter 33 | from .data import Dictionary 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | # ------------------------------------------------------------------------------ 39 | # Data loading 40 | # ------------------------------------------------------------------------------ 41 | 42 | 43 | def load_data(args, filename, skip_no_answer=False): 44 | """Load examples from preprocessed file. 45 | One example per line, JSON encoded. 46 | """ 47 | # Load JSON lines 48 | with open(filename) as f: 49 | examples = [json.loads(line) for line in f] 50 | 51 | # Make case insensitive? 52 | if args.uncased_question or args.uncased_doc: 53 | for ex in examples: 54 | if args.uncased_question: 55 | ex['question'] = [w.lower() for w in ex['question']] 56 | if args.uncased_doc: 57 | ex['document'] = [w.lower() for w in ex['document']] 58 | 59 | # Skip unparsed (start/end) examples 60 | if skip_no_answer: 61 | examples = [ex for ex in examples if len(ex['answers']) > 0] 62 | 63 | return examples 64 | 65 | 66 | def load_text(filename): 67 | """Load the paragraphs only of a SQuAD dataset. Store as qid -> text.""" 68 | # Load JSON file 69 | with open(filename) as f: 70 | examples = json.load(f)['data'] 71 | 72 | texts = {} 73 | for article in examples: 74 | for paragraph in article['paragraphs']: 75 | for qa in paragraph['qas']: 76 | texts[qa['id']] = paragraph['context'] 77 | return texts 78 | 79 | 80 | def load_answers(filename): 81 | """Load the answers only of a SQuAD dataset. Store as qid -> [answers].""" 82 | # Load JSON file 83 | with open(filename) as f: 84 | examples = json.load(f)['data'] 85 | 86 | ans = {} 87 | for article in examples: 88 | for paragraph in article['paragraphs']: 89 | for qa in paragraph['qas']: 90 | ans[qa['id']] = list(map(lambda x: x['text'], qa['answers'])) 91 | return ans 92 | 93 | 94 | # ------------------------------------------------------------------------------ 95 | # Dictionary building 96 | # ------------------------------------------------------------------------------ 97 | 98 | 99 | def index_embedding_words(embedding_file): 100 | """Put all the words in embedding_file into a set.""" 101 | words = set() 102 | with open(embedding_file) as f: 103 | for line in f: 104 | w = Dictionary.normalize(line.rstrip().split(' ')[0]) 105 | words.add(w) 106 | return words 107 | 108 | 109 | def load_words(args, examples): 110 | """Iterate and index all the words in examples (documents + questions).""" 111 | def _insert(iterable): 112 | for w in iterable: 113 | w = Dictionary.normalize(w) 114 | if valid_words and w not in valid_words: 115 | continue 116 | words.add(w) 117 | 118 | valid_words = None 119 | 120 | words = set() 121 | # add words in the paragraph 122 | for pid, p in examples.paragraphs.items(): 123 | _insert(p.text) 124 | # add words in the question 125 | for qid, q in examples.questions.items(): 126 | _insert(q.text.split(" ")) # the question text has been tokenized but then joined with " " 127 | 128 | return words 129 | 130 | 131 | def build_word_dict(args, examples): 132 | """Return a dictionary from question and document words in 133 | provided examples. 134 | """ 135 | word_dict = Dictionary(args) 136 | if not args.create_vocab: 137 | return word_dict 138 | 139 | for w in load_words(args, examples): 140 | word_dict.add(w) 141 | # save so we dont have to make it from scratch again 142 | word_dict.save() 143 | 144 | return word_dict 145 | 146 | 147 | def top_question_words(args, examples, word_dict): 148 | """Count and return the most common question words in provided examples.""" 149 | word_count = Counter() 150 | for ex in examples: 151 | for w in ex['question']: 152 | w = Dictionary.normalize(w) 153 | if w in word_dict: 154 | word_count.update([w]) 155 | return word_count.most_common(args.tune_partial) 156 | 157 | 158 | def build_feature_dict(args, examples): 159 | """Index features (one hot) from fields in examples and options.""" 160 | # if not args.create_vocab: 161 | return json.load(open(os.path.join(args.vocab_dir, 'feat_dict.json'))) 162 | 163 | def _insert(feature): 164 | if feature not in feature_dict: 165 | feature_dict[feature] = len(feature_dict) 166 | 167 | feature_dict = {} 168 | 169 | # Exact match features 170 | if args.use_in_question: 171 | _insert('in_question') 172 | _insert('in_question_uncased') 173 | if args.use_lemma: 174 | _insert('in_question_lemma') 175 | 176 | # Part of speech tag features 177 | if args.use_pos: 178 | for ex in examples: 179 | for w in ex['pos']: 180 | _insert('pos=%s' % w) 181 | 182 | # Named entity tag features 183 | if args.use_ner: 184 | for ex in examples: 185 | for w in ex['ner']: 186 | _insert('ner=%s' % w) 187 | 188 | # Term frequency feature 189 | if args.use_tf: 190 | _insert('tf') 191 | return feature_dict 192 | 193 | 194 | # ------------------------------------------------------------------------------ 195 | # Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset. 196 | # ------------------------------------------------------------------------------ 197 | 198 | 199 | def normalize_answer(s): 200 | """Lower text and remove punctuation, articles and extra whitespace.""" 201 | def remove_articles(text): 202 | return re.sub(r'\b(a|an|the)\b', ' ', text) 203 | 204 | def white_space_fix(text): 205 | return ' '.join(text.split()) 206 | 207 | def remove_punc(text): 208 | exclude = set(string.punctuation) 209 | return ''.join(ch for ch in text if ch not in exclude) 210 | 211 | def lower(text): 212 | return text.lower() 213 | 214 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 215 | 216 | 217 | def f1_score(prediction, ground_truth): 218 | """Compute the geometric mean of precision and recall for answer tokens.""" 219 | prediction_tokens = normalize_answer(prediction).split() 220 | ground_truth_tokens = normalize_answer(ground_truth).split() 221 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 222 | num_same = sum(common.values()) 223 | if num_same == 0: 224 | return 0 225 | precision = 1.0 * num_same / len(prediction_tokens) 226 | recall = 1.0 * num_same / len(ground_truth_tokens) 227 | f1 = (2 * precision * recall) / (precision + recall) 228 | return f1 229 | 230 | 231 | def exact_match_score(prediction, ground_truth): 232 | """Check if the prediction is a (soft) exact match with the ground truth.""" 233 | return normalize_answer(prediction) == normalize_answer(ground_truth) 234 | 235 | 236 | def regex_match_score(prediction, pattern): 237 | """Check if the prediction matches the given regular expression.""" 238 | try: 239 | compiled = re.compile( 240 | pattern, 241 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 242 | ) 243 | except BaseException: 244 | logger.warn('Regular expression failed to compile: %s' % pattern) 245 | return False 246 | return compiled.match(prediction) is not None 247 | 248 | 249 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 250 | """Given a prediction and multiple valid answers, return the score of 251 | the best prediction-answer_n pair given a metric function. 252 | """ 253 | scores_for_ground_truths = [] 254 | for ground_truth in ground_truths: 255 | score = metric_fn(prediction, ground_truth) 256 | scores_for_ground_truths.append(score) 257 | return max(scores_for_ground_truths) 258 | 259 | 260 | def logsumexp(inputs, dim=None, keepdim=False): 261 | """Numerically stable logsumexp. 262 | Args: 263 | inputs: A Variable with any shape. 264 | dim: An integer. 265 | keepdim: A boolean. 266 | Returns: 267 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). 268 | """ 269 | # For a 1-D array x (any array along a single dimension), 270 | # log sum exp(x) = s + log sum exp(x - s) 271 | # with s = max(x) being a common choice. 272 | if dim is None: 273 | inputs = inputs.view(-1) 274 | dim = 0 275 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 276 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 277 | if not keepdim: 278 | outputs = outputs.squeeze(dim) 279 | return outputs 280 | 281 | 282 | # ------------------------------------------------------------------------------ 283 | # Utility classes 284 | # ------------------------------------------------------------------------------ 285 | 286 | 287 | class AverageMeter(object): 288 | """Computes and stores the average and current value.""" 289 | 290 | def __init__(self): 291 | self.reset() 292 | 293 | def reset(self): 294 | self.val = 0 295 | self.avg = 0 296 | self.sum = 0 297 | self.count = 0 298 | 299 | def update(self, val, n=1): 300 | self.val = val 301 | self.sum += val * n 302 | self.count += n 303 | self.avg = self.sum / self.count 304 | 305 | 306 | class Timer(object): 307 | """Computes elapsed time.""" 308 | 309 | def __init__(self): 310 | self.running = True 311 | self.total = 0 312 | self.start = time.time() 313 | 314 | def reset(self): 315 | self.running = True 316 | self.total = 0 317 | self.start = time.time() 318 | return self 319 | 320 | def resume(self): 321 | if not self.running: 322 | self.running = True 323 | self.start = time.time() 324 | return self 325 | 326 | def stop(self): 327 | if self.running: 328 | self.running = False 329 | self.total += time.time() - self.start 330 | return self 331 | 332 | def time(self): 333 | if self.running: 334 | return self.total + time.time() - self.start 335 | return self.total 336 | 337 | def load_embeddings(args, word_dict): 338 | 339 | embeddings = torch.Tensor(len(word_dict), args.embedding_dim_orig) 340 | if not os.path.isfile(args.embedding_table): 341 | logger.info("Initializing embedding table randomly...") 342 | embeddings.normal_(0, 1) 343 | embeddings[0].fill_(0) 344 | 345 | # Fill in embeddings 346 | with smart_open(args.embedding_file) as f: 347 | for line in f: 348 | line = line.decode('utf-8') 349 | parsed = line.rstrip().split(' ') 350 | assert (len(parsed) == args.embedding_dim_orig + 1) 351 | w = word_dict.normalize(parsed[0]) 352 | if w in word_dict: 353 | vec = torch.Tensor([float(i) for i in parsed[1:]]) 354 | embeddings[word_dict[w]].copy_(vec) 355 | # save the embedding table 356 | logger.info('Saving the embedding table') 357 | torch.save(embeddings, args.embedding_table) 358 | else: 359 | logger.info('Loading embeddings from saved embeddings table') 360 | embeddings = torch.load(args.embedding_table) 361 | return embeddings 362 | 363 | # 364 | # ------------------------------------------------------------------------------ 365 | # Utility classes 366 | # ------------------------------------------------------------------------------ 367 | 368 | 369 | class AverageMeter(object): 370 | """ 371 | Computes and stores the average and current value. 372 | """ 373 | 374 | def __init__(self): 375 | self.reset() 376 | 377 | def reset(self): 378 | self.val = 0 379 | self.avg = 0 380 | self.sum = 0 381 | self.count = 0 382 | 383 | def update(self, val, n=1): 384 | self.val = val 385 | self.sum += val * n 386 | self.count += n 387 | self.avg = self.sum / self.count 388 | 389 | 390 | class Timer(object): 391 | """ 392 | Computes elapsed time. 393 | """ 394 | 395 | def __init__(self): 396 | self.running = True 397 | self.total = 0 398 | self.start = time.time() 399 | 400 | def reset(self): 401 | self.running = True 402 | self.total = 0 403 | self.start = time.time() 404 | return self 405 | 406 | def resume(self): 407 | if not self.running: 408 | self.running = True 409 | self.start = time.time() 410 | return self 411 | 412 | def stop(self): 413 | if self.running: 414 | self.running = False 415 | self.total += time.time() - self.start 416 | return self 417 | 418 | def time(self): 419 | if self.running: 420 | return self.total + time.time() - self.start 421 | return self.total -------------------------------------------------------------------------------- /paragraph_encoder/train_para_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import os 5 | import pickle 6 | import sys 7 | import logging 8 | import shutil 9 | from tqdm import tqdm 10 | 11 | from torch.autograd import Variable 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.utils.data.sampler import RandomSampler 15 | 16 | import config 17 | from model import utils, data, vector 18 | from model.retriever import LSTMRetriever 19 | from multi_corpus import MultiCorpus 20 | 21 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 22 | import math 23 | 24 | logger = logging.getLogger() 25 | 26 | global_timer = utils.Timer() 27 | stats = {'timer': global_timer, 'epoch': 0, 'best_valid': 0, 'best_verified_valid': 0, 'best_acc': 0, 'best_verified_acc': 0} 28 | 29 | def make_data_loader(args, corpus, train_time=False): 30 | 31 | dataset = data.MultiCorpusDataset( 32 | args, 33 | corpus, 34 | args.word_dict, 35 | args.feature_dict, 36 | single_answer=False, 37 | para_mode=args.para_mode, 38 | train_time=train_time 39 | ) 40 | sampler = SequentialSampler(dataset) if not train_time else RandomSampler(dataset) 41 | loader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=args.batch_size, 44 | sampler=sampler, 45 | num_workers=args.data_workers, 46 | collate_fn=vector.batchify(args, args.para_mode, train_time=train_time), 47 | pin_memory=True 48 | ) 49 | 50 | return loader 51 | 52 | 53 | 54 | def init_from_checkpoint(args): 55 | 56 | logger.info('Loading model from saved checkpoint {}'.format(args.pretrained)) 57 | model = torch.load(args.pretrained) 58 | word_dict = model['word_dict'] 59 | feature_dict = model['feature_dict'] 60 | 61 | args.vocab_size = len(word_dict) 62 | args.embedding_dim_orig = args.embedding_dim 63 | args.word_dict = word_dict 64 | args.feature_dict = feature_dict 65 | 66 | ret = LSTMRetriever(args, word_dict, feature_dict) 67 | # load saved param values 68 | ret.model.load_state_dict(model['state_dict']['para_clf']) 69 | optimizer = None 70 | parameters = ret.get_trainable_params() 71 | if args.optimizer == 'sgd': 72 | optimizer = optim.SGD(parameters, args.learning_rate, 73 | momentum=args.momentum, 74 | weight_decay=args.weight_decay) 75 | elif args.optimizer == 'adamax': 76 | optimizer = optim.Adamax(parameters, 77 | weight_decay=args.weight_decay) 78 | elif args.optimizer == 'nag': 79 | optimizer = NAG(parameters, args.learning_rate, momentum=args.momentum, 80 | weight_decay=args.weight_decay) 81 | else: 82 | raise RuntimeError('Unsupported optimizer: %s' % args.optimizer) 83 | optimizer.load_state_dict(model['state_dict']['optimizer']) 84 | logger.info('Model loaded...') 85 | return ret, optimizer, word_dict, feature_dict 86 | 87 | 88 | def init_from_scratch(args, train_exs): 89 | 90 | logger.info('Initializing model from scratch') 91 | word_dict = feature_dict = None 92 | # create or get vocab 93 | word_dict = utils.build_word_dict(args, train_exs) 94 | if word_dict is not None: 95 | args.vocab_size = len(word_dict) 96 | args.embedding_dim_orig = args.embedding_dim 97 | args.word_dict = word_dict 98 | args.feature_dict = feature_dict 99 | 100 | ret = LSTMRetriever(args, word_dict, feature_dict) 101 | 102 | # -------------------------------------------------------------------------- 103 | # TRAIN/VALID LOOP 104 | # -------------------------------------------------------------------------- 105 | # train 106 | parameters = ret.get_trainable_params() 107 | 108 | 109 | optimizer = None 110 | if parameters is not None and len(parameters) > 0: 111 | if args.optimizer == 'sgd': 112 | optimizer = optim.SGD(parameters, args.learning_rate, 113 | momentum=args.momentum, 114 | weight_decay=args.weight_decay) 115 | elif args.optimizer == 'adamax': 116 | optimizer = optim.Adamax(parameters, 117 | weight_decay=args.weight_decay) 118 | elif args.optimizer == 'nag': 119 | optimizer = NAG(parameters, args.learning_rate, momentum=args.momentum, 120 | weight_decay=args.weight_decay) 121 | else: 122 | raise RuntimeError('Unsupported optimizer: %s' % args.optimizer) 123 | else: 124 | pass 125 | 126 | return ret, optimizer, word_dict, feature_dict 127 | 128 | def train_binary_classification(args, ret_model, optimizer, train_loader, verified_dev_loader=None): 129 | 130 | args.train_time = True 131 | para_loss = utils.AverageMeter() 132 | ret_model.model.train() 133 | for idx, ex in enumerate(train_loader): 134 | if ex is None: 135 | continue 136 | 137 | inputs = [e if e is None or type(e) != type(ex[0]) else Variable(e.cuda(async=True)) 138 | for e in ex[:]] 139 | ret_input = [*inputs[:4]] 140 | scores, _, _ = ret_model.score_paras(*ret_input) 141 | y_num_occurrences = Variable(ex[-2]) 142 | labels = (y_num_occurrences > 0).float() 143 | labels = labels.cuda() 144 | # BCE logits loss 145 | batch_para_loss = F.binary_cross_entropy_with_logits(scores.squeeze(1), labels) 146 | optimizer.zero_grad() 147 | batch_para_loss.backward() 148 | 149 | torch.nn.utils.clip_grad_norm(ret_model.get_trainable_params(), 150 | 2.0) 151 | optimizer.step() 152 | para_loss.update(batch_para_loss.data.item()) 153 | if math.isnan(para_loss.avg): 154 | import pdb 155 | pdb.set_trace() 156 | 157 | if idx % 25 == 0 and idx > 0: 158 | logger.info('Epoch = {} | iter={}/{} | para loss = {:2.4f}'.format( 159 | stats['epoch'], 160 | idx, len(train_loader), 161 | para_loss.avg)) 162 | para_loss.reset() 163 | 164 | 165 | def eval_binary_classification(args, ret_model, corpus, dev_loader, verified_dev_loader=None, save_scores = True): 166 | total_exs = 0 167 | args.train_time = False 168 | ret_model.model.eval() 169 | accuracy = 0.0 170 | for idx, ex in enumerate(tqdm(dev_loader)): 171 | if ex is None: 172 | raise BrokenPipeError 173 | 174 | inputs = [e if e is None or type(e) != type(ex[0]) else Variable(e.cuda(async=True)) 175 | for e in ex[:]] 176 | ret_input = [*inputs[:4]] 177 | total_exs += ex[0].size(0) 178 | 179 | scores, _, _ = ret_model.score_paras(*ret_input) 180 | 181 | scores = F.sigmoid(scores) 182 | y_num_occurrences = Variable(ex[-2]) 183 | labels = (y_num_occurrences > 0).float() 184 | labels = labels.data.numpy() 185 | scores = scores.cpu().data.numpy() 186 | scores = scores.reshape((-1)) 187 | if save_scores: 188 | for i, pid in enumerate(ex[-1]): 189 | corpus.paragraphs[pid].model_score = scores[i] 190 | 191 | scores = scores > 0.5 192 | a = scores == labels 193 | accuracy += a.sum() 194 | 195 | logger.info('Eval accuracy = {} '.format(accuracy/total_exs)) 196 | top1 = get_topk(corpus) 197 | return top1 198 | 199 | def print_vectors(args, para_vectors, question_vectors, corpus, train=False, test=False): 200 | all_question_vectors = [] 201 | all_para_vectors = [] 202 | qid2idx = {} 203 | cum_num_lens = [] 204 | all_correct_ans = {} 205 | cum_num_len = 0 206 | for question_i, qid in enumerate(corpus.questions): 207 | labels = [] 208 | all_question_vectors.append(question_vectors[qid]) 209 | qid2idx[qid] = question_i 210 | cum_num_len += len(corpus.questions[qid].pids) 211 | cum_num_lens.append(cum_num_len) 212 | for para_i, pid in enumerate(corpus.questions[qid].pids): 213 | if corpus.paragraphs[pid].ans_occurance > 0: 214 | labels.append(para_i) 215 | all_para_vectors.append(para_vectors[pid]) 216 | all_correct_ans[qid] = labels 217 | all_para_vectors = np.stack(all_para_vectors) 218 | all_question_vectors = np.stack(all_question_vectors) 219 | assert all_para_vectors.shape[0] == cum_num_lens[-1] 220 | assert all_question_vectors.shape[0] == len(cum_num_lens) 221 | assert all_question_vectors.shape[0] == len(qid2idx) 222 | assert all_question_vectors.shape[0] == len(all_correct_ans) 223 | 224 | ## saving code 225 | if train: 226 | OUT_DIR = os.path.join(args.save_dir, args.src, args.domain, "train/") 227 | else: 228 | if args.is_test == 0: 229 | OUT_DIR = os.path.join(args.save_dir, args.src, args.domain, "dev/") 230 | else: 231 | OUT_DIR = os.path.join(args.save_dir, args.src, args.domain, "test/") 232 | 233 | 234 | logger.info("Printing vectors at {}".format(OUT_DIR)) 235 | if not os.path.exists(OUT_DIR): 236 | os.makedirs(OUT_DIR) 237 | else: 238 | shutil.rmtree(OUT_DIR, ignore_errors=True) 239 | os.makedirs(OUT_DIR) 240 | 241 | 242 | json.dump(qid2idx, open(OUT_DIR + 'map.json', 'w')) 243 | json.dump(all_correct_ans, open(OUT_DIR + 'correct_paras.json', 'w')) 244 | all_cumlen = np.array(cum_num_lens) 245 | np.save(OUT_DIR + "document", all_para_vectors) 246 | np.save(OUT_DIR + "question", all_question_vectors) 247 | np.save(OUT_DIR + "all_cumlen", cum_num_lens) 248 | 249 | 250 | def save_vectors(args, ret_model, corpus, data_loader, verified_dev_loader=None, save_scores = True, train=False, test=False): 251 | total_exs = 0 252 | args.train_time = False 253 | ret_model.model.eval() 254 | para_vectors = {} 255 | question_vectors = {} 256 | for idx, ex in enumerate(tqdm(data_loader)): 257 | if ex is None: 258 | raise BrokenPipeError 259 | 260 | inputs = [e if e is None or type(e) != type(ex[0]) else Variable(e.cuda(async=True)) 261 | for e in ex[:]] 262 | ret_input = [*inputs[:4]] 263 | total_exs += ex[0].size(0) 264 | 265 | scores, doc, ques = ret_model.score_paras(*ret_input) 266 | scores = scores.cpu().data.numpy() 267 | scores = scores.reshape((-1)) 268 | 269 | if save_scores: 270 | for i, pid in enumerate(ex[-1]): 271 | para_vectors[pid] = doc[i] 272 | for i, qid in enumerate([corpus.paragraphs[pid].qid for pid in ex[-1]]): 273 | if qid not in question_vectors: 274 | question_vectors[qid] = ques[i] 275 | for i, pid in enumerate(ex[-1]): 276 | corpus.paragraphs[pid].model_score = scores[i] 277 | 278 | get_topk(corpus) 279 | print_vectors(args, para_vectors, question_vectors, corpus, train, test) 280 | 281 | 282 | 283 | def get_topk(corpus): 284 | top1 = 0 285 | top3 = 0 286 | top5 = 0 287 | for qid in corpus.questions: 288 | 289 | para_scores = [(corpus.paragraphs[pid].model_score,corpus.paragraphs[pid].ans_occurance ) for pid in corpus.questions[qid].pids] 290 | sorted_para_scores = sorted(para_scores, key=lambda x: x[0], reverse=True) 291 | 292 | if sorted_para_scores[0][1] > 0: 293 | top1 += 1 294 | if sum([ans[1] for ans in sorted_para_scores[:3]]) > 0: 295 | top3 += 1 296 | if sum([ans[1] for ans in sorted_para_scores[:5]]) > 0: 297 | top5 += 1 298 | 299 | top1 = top1/len(corpus.questions) 300 | top3 = top3/len(corpus.questions) 301 | top5 = top5/len(corpus.questions) 302 | 303 | logger.info('top1 = {}, top3 = {}, top5 = {} '.format(top1, top3 ,top5 )) 304 | return top1 305 | 306 | def get_topk_tfidf(corpus): 307 | top1 = 0 308 | top3 = 0 309 | top5 = 0 310 | for qid in corpus.questions: 311 | 312 | para_scores = [(corpus.paragraphs[pid].tfidf_score, corpus.paragraphs[pid].ans_occurance) for pid in 313 | corpus.questions[qid].pids] 314 | sorted_para_scores = sorted(para_scores, key=lambda x: x[0]) 315 | # import pdb 316 | # pdb.set_trace() 317 | if sorted_para_scores[0][1] > 0: 318 | top1 += 1 319 | if sum([ans[1] for ans in sorted_para_scores[:3]]) > 0: 320 | top3 += 1 321 | if sum([ans[1] for ans in sorted_para_scores[:5]]) > 0: 322 | top5 += 1 323 | 324 | logger.info( 325 | 'top1 = {}, top3 = {}, top5 = {} '.format(top1 / len(corpus.questions), top3 / len(corpus.questions), 326 | top5 / len(corpus.questions))) 327 | 328 | 329 | def run_predictions(args, data_loader, model, eval_on_train_set=False): 330 | 331 | args.train_time = False 332 | top_1 = 0 333 | top_3 = 0 334 | top_5 = 0 335 | total_num_questions = 0 336 | map_counter = 0 337 | cum_num_lens = [] 338 | qid2idx = {} 339 | sum_num_paras = 0 340 | all_correct_answers = {} 341 | 342 | for ex_counter, ex in tqdm(enumerate(data_loader)): 343 | 344 | ret_input = [*ex] 345 | y_num_occurrences = ex[3] 346 | labels = (y_num_occurrences > 0) 347 | try: 348 | topk_paras, docs, ques = model.return_topk(5,*ret_input) 349 | except RuntimeError: 350 | import pdb 351 | pdb.set_trace() 352 | 353 | num_paras = ex[1] 354 | qids = ex[-1] 355 | 356 | if args.save_para_clf_output: 357 | docs = docs.cpu().data.numpy() 358 | ques = ques.cpu().data.numpy() 359 | if ex_counter == 0: 360 | documents = docs 361 | questions = ques 362 | else: 363 | documents = np.concatenate([documents, docs]) 364 | questions = np.concatenate([questions, ques]) 365 | 366 | 367 | ### create map and cum_num_lens 368 | 369 | for i, qid in enumerate(qids): 370 | qid2idx[qid] = map_counter 371 | sum_num_paras += num_paras[i] 372 | cum_num_lens.append(sum_num_paras) 373 | all_correct_answers[map_counter] = [] 374 | 375 | st = sum(num_paras[:i]) 376 | for j in range(num_paras[i]): 377 | if labels[st+j] == 1: 378 | all_correct_answers[map_counter].append(j) 379 | 380 | ### Test case: 381 | assert len(all_correct_answers[map_counter]) == sum(labels.data.numpy()[st: st + num_paras[i]]) 382 | 383 | map_counter += 1 384 | 385 | 386 | 387 | counter = 0 388 | for q_counter, ranked_para_ids in enumerate(topk_paras): 389 | total_num_questions += 1 390 | for i, no_paras in enumerate(ranked_para_ids): 391 | if labels[counter + no_paras ] ==1: 392 | if i <= 4: 393 | top_5 += 1 394 | if i <= 2: 395 | top_3 += 1 396 | if i <= 0: 397 | top_1 += 1 398 | break 399 | counter += num_paras[q_counter] 400 | 401 | 402 | 403 | logger.info('Accuracy of para classifier when evaluated on the annotated dev set.') 404 | logger.info('top-1: {:2.4f}, top-3: {:2.4f}, top-5: {:2.4f}'.format( 405 | (top_1 * 1.0 / total_num_questions), 406 | (top_3 * 1.0 / total_num_questions), 407 | (top_5 * 1.0 / total_num_questions))) 408 | 409 | 410 | ## saving code 411 | if args.save_para_clf_output: 412 | if eval_on_train_set: 413 | OUT_DIR = "/iesl/canvas/sdhuliawala/vectors_web/train/" 414 | else: 415 | OUT_DIR = "/iesl/canvas/sdhuliawala/vectors_web/dev/" 416 | 417 | if not os.path.exists(OUT_DIR): 418 | os.mkdir(OUT_DIR) 419 | else: 420 | shutil.rmtree(OUT_DIR, ignore_errors=True) 421 | os.mkdir(OUT_DIR) 422 | 423 | 424 | #Test cases 425 | assert cum_num_lens[-1] == documents.shape[0] 426 | assert questions.shape[0] == documents.shape[0] 427 | assert len(cum_num_lens) == len(qid2idx) 428 | assert len(cum_num_lens) == len(all_correct_answers) 429 | 430 | json.dump(qid2idx, open(OUT_DIR + 'map.json', 'w')) 431 | json.dump(all_correct_answers, open(OUT_DIR + 'correct_paras.json', 'w')) 432 | all_cumlen = np.array(cum_num_lens) 433 | np.save(OUT_DIR + "document", documents) 434 | np.save(OUT_DIR + "question", questions) 435 | np.save(OUT_DIR + "all_cumlen", all_cumlen) 436 | return (top_1 * 1.0 / total_num_questions), (top_3 * 1.0 / total_num_questions), (top_5 * 1.0 / total_num_questions) 437 | 438 | 439 | def save(args, model, optimizer, filename, epoch=None): 440 | 441 | params = { 442 | 'state_dict': { 443 | 'para_clf': model.state_dict(), 444 | 'optimizer': optimizer.state_dict() 445 | }, 446 | 'word_dict': args.word_dict, 447 | 'feature_dict': args.feature_dict 448 | } 449 | args.word_dict = None 450 | args.feature_dict = None 451 | params['config'] = vars(args) 452 | if epoch: 453 | params['epoch'] = epoch 454 | try: 455 | torch.save(params, filename) 456 | # bad hack for not saving dictionary twice 457 | args.word_dict = params['word_dict'] 458 | args.feature_dict = params['feature_dict'] 459 | except BaseException: 460 | logger.warn('[ WARN: Saving failed... continuing anyway. ]') 461 | 462 | 463 | # ------------------------------------------------------------------------------ 464 | # Main. 465 | # ------------------------------------------------------------------------------ 466 | 467 | def main(args): 468 | 469 | # PRINT CONFIG 470 | logger.info('-' * 100) 471 | logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) 472 | 473 | # small can't test 474 | if args.small == 1: 475 | args.test = 0 476 | 477 | if args.small == 1: 478 | args.train_file_name = args.train_file_name + "_small" 479 | args.dev_file_name = args.dev_file_name + "_small" 480 | if args.test == 1: 481 | args.test_file_name = args.test_file_name + "_small" 482 | 483 | args.train_file_name = args.train_file_name + ".pkl" 484 | args.dev_file_name = args.dev_file_name + ".pkl" 485 | if args.test == 1: 486 | args.test_file_name = args.test_file_name + ".pkl" 487 | 488 | logger.info("Loading pickle files") 489 | fin = open(os.path.join(args.data_dir, args.src, "data", args.domain, args.train_file_name), "rb") 490 | all_train_exs = pickle.load(fin) 491 | fin.close() 492 | 493 | fin = open(os.path.join(args.data_dir, args.src, "data", args.domain, args.dev_file_name), "rb") 494 | all_dev_exs = pickle.load(fin) 495 | fin.close() 496 | if args.test == 1: 497 | fin = open(os.path.join(args.data_dir, args.src, "data", args.domain, args.test_file_name), "rb") 498 | all_test_exs = pickle.load(fin) 499 | fin.close() 500 | 501 | logger.info("Loading done!") 502 | 503 | logger.info("Num train examples {}".format(len(all_train_exs.paragraphs))) 504 | logger.info("Num dev examples {}".format(len(all_dev_exs.paragraphs))) 505 | if args.test == 1: 506 | logger.info("Num test examples {}".format(len(all_test_exs.paragraphs))) 507 | 508 | 509 | if args.pretrained is None: 510 | ret_model, optimizer, word_dict, feature_dict = init_from_scratch(args, all_train_exs) 511 | else: 512 | ret_model, optimizer, word_dict, feature_dict = init_from_checkpoint(args) 513 | 514 | # make data loader 515 | logger.info("Making data loaders...") 516 | if word_dict == None: 517 | args.word_dict = utils.build_word_dict(args, (all_train_exs, all_dev_exs)) 518 | word_dict = args.word_dict 519 | 520 | train_loader = make_data_loader(args, all_train_exs, train_time=False) if args.eval_only else make_data_loader(args, all_train_exs, train_time=True) 521 | dev_loader = make_data_loader(args, all_dev_exs) 522 | if args.test: 523 | test_loader = make_data_loader(args, all_test_exs) 524 | 525 | 526 | if args.eval_only: 527 | logger.info("Saving dev paragraph vectors") 528 | save_vectors(args, ret_model, all_dev_exs, dev_loader, verified_dev_loader=None) 529 | 530 | 531 | logger.info("Saving train paragraph vectors") 532 | save_vectors(args, ret_model, all_train_exs, train_loader, verified_dev_loader=None, train=True) 533 | if args.test: 534 | args.is_test = 1 535 | logger.info("Saving test paragraph vectors") 536 | save_vectors(args, ret_model, all_test_exs, test_loader, verified_dev_loader=None) 537 | 538 | else: 539 | get_topk_tfidf(all_dev_exs) 540 | for epoch in range(args.num_epochs): 541 | stats['epoch'] = epoch 542 | train_binary_classification(args, ret_model, optimizer, train_loader, verified_dev_loader=None) 543 | logger.info('checkpointing model at {}'.format(args.model_file)) 544 | ## check pointing## 545 | save(args, ret_model.model, optimizer, args.model_file+".ckpt", epoch=stats['epoch']) 546 | 547 | logger.info("Evaluating on the full dev set....") 548 | top1 = eval_binary_classification(args, ret_model, all_dev_exs, dev_loader, verified_dev_loader=None) 549 | if stats['best_acc'] < top1: 550 | stats['best_acc'] = top1 551 | logger.info('Best accuracy {}'.format(stats['best_acc'])) 552 | logger.info('Saving model at {}'.format(args.model_file)) 553 | logger.info("Logs saved at {}".format(args.log_file)) 554 | save(args, ret_model.model, optimizer, args.model_file, epoch=stats['epoch']) 555 | 556 | 557 | if __name__ == '__main__': 558 | # MODEL 559 | logger.info('-' * 100) 560 | # Parse cmdline args and setup environment 561 | args = config.get_args() 562 | 563 | # Set cuda 564 | args.cuda = not args.no_cuda and torch.cuda.is_available() 565 | if args.cuda: 566 | torch.cuda.set_device(args.gpu) 567 | 568 | # Set random state 569 | np.random.seed(args.random_seed) 570 | torch.manual_seed(args.random_seed) 571 | if args.cuda: 572 | torch.cuda.manual_seed(args.random_seed) 573 | 574 | # Set logging 575 | logger.setLevel(logging.INFO) 576 | fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p') 577 | console = logging.StreamHandler() 578 | console.setFormatter(fmt) 579 | logger.addHandler(console) 580 | if args.log_file: 581 | if args.checkpoint: 582 | logfile = logging.FileHandler(args.log_file, 'a') 583 | else: 584 | logfile = logging.FileHandler(args.log_file, 'w') 585 | logfile.setFormatter(fmt) 586 | logger.addHandler(logfile) 587 | logger.info('[ COMMAND: %s ]' % ' '.join(sys.argv)) 588 | 589 | # Run! 590 | main(args) --------------------------------------------------------------------------------