├── LICENSE ├── README.md ├── build_db.py ├── build_tfidf.py ├── drqa_retriever ├── __init__.py ├── doc_db.py ├── tfidf_doc_ranker.py └── utils.py ├── drqa_tokenizers ├── __init__.py ├── corenlp_tokenizer.py ├── regexp_tokenizer.py ├── simple_tokenizer.py ├── spacy_tokenizer.py └── tokenizer.py ├── filter_subset_wiki.py ├── inference_tfidf.py ├── requirements.txt ├── run.sh └── run_inference.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For DrQA software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README for retrieval-based baselines 2 | 3 | This repo provides guidelines for training and testing retrieval-based baselines for [NeurIPS Competition on Efficient Open-domain Question Answering](http://efficientqa.github.io/). 4 | 5 | We provide two retrieval-based baselines: 6 | 7 | - TF-IDF: TF-IDF retrieval built on fixed-length passages, adapted from the [DrQA system's implementation](https://github.com/facebookresearch/DrQA). 8 | - DPR: A learned dense passage retriever, detailed in [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) (Karpukhin et al, 2020). Our baseline is adapted from the [original implementation](https://github.com/facebookresearch/DPR). 9 | 10 | 11 | Note that for both baselines, we use text blocks of 100 words as passages and a BERT-base multi-passage reader. See more details in the [DPR paper](https://arxiv.org/pdf/2004.04906.pdf). 12 | 13 | We provide two variants for each model, using (1) full Wikipedia (`full`) and (2) A subset of Wikipedia articles which are found relevant to the questions on the train data (`subset`). In particular, we think of `subset` as a naive way to reduce the disk memory usage for the retrieval-based baselines. 14 | 15 | 16 | *Note: If you want to try parameter-only baselines (T5-based) for the competition, please note that implementations on T5-based Closed-book QA model is available [here](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa).* 17 | 18 | *Note: If you want simple guidelines on making end-to-end QA predictions using pretrained models, please refer to [this tutorial](https://github.com/efficientqa/efficientqa.github.io/blob/master/getting_started.md).* 19 | 20 | ## Content 21 | 22 | 1. [Getting ready](#getting-ready) 23 | 2. [TFIDF retrieval](#tfidf-retrieval) 24 | 3. [DPR retrieval](#dpr-retrieval) 25 | 4. [DPR reader](#dpr-reader) 26 | 5. [Result](#result) 27 | 28 | ## Getting ready 29 | 30 | ### Git clone 31 | 32 | ```bash 33 | git clone https://github.com/facebookresearch/DPR.git # dependency 34 | git clone https://github.com/efficientqa/retrieval-based-baselines.git # this repo 35 | ``` 36 | 37 | ### Download data 38 | 39 | Follow [DPR repo][dpr] in order to download NQ data and Wikipedia DB. Specificially, after running `cd DPR` and let `base_dir` as your base directory to store data and pretrained models, 40 | 41 | 42 | 1. Download QA pairs by `python3 data/download_data.py --resource data.retriever.qas --output_dir ${base_dir}` and `python3 data/download_data.py --resource data.retriever.nq --output_dir ${base_dir}`. 43 | 2. Download wikipedia DB by `python3 data/download_data.py --resource data.wikipedia_split --output_dir ${base_dir}`. 44 | 3. Download gold question-passage pairs by `python3 data/download_data.py --resource data.gold_passages_info --output_dir ${base_dir}`. 45 | 46 | Optionally, if you want to try `subset` variant, run `cd ../retrieval-based-baselines; python3 filter_subset_wiki.py --db_path ${base_dir}/data/wikipedia_split/psgs_w100.tsv --data_path ${base_dir}/data/retriever/nq-train.json`. This script will create a new passage DB containing passages which originated articles are those paired with question on the original NQ data (78,050 unique articles; 1,642,855 unique passages). 47 | This new DB will be stored at `${base_dir}/data/wikipedia_split/psgs_w100_subset.tsv`. 48 | 49 | From now on, we will denote Wikipedia DBs (either full or subset) as `db_path`. 50 | 51 | 52 | ## TFIDF retrieval 53 | 54 | Make sure to be in `retrieval-based-baselines` directory to run scripts for TFIDF (largely adapted from [DrQA repo][drqa]). 55 | 56 | **Step 1**: Run `pip install -r requirements.txt` 57 | 58 | **Step 2**: Build Sqlite DB via: 59 | ``` 60 | mkdir -p {base_dir}/tfidf 61 | python3 build_db.py ${db_path} ${base_dir}/tfidf/db.db --num-workers 60`. 62 | ``` 63 | **Step 3**: Run the following command to build TFIDF index. 64 | ``` 65 | python3 build_tfidf.py ${base_dir}/tfidf/db.db ${base_dir}/tfidf 66 | ``` 67 | It will save TF-IDF index in `${base_dir}/tfidf` 68 | 69 | **Step 4**: Run inference code to save retrieval results. 70 | ``` 71 | python3 inference_tfidf.py --qa_file ${base_dir}/data/retriever/qas/nq-{train|dev|test}.csv --db_path ${db_path} --out_file ${base_dir}/tfidf/nq-{train|dev|test}.json --tfidf_path {path_to_tfidf_index} 72 | ``` 73 | 74 | The resulting files, `${base_dir}/tfidf/nq-{train|dev|test}-tfidf.json` are ready to be fed into the DPR reader. 75 | 76 | ## DPR retrieval 77 | 78 | Follow [DPR repo][dpr] to train DPR retriever and make inference. You can follow steps until [Retriever validation](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). 79 | 80 | 81 | If you want to use retriever checkpoint provided by DPR, follow these three steps. 82 | 83 | **Step 1**: Make sure to be in `DPR` directory, and download retriever checkpoint by `python3 data/download_data.py --resource checkpoint.retriever.multiset.bert-base-encoder --output_dir ${base_dir}`. 84 | 85 | **Step 2**: Save passage vectors by following [Generating representations](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). Note that you can replace `ctx_file` to your own `db_path` if you are trying "seen only" version. In particular, you can do 86 | ``` 87 | python3 generate_dense_embeddings.py --model_file ${base_dir}/checkpoint/retriever/multiset/bert-base-encoder.cp --ctx_file ${db_path} --shard_id {0-19} --num_shards 20 --out_file ${base_dir}/dpr_ctx 88 | ``` 89 | 90 | **Step 3**: Save retrieval results by following [Retriever validation](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). In particular, you can do 91 | ``` 92 | mkdir -p ${base_dir}/dpr_retrieval 93 | python3 dense_retriever.py \ 94 | --model_file ${base_dir}/checkpoint/retriever/single/nq/bert-base-encoder.cp \ 95 | --ctx_file ${dp_path} \ 96 | --qa_file ${base_dir}/data/retriever/qas/nq-{train|dev|test}.csv \ 97 | --encoded_ctx_file ${base_dir}/'dpr_ctx*' \ 98 | --out_file ${base_dir}/dpr_retrieval/nq-{train|dev|test}.json \ 99 | --n-docs 200 \ 100 | --save_or_load_index # this to save the dense index if it was built for the first time, and load it next times. 101 | ``` 102 | 103 | Now, `${base_dir}/dpr_retrieval/nq-{train|dev|test}.json` is ready to be fed into DPR reader. 104 | 105 | ## DPR reader 106 | 107 | *Note*: The following instruction is identical to instructions from [DPR README](https://github.com/facebookresearch/DPR#optional-reader-model-input-data-pre-processing), but we rewrite it with hyperparamters specified for our baselines. 108 | 109 | The following instruction is for training the reader using TFIDF results, saved in `${base_dir}/tfidf/nq-{train|dev|test}-tfidf.json`. In order to use DPR retrieval results, simply replace paths to these files to `${base_dir}/dpr_retrieval/nq-{train|dev|test}.json` 110 | 111 | **Step 1**: Preprocess data. 112 | 113 | ``` 114 | python3 preprocess_reader_data.py \ 115 | --retriever_results ${base_dir}/tfidf/nq-{train|dev|test}.json \ 116 | --gold_passages ${base_dir}/data/gold_passages_info/nq_{train|dev|test}.json \ 117 | --do_lower_case \ 118 | --pretrained_model_cfg bert-base-uncased \ 119 | --encoder_model_type hf_bert \ 120 | --out_file ${base_dir}/tfidf/nq-{train|dev|test}-tfidf \ 121 | --is_train_set # specify this only when it is train data 122 | ``` 123 | 124 | **Step 2**: Train the reader. 125 | ``` 126 | python3 train_reader.py \ 127 | --encoder_model_type hf_bert \ 128 | --pretrained_model_cfg bert-base-uncased \ 129 | --train_file ${base_dir}/tfidf/'nq-train*.pkl' \ 130 | --dev_file ${base_dir}/tfidf/'nq-dev*.pkl' \ 131 | --output_dir ${base_dir}/checkpoints/reader_from_tfidf \ 132 | --seed 42 \ 133 | --learning_rate 1e-5 \ 134 | --eval_step 2000 \ 135 | --eval_top_docs 50 \ 136 | --warmup_steps 0 \ 137 | --sequence_length 350 \ 138 | --batch_size 16 \ 139 | --passages_per_question 24 \ 140 | --num_train_epochs 100000 \ 141 | --dev_batch_size 72 \ 142 | --passages_per_question_predict 50 143 | ``` 144 | 145 | **Step 3**: Test the reader. 146 | ``` 147 | python train_reader.py \ 148 | --prediction_results_file ${base_dir}/checkpoints/reader_from_tfidf/dev_predictions.json \ 149 | --eval_top_docs 10 20 40 50 80 100 \ 150 | --dev_file ${base_dir}/tfidf/`nq-dev*.pkl` \ 151 | --model_file ${base_dir}/checkpoints/reader_from_tfidf/{checkpoint file} \ 152 | --dev_batch_size 80 \ 153 | --passages_per_question_predict 100 \ 154 | --sequence_length 350 155 | ``` 156 | 157 | [drqa]: https://github.com/facebookresearch/DrQA/ 158 | [dpr]: https://github.com/facebookresearch/DPR 159 | 160 | ## Result 161 | 162 | |Model|Exact Mach|Disk usage (gb)| 163 | |---|---|---| 164 | |TFIDF-full|32.0|20.1| 165 | |TFIDF-subset|31.0|2.8| 166 | |DPR-full|41.0|66.4| 167 | |DPR-subset|34.8|5.9| 168 | 169 | 170 | -------------------------------------------------------------------------------- /build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to read in and store documents in a sqlite database.""" 8 | 9 | import argparse 10 | import sqlite3 11 | import json 12 | import os 13 | import logging 14 | import importlib.util 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from tqdm import tqdm 18 | from drqa_retriever import utils 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 23 | console = logging.StreamHandler() 24 | console.setFormatter(fmt) 25 | logger.addHandler(console) 26 | 27 | 28 | # ------------------------------------------------------------------------------ 29 | # Import helper 30 | # ------------------------------------------------------------------------------ 31 | 32 | 33 | PREPROCESS_FN = None 34 | 35 | 36 | def init(filename): 37 | global PREPROCESS_FN 38 | if filename: 39 | PREPROCESS_FN = import_module(filename).preprocess 40 | 41 | 42 | def import_module(filename): 43 | """Import a module given a full path to the file.""" 44 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 45 | module = importlib.util.module_from_spec(spec) 46 | spec.loader.exec_module(module) 47 | return module 48 | 49 | 50 | # ------------------------------------------------------------------------------ 51 | # Store corpus. 52 | # ------------------------------------------------------------------------------ 53 | 54 | 55 | def iter_files(path): 56 | """Walk through all files located under a root path.""" 57 | if os.path.isfile(path): 58 | yield path 59 | elif os.path.isdir(path): 60 | for dirpath, _, filenames in os.walk(path): 61 | for f in filenames: 62 | yield os.path.join(dirpath, f) 63 | else: 64 | raise RuntimeError('Path %s is invalid' % path) 65 | 66 | 67 | def get_contents(filename): 68 | """Parse the contents of a file. Each line is a JSON encoded document.""" 69 | global PREPROCESS_FN 70 | documents = [] 71 | if filename.endswith(".tsv.gz"): 72 | raise NotImplementedError("TODO") 73 | elif filename.endswith(".tsv"): 74 | import csv 75 | with open(filename) as tsvfile: 76 | reader = csv.reader(tsvfile, delimiter='\t') 77 | # file format: doc_id, doc_text, title 78 | for (doc_id, doc_text, title) in reader: 79 | if doc_id=="id": continue 80 | documents.append((doc_id, utils.normalize(title) + " " + doc_text)) 81 | else: 82 | with open(filename) as f: 83 | for line in f: 84 | # Parse document 85 | doc = json.loads(line) 86 | # Maybe preprocess the document with custom function 87 | if PREPROCESS_FN: 88 | doc = PREPROCESS_FN(doc) 89 | # Skip if it is empty or None 90 | if not doc: 91 | continue 92 | # Add the document 93 | documents.append((utils.normalize(doc['id']), doc['text'])) 94 | return documents 95 | 96 | def store_contents(data_path, save_path, preprocess, num_workers=None): 97 | """Preprocess and store a corpus of documents in sqlite. 98 | 99 | Args: 100 | data_path: Root path to directory (or directory of directories) of files 101 | containing json encoded documents (must have `id` and `text` fields). 102 | save_path: Path to output sqlite db. 103 | preprocess: Path to file defining a custom `preprocess` function. Takes 104 | in and outputs a structured doc. 105 | num_workers: Number of parallel processes to use when reading docs. 106 | """ 107 | if os.path.isfile(save_path): 108 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 109 | 110 | logger.info('Reading into database...') 111 | conn = sqlite3.connect(save_path) 112 | c = conn.cursor() 113 | c.execute("CREATE TABLE documents (id PRIMARY KEY, text);") 114 | 115 | if num_workers is None or num_workers==1: 116 | files = [f for f in iter_files(data_path)] 117 | count = 0 118 | with tqdm(total=len(files)) as pbar: 119 | for f in files: 120 | pairs = get_contents(f) 121 | count += len(pairs) 122 | c.executemany("INSERT INTO documents VALUES (?,?)", pairs) 123 | pbar.update() 124 | else: 125 | workers = ProcessPool(num_workers, initializer=init, initargs=(preprocess,)) 126 | files = [f for f in iter_files(data_path)] 127 | count = 0 128 | with tqdm(total=len(files)) as pbar: 129 | for pairs in tqdm(workers.imap_unordered(get_contents, files)): 130 | count += len(pairs) 131 | c.executemany("INSERT INTO documents VALUES (?,?)", pairs) 132 | pbar.update() 133 | logger.info('Read %d docs.' % count) 134 | logger.info('Committing...') 135 | conn.commit() 136 | conn.close() 137 | 138 | 139 | # ------------------------------------------------------------------------------ 140 | # Main. 141 | # ------------------------------------------------------------------------------ 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('data_path', type=str, help='/path/to/data') 147 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db') 148 | parser.add_argument('--preprocess', type=str, default=None, 149 | help=('File path to a python module that defines ' 150 | 'a `preprocess` function')) 151 | parser.add_argument('--num-workers', type=int, default=None, 152 | help='Number of CPU processes (for tokenizing, etc)') 153 | args = parser.parse_args() 154 | 155 | store_contents( 156 | args.data_path, args.save_path, args.preprocess, args.num_workers 157 | ) 158 | -------------------------------------------------------------------------------- /build_tfidf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to build the tf-idf document matrices for retrieval.""" 8 | 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import argparse 12 | import os 13 | import math 14 | import logging 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from multiprocessing.util import Finalize 18 | from functools import partial 19 | from collections import Counter 20 | 21 | import drqa_retriever as retriever 22 | import drqa_tokenizers as tokenizers 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 27 | console = logging.StreamHandler() 28 | console.setFormatter(fmt) 29 | logger.addHandler(console) 30 | 31 | 32 | # ------------------------------------------------------------------------------ 33 | # Multiprocessing functions 34 | # ------------------------------------------------------------------------------ 35 | 36 | DOC2IDX = None 37 | PROCESS_TOK = None 38 | PROCESS_DB = None 39 | 40 | 41 | def init(tokenizer_class, db_class, db_opts): 42 | global PROCESS_TOK, PROCESS_DB 43 | PROCESS_TOK = tokenizer_class() 44 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 45 | PROCESS_DB = db_class(**db_opts) 46 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 47 | 48 | 49 | def fetch_text(doc_id): 50 | global PROCESS_DB 51 | return PROCESS_DB.get_doc_text(doc_id) 52 | 53 | 54 | def tokenize(text): 55 | global PROCESS_TOK 56 | return PROCESS_TOK.tokenize(text) 57 | 58 | 59 | # ------------------------------------------------------------------------------ 60 | # Build article --> word count sparse matrix. 61 | # ------------------------------------------------------------------------------ 62 | 63 | 64 | def count(ngram, hash_size, doc_id): 65 | """Fetch the text of a document and compute hashed ngrams counts.""" 66 | global DOC2IDX 67 | row, col, data = [], [], [] 68 | # Tokenize 69 | tokens = tokenize(retriever.utils.normalize(fetch_text(doc_id))) 70 | 71 | # Get ngrams from tokens, with stopword/punctuation filtering. 72 | ngrams = tokens.ngrams( 73 | n=ngram, uncased=True, filter_fn=retriever.utils.filter_ngram 74 | ) 75 | 76 | # Hash ngrams and count occurences 77 | counts = Counter([retriever.utils.hash(gram, hash_size) for gram in ngrams]) 78 | 79 | # Return in sparse matrix data format. 80 | row.extend(counts.keys()) 81 | col.extend([DOC2IDX[doc_id]] * len(counts)) 82 | data.extend(counts.values()) 83 | return row, col, data 84 | 85 | 86 | def get_count_matrix(args, db, db_opts): 87 | """Form a sparse word to document count matrix (inverted index). 88 | 89 | M[i, j] = # times word i appears in document j. 90 | """ 91 | # Map doc_ids to indexes 92 | global DOC2IDX 93 | db_class = retriever.get_class(db) 94 | with db_class(**db_opts) as doc_db: 95 | doc_ids = doc_db.get_doc_ids() 96 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)} 97 | 98 | # Setup worker pool 99 | tok_class = tokenizers.get_class(args.tokenizer) 100 | workers = ProcessPool( 101 | args.num_workers, 102 | initializer=init, 103 | initargs=(tok_class, db_class, db_opts) 104 | ) 105 | 106 | # Compute the count matrix in steps (to keep in memory) 107 | logger.info('Mapping...') 108 | row, col, data = [], [], [] 109 | step = max(int(len(doc_ids) / 10), 1) 110 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)] 111 | _count = partial(count, args.ngram, args.hash_size) 112 | for i, batch in enumerate(batches): 113 | logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25) 114 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch): 115 | row.extend(b_row) 116 | col.extend(b_col) 117 | data.extend(b_data) 118 | workers.close() 119 | workers.join() 120 | 121 | logger.info('Creating sparse matrix...') 122 | count_matrix = sp.csr_matrix( 123 | (data, (row, col)), shape=(args.hash_size, len(doc_ids)) 124 | ) 125 | count_matrix.sum_duplicates() 126 | return count_matrix, (DOC2IDX, doc_ids) 127 | 128 | 129 | # ------------------------------------------------------------------------------ 130 | # Transform count matrix to different forms. 131 | # ------------------------------------------------------------------------------ 132 | 133 | 134 | def get_tfidf_matrix(cnts): 135 | """Convert the word count matrix into tfidf one. 136 | 137 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 138 | * tf = term frequency in document 139 | * N = number of documents 140 | * Nt = number of occurences of term in all documents 141 | """ 142 | Ns = get_doc_freqs(cnts) 143 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5)) 144 | idfs[idfs < 0] = 0 145 | idfs = sp.diags(idfs, 0) 146 | tfs = cnts.log1p() 147 | tfidfs = idfs.dot(tfs) 148 | return tfidfs 149 | 150 | 151 | def get_doc_freqs(cnts): 152 | """Return word --> # of docs it appears in.""" 153 | binary = (cnts > 0).astype(int) 154 | freqs = np.array(binary.sum(1)).squeeze() 155 | return freqs 156 | 157 | 158 | # ------------------------------------------------------------------------------ 159 | # Main. 160 | # ------------------------------------------------------------------------------ 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('db_path', type=str, default=None, 166 | help='Path to sqlite db holding document texts') 167 | parser.add_argument('out_dir', type=str, default=None, 168 | help='Directory for saving output files') 169 | parser.add_argument('--ngram', type=int, default=2, 170 | help=('Use up to N-size n-grams ' 171 | '(e.g. 2 = unigrams + bigrams)')) 172 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)), 173 | help='Number of buckets to use for hashing ngrams') 174 | parser.add_argument('--tokenizer', type=str, default='simple', 175 | help=("String option specifying tokenizer type to use " 176 | "(e.g. 'corenlp')")) 177 | parser.add_argument('--num-workers', type=int, default=None, 178 | help='Number of CPU processes (for tokenizing, etc)') 179 | args = parser.parse_args() 180 | 181 | logging.info('Counting words...') 182 | count_matrix, doc_dict = get_count_matrix( 183 | args, 'sqlite', {'db_path': args.db_path} 184 | ) 185 | 186 | logger.info('Making tfidf vectors...') 187 | tfidf = get_tfidf_matrix(count_matrix) 188 | 189 | logger.info('Getting word-doc frequencies...') 190 | freqs = get_doc_freqs(count_matrix) 191 | 192 | basename = os.path.splitext(os.path.basename(args.db_path))[0] 193 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=%s' % 194 | (args.ngram, args.hash_size, args.tokenizer)) 195 | filename = os.path.join(args.out_dir, basename) 196 | 197 | logger.info('Saving to %s.npz' % filename) 198 | metadata = { 199 | 'doc_freqs': freqs, 200 | 'tokenizer': args.tokenizer, 201 | 'hash_size': args.hash_size, 202 | 'ngram': args.ngram, 203 | 'doc_dict': doc_dict 204 | } 205 | retriever.utils.save_sparse_csr(filename, tfidf, metadata) 206 | -------------------------------------------------------------------------------- /drqa_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DATA_DIR = "/home/sewon/analysis_multiple/DrQA/data" 11 | DEFAULTS = { 12 | 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), 13 | 'tfidf_path': os.path.join( 14 | DATA_DIR, 15 | 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' 16 | ), 17 | } 18 | 19 | 20 | def set_default(key, value): 21 | global DEFAULTS 22 | DEFAULTS[key] = value 23 | 24 | 25 | def get_class(name): 26 | if name == 'tfidf': 27 | return TfidfDocRanker 28 | if name == 'sqlite': 29 | return DocDB 30 | raise RuntimeError('Invalid retriever class: %s' % name) 31 | 32 | 33 | from .doc_db import DocDB 34 | from .tfidf_doc_ranker import TfidfDocRanker 35 | -------------------------------------------------------------------------------- /drqa_retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | 12 | 13 | class DocDB(object): 14 | """Sqlite backed document storage. 15 | 16 | Implements get_doc_text(doc_id). 17 | """ 18 | 19 | def __init__(self, db_path=None): 20 | self.path = db_path 21 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 22 | 23 | def __enter__(self): 24 | return self 25 | 26 | def __exit__(self, *args): 27 | self.close() 28 | 29 | def path(self): 30 | """Return the path to the file that backs this database.""" 31 | return self.path 32 | 33 | def close(self): 34 | """Close the connection to the database.""" 35 | self.connection.close() 36 | 37 | def get_doc_ids(self): 38 | """Fetch all ids of docs stored in the db.""" 39 | cursor = self.connection.cursor() 40 | cursor.execute("SELECT id FROM documents") 41 | results = [r[0] for r in cursor.fetchall()] 42 | cursor.close() 43 | return results 44 | 45 | def get_doc_text(self, doc_id): 46 | """Fetch the raw text of the doc for 'doc_id'.""" 47 | cursor = self.connection.cursor() 48 | cursor.execute( 49 | "SELECT text FROM documents WHERE id = ?", 50 | (utils.normalize(doc_id),) 51 | ) 52 | result = cursor.fetchone() 53 | cursor.close() 54 | return result if result is None else result[0] 55 | 56 | 57 | -------------------------------------------------------------------------------- /drqa_retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from drqa_retriever import utils 17 | from drqa_retriever import DEFAULTS 18 | import drqa_tokenizers as tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /drqa_retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename, allow_pickle=True) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | -------------------------------------------------------------------------------- /drqa_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /drqa_tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /drqa_tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P