├── LICENSE ├── README.md ├── build_db.py ├── build_tfidf.py ├── drqa_retriever ├── __init__.py ├── doc_db.py ├── tfidf_doc_ranker.py └── utils.py ├── drqa_tokenizers ├── __init__.py ├── corenlp_tokenizer.py ├── regexp_tokenizer.py ├── simple_tokenizer.py ├── spacy_tokenizer.py └── tokenizer.py ├── filter_subset_wiki.py ├── inference_tfidf.py ├── requirements.txt ├── run.sh └── run_inference.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For DrQA software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README for retrieval-based baselines 2 | 3 | This repo provides guidelines for training and testing retrieval-based baselines for [NeurIPS Competition on Efficient Open-domain Question Answering](http://efficientqa.github.io/). 4 | 5 | We provide two retrieval-based baselines: 6 | 7 | - TF-IDF: TF-IDF retrieval built on fixed-length passages, adapted from the [DrQA system's implementation](https://github.com/facebookresearch/DrQA). 8 | - DPR: A learned dense passage retriever, detailed in [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) (Karpukhin et al, 2020). Our baseline is adapted from the [original implementation](https://github.com/facebookresearch/DPR). 9 | 10 | 11 | Note that for both baselines, we use text blocks of 100 words as passages and a BERT-base multi-passage reader. See more details in the [DPR paper](https://arxiv.org/pdf/2004.04906.pdf). 12 | 13 | We provide two variants for each model, using (1) full Wikipedia (`full`) and (2) A subset of Wikipedia articles which are found relevant to the questions on the train data (`subset`). In particular, we think of `subset` as a naive way to reduce the disk memory usage for the retrieval-based baselines. 14 | 15 | 16 | *Note: If you want to try parameter-only baselines (T5-based) for the competition, please note that implementations on T5-based Closed-book QA model is available [here](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa).* 17 | 18 | *Note: If you want simple guidelines on making end-to-end QA predictions using pretrained models, please refer to [this tutorial](https://github.com/efficientqa/efficientqa.github.io/blob/master/getting_started.md).* 19 | 20 | ## Content 21 | 22 | 1. [Getting ready](#getting-ready) 23 | 2. [TFIDF retrieval](#tfidf-retrieval) 24 | 3. [DPR retrieval](#dpr-retrieval) 25 | 4. [DPR reader](#dpr-reader) 26 | 5. [Result](#result) 27 | 28 | ## Getting ready 29 | 30 | ### Git clone 31 | 32 | ```bash 33 | git clone https://github.com/facebookresearch/DPR.git # dependency 34 | git clone https://github.com/efficientqa/retrieval-based-baselines.git # this repo 35 | ``` 36 | 37 | ### Download data 38 | 39 | Follow [DPR repo][dpr] in order to download NQ data and Wikipedia DB. Specificially, after running `cd DPR` and let `base_dir` as your base directory to store data and pretrained models, 40 | 41 | 42 | 1. Download QA pairs by `python3 data/download_data.py --resource data.retriever.qas --output_dir ${base_dir}` and `python3 data/download_data.py --resource data.retriever.nq --output_dir ${base_dir}`. 43 | 2. Download wikipedia DB by `python3 data/download_data.py --resource data.wikipedia_split --output_dir ${base_dir}`. 44 | 3. Download gold question-passage pairs by `python3 data/download_data.py --resource data.gold_passages_info --output_dir ${base_dir}`. 45 | 46 | Optionally, if you want to try `subset` variant, run `cd ../retrieval-based-baselines; python3 filter_subset_wiki.py --db_path ${base_dir}/data/wikipedia_split/psgs_w100.tsv --data_path ${base_dir}/data/retriever/nq-train.json`. This script will create a new passage DB containing passages which originated articles are those paired with question on the original NQ data (78,050 unique articles; 1,642,855 unique passages). 47 | This new DB will be stored at `${base_dir}/data/wikipedia_split/psgs_w100_subset.tsv`. 48 | 49 | From now on, we will denote Wikipedia DBs (either full or subset) as `db_path`. 50 | 51 | 52 | ## TFIDF retrieval 53 | 54 | Make sure to be in `retrieval-based-baselines` directory to run scripts for TFIDF (largely adapted from [DrQA repo][drqa]). 55 | 56 | **Step 1**: Run `pip install -r requirements.txt` 57 | 58 | **Step 2**: Build Sqlite DB via: 59 | ``` 60 | mkdir -p {base_dir}/tfidf 61 | python3 build_db.py ${db_path} ${base_dir}/tfidf/db.db --num-workers 60`. 62 | ``` 63 | **Step 3**: Run the following command to build TFIDF index. 64 | ``` 65 | python3 build_tfidf.py ${base_dir}/tfidf/db.db ${base_dir}/tfidf 66 | ``` 67 | It will save TF-IDF index in `${base_dir}/tfidf` 68 | 69 | **Step 4**: Run inference code to save retrieval results. 70 | ``` 71 | python3 inference_tfidf.py --qa_file ${base_dir}/data/retriever/qas/nq-{train|dev|test}.csv --db_path ${db_path} --out_file ${base_dir}/tfidf/nq-{train|dev|test}.json --tfidf_path {path_to_tfidf_index} 72 | ``` 73 | 74 | The resulting files, `${base_dir}/tfidf/nq-{train|dev|test}-tfidf.json` are ready to be fed into the DPR reader. 75 | 76 | ## DPR retrieval 77 | 78 | Follow [DPR repo][dpr] to train DPR retriever and make inference. You can follow steps until [Retriever validation](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). 79 | 80 | 81 | If you want to use retriever checkpoint provided by DPR, follow these three steps. 82 | 83 | **Step 1**: Make sure to be in `DPR` directory, and download retriever checkpoint by `python3 data/download_data.py --resource checkpoint.retriever.multiset.bert-base-encoder --output_dir ${base_dir}`. 84 | 85 | **Step 2**: Save passage vectors by following [Generating representations](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). Note that you can replace `ctx_file` to your own `db_path` if you are trying "seen only" version. In particular, you can do 86 | ``` 87 | python3 generate_dense_embeddings.py --model_file ${base_dir}/checkpoint/retriever/multiset/bert-base-encoder.cp --ctx_file ${db_path} --shard_id {0-19} --num_shards 20 --out_file ${base_dir}/dpr_ctx 88 | ``` 89 | 90 | **Step 3**: Save retrieval results by following [Retriever validation](https://github.com/facebookresearch/DPR/tree/master#retriever-validation-against-the-entire-set-of-documents). In particular, you can do 91 | ``` 92 | mkdir -p ${base_dir}/dpr_retrieval 93 | python3 dense_retriever.py \ 94 | --model_file ${base_dir}/checkpoint/retriever/single/nq/bert-base-encoder.cp \ 95 | --ctx_file ${dp_path} \ 96 | --qa_file ${base_dir}/data/retriever/qas/nq-{train|dev|test}.csv \ 97 | --encoded_ctx_file ${base_dir}/'dpr_ctx*' \ 98 | --out_file ${base_dir}/dpr_retrieval/nq-{train|dev|test}.json \ 99 | --n-docs 200 \ 100 | --save_or_load_index # this to save the dense index if it was built for the first time, and load it next times. 101 | ``` 102 | 103 | Now, `${base_dir}/dpr_retrieval/nq-{train|dev|test}.json` is ready to be fed into DPR reader. 104 | 105 | ## DPR reader 106 | 107 | *Note*: The following instruction is identical to instructions from [DPR README](https://github.com/facebookresearch/DPR#optional-reader-model-input-data-pre-processing), but we rewrite it with hyperparamters specified for our baselines. 108 | 109 | The following instruction is for training the reader using TFIDF results, saved in `${base_dir}/tfidf/nq-{train|dev|test}-tfidf.json`. In order to use DPR retrieval results, simply replace paths to these files to `${base_dir}/dpr_retrieval/nq-{train|dev|test}.json` 110 | 111 | **Step 1**: Preprocess data. 112 | 113 | ``` 114 | python3 preprocess_reader_data.py \ 115 | --retriever_results ${base_dir}/tfidf/nq-{train|dev|test}.json \ 116 | --gold_passages ${base_dir}/data/gold_passages_info/nq_{train|dev|test}.json \ 117 | --do_lower_case \ 118 | --pretrained_model_cfg bert-base-uncased \ 119 | --encoder_model_type hf_bert \ 120 | --out_file ${base_dir}/tfidf/nq-{train|dev|test}-tfidf \ 121 | --is_train_set # specify this only when it is train data 122 | ``` 123 | 124 | **Step 2**: Train the reader. 125 | ``` 126 | python3 train_reader.py \ 127 | --encoder_model_type hf_bert \ 128 | --pretrained_model_cfg bert-base-uncased \ 129 | --train_file ${base_dir}/tfidf/'nq-train*.pkl' \ 130 | --dev_file ${base_dir}/tfidf/'nq-dev*.pkl' \ 131 | --output_dir ${base_dir}/checkpoints/reader_from_tfidf \ 132 | --seed 42 \ 133 | --learning_rate 1e-5 \ 134 | --eval_step 2000 \ 135 | --eval_top_docs 50 \ 136 | --warmup_steps 0 \ 137 | --sequence_length 350 \ 138 | --batch_size 16 \ 139 | --passages_per_question 24 \ 140 | --num_train_epochs 100000 \ 141 | --dev_batch_size 72 \ 142 | --passages_per_question_predict 50 143 | ``` 144 | 145 | **Step 3**: Test the reader. 146 | ``` 147 | python train_reader.py \ 148 | --prediction_results_file ${base_dir}/checkpoints/reader_from_tfidf/dev_predictions.json \ 149 | --eval_top_docs 10 20 40 50 80 100 \ 150 | --dev_file ${base_dir}/tfidf/`nq-dev*.pkl` \ 151 | --model_file ${base_dir}/checkpoints/reader_from_tfidf/{checkpoint file} \ 152 | --dev_batch_size 80 \ 153 | --passages_per_question_predict 100 \ 154 | --sequence_length 350 155 | ``` 156 | 157 | [drqa]: https://github.com/facebookresearch/DrQA/ 158 | [dpr]: https://github.com/facebookresearch/DPR 159 | 160 | ## Result 161 | 162 | |Model|Exact Mach|Disk usage (gb)| 163 | |---|---|---| 164 | |TFIDF-full|32.0|20.1| 165 | |TFIDF-subset|31.0|2.8| 166 | |DPR-full|41.0|66.4| 167 | |DPR-subset|34.8|5.9| 168 | 169 | 170 | -------------------------------------------------------------------------------- /build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to read in and store documents in a sqlite database.""" 8 | 9 | import argparse 10 | import sqlite3 11 | import json 12 | import os 13 | import logging 14 | import importlib.util 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from tqdm import tqdm 18 | from drqa_retriever import utils 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 23 | console = logging.StreamHandler() 24 | console.setFormatter(fmt) 25 | logger.addHandler(console) 26 | 27 | 28 | # ------------------------------------------------------------------------------ 29 | # Import helper 30 | # ------------------------------------------------------------------------------ 31 | 32 | 33 | PREPROCESS_FN = None 34 | 35 | 36 | def init(filename): 37 | global PREPROCESS_FN 38 | if filename: 39 | PREPROCESS_FN = import_module(filename).preprocess 40 | 41 | 42 | def import_module(filename): 43 | """Import a module given a full path to the file.""" 44 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 45 | module = importlib.util.module_from_spec(spec) 46 | spec.loader.exec_module(module) 47 | return module 48 | 49 | 50 | # ------------------------------------------------------------------------------ 51 | # Store corpus. 52 | # ------------------------------------------------------------------------------ 53 | 54 | 55 | def iter_files(path): 56 | """Walk through all files located under a root path.""" 57 | if os.path.isfile(path): 58 | yield path 59 | elif os.path.isdir(path): 60 | for dirpath, _, filenames in os.walk(path): 61 | for f in filenames: 62 | yield os.path.join(dirpath, f) 63 | else: 64 | raise RuntimeError('Path %s is invalid' % path) 65 | 66 | 67 | def get_contents(filename): 68 | """Parse the contents of a file. Each line is a JSON encoded document.""" 69 | global PREPROCESS_FN 70 | documents = [] 71 | if filename.endswith(".tsv.gz"): 72 | raise NotImplementedError("TODO") 73 | elif filename.endswith(".tsv"): 74 | import csv 75 | with open(filename) as tsvfile: 76 | reader = csv.reader(tsvfile, delimiter='\t') 77 | # file format: doc_id, doc_text, title 78 | for (doc_id, doc_text, title) in reader: 79 | if doc_id=="id": continue 80 | documents.append((doc_id, utils.normalize(title) + " " + doc_text)) 81 | else: 82 | with open(filename) as f: 83 | for line in f: 84 | # Parse document 85 | doc = json.loads(line) 86 | # Maybe preprocess the document with custom function 87 | if PREPROCESS_FN: 88 | doc = PREPROCESS_FN(doc) 89 | # Skip if it is empty or None 90 | if not doc: 91 | continue 92 | # Add the document 93 | documents.append((utils.normalize(doc['id']), doc['text'])) 94 | return documents 95 | 96 | def store_contents(data_path, save_path, preprocess, num_workers=None): 97 | """Preprocess and store a corpus of documents in sqlite. 98 | 99 | Args: 100 | data_path: Root path to directory (or directory of directories) of files 101 | containing json encoded documents (must have `id` and `text` fields). 102 | save_path: Path to output sqlite db. 103 | preprocess: Path to file defining a custom `preprocess` function. Takes 104 | in and outputs a structured doc. 105 | num_workers: Number of parallel processes to use when reading docs. 106 | """ 107 | if os.path.isfile(save_path): 108 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 109 | 110 | logger.info('Reading into database...') 111 | conn = sqlite3.connect(save_path) 112 | c = conn.cursor() 113 | c.execute("CREATE TABLE documents (id PRIMARY KEY, text);") 114 | 115 | if num_workers is None or num_workers==1: 116 | files = [f for f in iter_files(data_path)] 117 | count = 0 118 | with tqdm(total=len(files)) as pbar: 119 | for f in files: 120 | pairs = get_contents(f) 121 | count += len(pairs) 122 | c.executemany("INSERT INTO documents VALUES (?,?)", pairs) 123 | pbar.update() 124 | else: 125 | workers = ProcessPool(num_workers, initializer=init, initargs=(preprocess,)) 126 | files = [f for f in iter_files(data_path)] 127 | count = 0 128 | with tqdm(total=len(files)) as pbar: 129 | for pairs in tqdm(workers.imap_unordered(get_contents, files)): 130 | count += len(pairs) 131 | c.executemany("INSERT INTO documents VALUES (?,?)", pairs) 132 | pbar.update() 133 | logger.info('Read %d docs.' % count) 134 | logger.info('Committing...') 135 | conn.commit() 136 | conn.close() 137 | 138 | 139 | # ------------------------------------------------------------------------------ 140 | # Main. 141 | # ------------------------------------------------------------------------------ 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('data_path', type=str, help='/path/to/data') 147 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db') 148 | parser.add_argument('--preprocess', type=str, default=None, 149 | help=('File path to a python module that defines ' 150 | 'a `preprocess` function')) 151 | parser.add_argument('--num-workers', type=int, default=None, 152 | help='Number of CPU processes (for tokenizing, etc)') 153 | args = parser.parse_args() 154 | 155 | store_contents( 156 | args.data_path, args.save_path, args.preprocess, args.num_workers 157 | ) 158 | -------------------------------------------------------------------------------- /build_tfidf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to build the tf-idf document matrices for retrieval.""" 8 | 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import argparse 12 | import os 13 | import math 14 | import logging 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from multiprocessing.util import Finalize 18 | from functools import partial 19 | from collections import Counter 20 | 21 | import drqa_retriever as retriever 22 | import drqa_tokenizers as tokenizers 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 27 | console = logging.StreamHandler() 28 | console.setFormatter(fmt) 29 | logger.addHandler(console) 30 | 31 | 32 | # ------------------------------------------------------------------------------ 33 | # Multiprocessing functions 34 | # ------------------------------------------------------------------------------ 35 | 36 | DOC2IDX = None 37 | PROCESS_TOK = None 38 | PROCESS_DB = None 39 | 40 | 41 | def init(tokenizer_class, db_class, db_opts): 42 | global PROCESS_TOK, PROCESS_DB 43 | PROCESS_TOK = tokenizer_class() 44 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 45 | PROCESS_DB = db_class(**db_opts) 46 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 47 | 48 | 49 | def fetch_text(doc_id): 50 | global PROCESS_DB 51 | return PROCESS_DB.get_doc_text(doc_id) 52 | 53 | 54 | def tokenize(text): 55 | global PROCESS_TOK 56 | return PROCESS_TOK.tokenize(text) 57 | 58 | 59 | # ------------------------------------------------------------------------------ 60 | # Build article --> word count sparse matrix. 61 | # ------------------------------------------------------------------------------ 62 | 63 | 64 | def count(ngram, hash_size, doc_id): 65 | """Fetch the text of a document and compute hashed ngrams counts.""" 66 | global DOC2IDX 67 | row, col, data = [], [], [] 68 | # Tokenize 69 | tokens = tokenize(retriever.utils.normalize(fetch_text(doc_id))) 70 | 71 | # Get ngrams from tokens, with stopword/punctuation filtering. 72 | ngrams = tokens.ngrams( 73 | n=ngram, uncased=True, filter_fn=retriever.utils.filter_ngram 74 | ) 75 | 76 | # Hash ngrams and count occurences 77 | counts = Counter([retriever.utils.hash(gram, hash_size) for gram in ngrams]) 78 | 79 | # Return in sparse matrix data format. 80 | row.extend(counts.keys()) 81 | col.extend([DOC2IDX[doc_id]] * len(counts)) 82 | data.extend(counts.values()) 83 | return row, col, data 84 | 85 | 86 | def get_count_matrix(args, db, db_opts): 87 | """Form a sparse word to document count matrix (inverted index). 88 | 89 | M[i, j] = # times word i appears in document j. 90 | """ 91 | # Map doc_ids to indexes 92 | global DOC2IDX 93 | db_class = retriever.get_class(db) 94 | with db_class(**db_opts) as doc_db: 95 | doc_ids = doc_db.get_doc_ids() 96 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)} 97 | 98 | # Setup worker pool 99 | tok_class = tokenizers.get_class(args.tokenizer) 100 | workers = ProcessPool( 101 | args.num_workers, 102 | initializer=init, 103 | initargs=(tok_class, db_class, db_opts) 104 | ) 105 | 106 | # Compute the count matrix in steps (to keep in memory) 107 | logger.info('Mapping...') 108 | row, col, data = [], [], [] 109 | step = max(int(len(doc_ids) / 10), 1) 110 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)] 111 | _count = partial(count, args.ngram, args.hash_size) 112 | for i, batch in enumerate(batches): 113 | logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25) 114 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch): 115 | row.extend(b_row) 116 | col.extend(b_col) 117 | data.extend(b_data) 118 | workers.close() 119 | workers.join() 120 | 121 | logger.info('Creating sparse matrix...') 122 | count_matrix = sp.csr_matrix( 123 | (data, (row, col)), shape=(args.hash_size, len(doc_ids)) 124 | ) 125 | count_matrix.sum_duplicates() 126 | return count_matrix, (DOC2IDX, doc_ids) 127 | 128 | 129 | # ------------------------------------------------------------------------------ 130 | # Transform count matrix to different forms. 131 | # ------------------------------------------------------------------------------ 132 | 133 | 134 | def get_tfidf_matrix(cnts): 135 | """Convert the word count matrix into tfidf one. 136 | 137 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 138 | * tf = term frequency in document 139 | * N = number of documents 140 | * Nt = number of occurences of term in all documents 141 | """ 142 | Ns = get_doc_freqs(cnts) 143 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5)) 144 | idfs[idfs < 0] = 0 145 | idfs = sp.diags(idfs, 0) 146 | tfs = cnts.log1p() 147 | tfidfs = idfs.dot(tfs) 148 | return tfidfs 149 | 150 | 151 | def get_doc_freqs(cnts): 152 | """Return word --> # of docs it appears in.""" 153 | binary = (cnts > 0).astype(int) 154 | freqs = np.array(binary.sum(1)).squeeze() 155 | return freqs 156 | 157 | 158 | # ------------------------------------------------------------------------------ 159 | # Main. 160 | # ------------------------------------------------------------------------------ 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('db_path', type=str, default=None, 166 | help='Path to sqlite db holding document texts') 167 | parser.add_argument('out_dir', type=str, default=None, 168 | help='Directory for saving output files') 169 | parser.add_argument('--ngram', type=int, default=2, 170 | help=('Use up to N-size n-grams ' 171 | '(e.g. 2 = unigrams + bigrams)')) 172 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)), 173 | help='Number of buckets to use for hashing ngrams') 174 | parser.add_argument('--tokenizer', type=str, default='simple', 175 | help=("String option specifying tokenizer type to use " 176 | "(e.g. 'corenlp')")) 177 | parser.add_argument('--num-workers', type=int, default=None, 178 | help='Number of CPU processes (for tokenizing, etc)') 179 | args = parser.parse_args() 180 | 181 | logging.info('Counting words...') 182 | count_matrix, doc_dict = get_count_matrix( 183 | args, 'sqlite', {'db_path': args.db_path} 184 | ) 185 | 186 | logger.info('Making tfidf vectors...') 187 | tfidf = get_tfidf_matrix(count_matrix) 188 | 189 | logger.info('Getting word-doc frequencies...') 190 | freqs = get_doc_freqs(count_matrix) 191 | 192 | basename = os.path.splitext(os.path.basename(args.db_path))[0] 193 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=%s' % 194 | (args.ngram, args.hash_size, args.tokenizer)) 195 | filename = os.path.join(args.out_dir, basename) 196 | 197 | logger.info('Saving to %s.npz' % filename) 198 | metadata = { 199 | 'doc_freqs': freqs, 200 | 'tokenizer': args.tokenizer, 201 | 'hash_size': args.hash_size, 202 | 'ngram': args.ngram, 203 | 'doc_dict': doc_dict 204 | } 205 | retriever.utils.save_sparse_csr(filename, tfidf, metadata) 206 | -------------------------------------------------------------------------------- /drqa_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DATA_DIR = "/home/sewon/analysis_multiple/DrQA/data" 11 | DEFAULTS = { 12 | 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), 13 | 'tfidf_path': os.path.join( 14 | DATA_DIR, 15 | 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' 16 | ), 17 | } 18 | 19 | 20 | def set_default(key, value): 21 | global DEFAULTS 22 | DEFAULTS[key] = value 23 | 24 | 25 | def get_class(name): 26 | if name == 'tfidf': 27 | return TfidfDocRanker 28 | if name == 'sqlite': 29 | return DocDB 30 | raise RuntimeError('Invalid retriever class: %s' % name) 31 | 32 | 33 | from .doc_db import DocDB 34 | from .tfidf_doc_ranker import TfidfDocRanker 35 | -------------------------------------------------------------------------------- /drqa_retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | 12 | 13 | class DocDB(object): 14 | """Sqlite backed document storage. 15 | 16 | Implements get_doc_text(doc_id). 17 | """ 18 | 19 | def __init__(self, db_path=None): 20 | self.path = db_path 21 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 22 | 23 | def __enter__(self): 24 | return self 25 | 26 | def __exit__(self, *args): 27 | self.close() 28 | 29 | def path(self): 30 | """Return the path to the file that backs this database.""" 31 | return self.path 32 | 33 | def close(self): 34 | """Close the connection to the database.""" 35 | self.connection.close() 36 | 37 | def get_doc_ids(self): 38 | """Fetch all ids of docs stored in the db.""" 39 | cursor = self.connection.cursor() 40 | cursor.execute("SELECT id FROM documents") 41 | results = [r[0] for r in cursor.fetchall()] 42 | cursor.close() 43 | return results 44 | 45 | def get_doc_text(self, doc_id): 46 | """Fetch the raw text of the doc for 'doc_id'.""" 47 | cursor = self.connection.cursor() 48 | cursor.execute( 49 | "SELECT text FROM documents WHERE id = ?", 50 | (utils.normalize(doc_id),) 51 | ) 52 | result = cursor.fetchone() 53 | cursor.close() 54 | return result if result is None else result[0] 55 | 56 | 57 | -------------------------------------------------------------------------------- /drqa_retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from drqa_retriever import utils 17 | from drqa_retriever import DEFAULTS 18 | import drqa_tokenizers as tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /drqa_retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename, allow_pickle=True) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | -------------------------------------------------------------------------------- /drqa_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /drqa_tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /drqa_tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 47 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 48 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 49 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 50 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 51 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 52 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 53 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 54 | self.NON_WS), 55 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 56 | ) 57 | if len(kwargs.get('annotators', {})) > 0: 58 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 59 | (type(self).__name__, kwargs.get('annotators'))) 60 | self.annotators = set() 61 | self.substitutions = kwargs.get('substitutions', True) 62 | 63 | def tokenize(self, text): 64 | data = [] 65 | matches = [m for m in self._regexp.finditer(text)] 66 | for i in range(len(matches)): 67 | # Get text 68 | token = matches[i].group() 69 | 70 | # Make normalizations for special token types 71 | if self.substitutions: 72 | groups = matches[i].groupdict() 73 | if groups['sdquote']: 74 | token = "``" 75 | elif groups['edquote']: 76 | token = "''" 77 | elif groups['ssquote']: 78 | token = "`" 79 | elif groups['esquote']: 80 | token = "'" 81 | elif groups['dash']: 82 | token = '--' 83 | elif groups['ellipses']: 84 | token = '...' 85 | 86 | # Get whitespace 87 | span = matches[i].span() 88 | start_ws = span[0] 89 | if i + 1 < len(matches): 90 | end_ws = matches[i + 1].span()[0] 91 | else: 92 | end_ws = span[1] 93 | 94 | # Format data 95 | data.append(( 96 | token, 97 | text[start_ws: end_ws], 98 | span, 99 | )) 100 | return Tokens(data, self.annotators) 101 | -------------------------------------------------------------------------------- /drqa_tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /drqa_tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 29 | nlp_kwargs['tagger'] = False 30 | if 'ner' not in self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 39 | self.nlp.tagger(tokens) 40 | if 'ner' in self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /drqa_tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /filter_subset_wiki.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import csv 4 | import argparse 5 | 6 | def main(args): 7 | with open(args.data_path, "r") as f: 8 | data = json.load(f) 9 | seen_doc_titles = set() 10 | for dp in data: 11 | seen_doc_titles |= set([ctx["title"] for ctx in dp["positive_ctxs"][:5]]) 12 | print ("Consider {} seen docs".format(len(seen_doc_titles))) 13 | 14 | rows = [] 15 | with open(args.db_path, "r") as tsvfile: 16 | reader = csv.reader(tsvfile, delimiter='\t') 17 | for doc_id, doc_text, title in reader: 18 | # file format: doc_id, doc_text, title 19 | if doc_id != 'id': 20 | rows.append((doc_id, doc_text, title)) 21 | orig_n_passages = len(rows) 22 | rows = [row for row in rows if row[2] in seen_doc_titles] 23 | print ("Reducing # of passages from {} to {}".format(orig_n_passages, len(rows))) 24 | 25 | with open(args.db_path.replace(".tsv", "_subset.tsv"), "w") as f: 26 | for row in rows: 27 | f.write("{}\t{}\t{}\n".format(row[0], row[1], row[2])) 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--data_path', type=str, default="/checkpoint/sewonmin/dpr/data/retriever/nq-train.json") 32 | parser.add_argument('--db_path', type=str, default="/checkpoint/sewonmin/dpr/data/wikipedia_split/psgs_w100.tsv") 33 | 34 | args = parser.parse_args() 35 | main(args) 36 | 37 | -------------------------------------------------------------------------------- /inference_tfidf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | #sys.path.append('/private/home/sewonmin/EfficientQA-baselines/DPR') 4 | #from dense_retriever import validate, save_results 5 | import drqa_retriever as retriever 6 | 7 | 8 | def main(args): 9 | questions = [] 10 | question_answers = [] 11 | 12 | for ds_item in parse_qa_csv_file(args.qa_file): 13 | question, answers = ds_item 14 | questions.append(question) 15 | question_answers.append(answers) 16 | 17 | top_ids_and_scores = [] 18 | for question in questions: 19 | psg_ids, scores = ranker.closest_docs(question, args.n_docs) 20 | top_ids_and_scores.append((psg_ids, scores)) 21 | 22 | all_passages = load_passages(args.db_path) 23 | 24 | if len(all_passages) == 0: 25 | raise RuntimeError('No passages data found. Please specify ctx_file param properly.') 26 | 27 | questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, 28 | args.match) 29 | 30 | if args.out_file: 31 | save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file) 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--qa_file', required=True, type=str, default=None) 36 | parser.add_argument('--dpr_path', type=str, default="../DPR") 37 | parser.add_argument('--db_path', type=str, default="/checkpoint/sewonmin/dpr/data/wikipedia_split/psgs_w100_seen_only.tsv") 38 | parser.add_argument('--tfidf_path', type=str, default="/checkpoint/sewonmin/dpr/drqa_retrieval_seen_only/db-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz") 39 | parser.add_argument('--out_file', type=str, default=None) 40 | parser.add_argument('--match', type=str, default='string', choices=['regex', 'string']) 41 | parser.add_argument('--n-docs', type=int, default=100) 42 | parser.add_argument('--validation_workers', type=int, default=16) 43 | args = parser.parse_args() 44 | 45 | sys.path.append(args.dpr_path) 46 | from dense_retriever import parse_qa_csv_file, load_passages, validate, save_results 47 | 48 | ranker = retriever.get_class('tfidf')(tfidf_path=args.tfidf_path) 49 | 50 | main(args) 51 | 52 | 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | regex 4 | tqdm 5 | scipy 6 | nltk 7 | elasticsearch 8 | pexpect==4.2.1 9 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | RETRIEVAL=$1 4 | 5 | tfidf_index="random" 6 | dpr_index="random" 7 | dpr_retrieval_checkpoint="random" 8 | n_paragraphs="100" 9 | 10 | wget https://raw.githubusercontent.com/efficientqa/nq-open/master/NQ-open.dev.jsonl 11 | #python3 ../DPR/data/download_data.py --resource data.retriever.qas.nq --output_dir ${base_dir} 12 | python3 ../DPR/data/download_data.py --resource data.wikipedia_split --output_dir ${base_dir} 13 | 14 | if [ $RETRIEVAL = "tfidf-full" ] 15 | then 16 | python3 ../DPR/data/download_data.py --resource indexes.tfidf.nq.full --output_dir ${base_dir} # DrQA index 17 | python3 ../DPR/data/download_data.py --resource checkpoint.reader.nq-tfidf.hf-bert-base --output_dir ${base_dir} # reader checkpoint 18 | tfidf_index="${base_dir}/indexes/tfidf/nq/full.npz" 19 | reader_checkpoint="${base_dir}/checkpoint/reader/nq-tfidf/hf-bert-base.cp" 20 | retrieval_type="tfidf" 21 | db_name="psgs_w100.tsv" 22 | elif [ $RETRIEVAL = "tfidf-subset" ] 23 | then 24 | python3 ../DPR/data/download_data.py --resource data.retriever.nq-train --output_dir ${base_dir} 25 | python3 filter_subset_wiki.py --db_path ${base_dir}/data/wikipedia_split/psgs_w100.tsv --data_path ${base_dir}/data/retriever/nq-train.json 26 | python3 ../DPR/data/download_data.py --resource indexes.tfidf.nq.subset --output_dir ${base_dir} # DrQA index 27 | python3 ../DPR/data/download_data.py --resource checkpoint.reader.nq-tfidf-subset.hf-bert-base --output_dir ${base_dir} # reader checkpoint 28 | tfidf_index="${base_dir}/indexes/tfidf/nq/subset.npz" 29 | reader_checkpoint="${base_dir}/checkpoint/reader/nq-tfidf-subset/hf-bert-base.cp" 30 | retrieval_type="tfidf" 31 | db_name="psgs_w100_subset.tsv" 32 | elif [ $RETRIEVAL = "dpr-full" ] 33 | then 34 | python3 ../DPR/data/download_data.py --resource checkpoint.retriever.single.nq.bert-base-encoder --output_dir ${base_dir} # retrieval checkpoint 35 | python3 ../DPR/data/download_data.py --resource indexes.single.nq.full --output_dir ${base_dir} # DPR index 36 | python3 ../DPR/data/download_data.py --resource checkpoint.reader.nq-single.hf-bert-base --output_dir ${base_dir} # reader checkpoint 37 | dpr_retrieval_checkpoint="${base_dir}/checkpoint/retriever/single/nq/bert-base-encoder.cp" 38 | dpr_index="${base_dir}/indexes/single/nq/full" 39 | reader_checkpoint="${base_dir}/checkpoint/reader/nq-single/hf-bert-base.cp" 40 | retrieval_type="dpr" 41 | db_name="psgs_w100.tsv" 42 | n_paragraphs="40" 43 | elif [ $RETRIEVAL = "dpr-subset" ] 44 | then 45 | python3 ../DPR/data/download_data.py --resource data.retriever.nq-train --output_dir ${base_dir} 46 | python3 filter_subset_wiki.py --db_path ${base_dir}/data/wikipedia_split/psgs_w100.tsv --data_path ${base_dir}/data/retriever/nq-train.json 47 | python3 ../DPR/data/download_data.py --resource checkpoint.retriever.single.nq.bert-base-encoder --output_dir ${base_dir} # retrieval checkpoint 48 | python3 ../DPR/data/download_data.py --resource indexes.single.nq.subset --output_dir ${base_dir} # DPR index 49 | python3 ../DPR/data/download_data.py --resource checkpoint.reader.nq-single-subset.hf-bert-base --output_dir ${base_dir} # reader checkpoint 50 | dpr_retrieval_checkpoint="${base_dir}/checkpoint/retriever/single/nq/bert-base-encoder.cp" 51 | dpr_index="${base_dir}/indexes/single/nq/subset" 52 | reader_checkpoint="${base_dir}/checkpoint/reader/nq-single-subset/hf-bert-base.cp" 53 | retrieval_type="dpr" 54 | db_name="psgs_w100_subset.tsv" 55 | n_paragraphs="40" 56 | fi 57 | python3 run_inference.py \ 58 | --qa_file NQ-open.dev.jsonl \ 59 | --retrieval_type ${retrieval_type} \ 60 | --db_path ${base_dir}/data/wikipedia_split/${db_name} \ 61 | --tfidf_path ${tfidf_index} \ 62 | --dpr_model_file ${dpr_retrieval_checkpoint} \ 63 | --dense_index_path ${dpr_index} \ 64 | --model_file ${reader_checkpoint} \ 65 | --dev_batch_size 64 \ 66 | --pretrained_model_cfg bert-base-uncased --encoder_model_type hf_bert --do_lower_case \ 67 | --sequence_length 350 --eval_top_docs 10 20 40 50 80 100 --passages_per_question_predict ${n_paragraphs} \ 68 | --prediction_results_file ${RETRIEVAL}_test_predictions.json 69 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | import os 8 | import sys 9 | import argparse 10 | import csv 11 | import json 12 | import numpy as np 13 | 14 | sys.path.append("../DPR") 15 | from dense_retriever import parse_qa_csv_file, load_passages, validate, save_results 16 | from dpr.options import add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, \ 17 | add_tokenizer_params, add_cuda_params, add_training_params, add_reader_preprocessing_params 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | 22 | # general params 23 | parser.add_argument('--qa_file', required=True, type=str, default=None) 24 | parser.add_argument('--retrieval_type', type=str, default='drqa', 25 | choices=['tfidf', 'dpr']) 26 | parser.add_argument('--dpr_model_file', type=str, default="/private/home/sewonmin/EfficientQA-baselines/DP") 27 | parser.add_argument('--db_path', type=str, default="/checkpoint/sewonmin/dpr/data/wikipedia_split/psgs_w100_seen_only.tsv") 28 | 29 | # retrieval specific params 30 | parser.add_argument('--dense_index_path', type=str, default="") 31 | parser.add_argument('--tfidf_path', type=str, default="/checkpoint/sewonmin/dpr/drqa_retrieval_seen_only/db-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz") 32 | parser.add_argument('--match', type=str, default='string', choices=['regex', 'string']) 33 | parser.add_argument('--n-docs', type=int, default=100) 34 | #parser.add_argument('--batch_size', type=int, default=32, help="Batch size for question encoder forward pass") 35 | parser.add_argument('--index_buffer', type=int, default=50000, 36 | help="Temporal memory data buffer size (in samples) for indexer") 37 | parser.add_argument("--hnsw_index", action='store_true', help='If enabled, use inference time efficient HNSW index') 38 | parser.add_argument("--save_or_load_index", action='store_true', default=True, help='If enabled, save index') 39 | 40 | # reader specific params 41 | add_encoder_params(parser) 42 | add_training_params(parser) 43 | add_tokenizer_params(parser) 44 | add_reader_preprocessing_params(parser) 45 | 46 | 47 | parser.add_argument("--max_n_answers", default=10, type=int, 48 | help="Max amount of answer spans to marginalize per singe passage") 49 | parser.add_argument('--passages_per_question', type=int, default=2, 50 | help="Total amount of positive and negative passages per question") 51 | parser.add_argument('--passages_per_question_predict', type=int, default=50, 52 | help="Total amount of positive and negative passages per question for evaluation") 53 | parser.add_argument("--max_answer_length", default=10, type=int, 54 | help="The maximum length of an answer that can be generated. This is needed because the start " 55 | "and end predictions are not conditioned on one another.") 56 | parser.add_argument('--eval_top_docs', nargs='+', type=int, 57 | help="top retrival passages thresholds to analyze prediction results for") 58 | parser.add_argument('--checkpoint_file_name', type=str, default='dpr_reader') 59 | parser.add_argument('--prediction_results_file', type=str, help='path to a file to write prediction results to') 60 | 61 | 62 | args = parser.parse_args() 63 | 64 | questions = [] 65 | question_answers = [] 66 | if args.qa_file.endswith(".csv"): 67 | for ds_item in parse_qa_csv_file(args.qa_file): 68 | question, answers = ds_item 69 | questions.append(question) 70 | question_answers.append(answers) 71 | else: 72 | with open(args.qa_file, "r") as f: 73 | for line in f: 74 | d = json.loads(line) 75 | questions.append(d["question"]) 76 | if "answer" not in d: 77 | d["answer"] = "random" 78 | question_answers.append(d["answer"]) 79 | if args.retrieval_type=="tfidf": 80 | import drqa_retriever as retriever 81 | ranker = retriever.get_class('tfidf')(tfidf_path=args.tfidf_path) 82 | top_ids_and_scores = [] 83 | for question in questions: 84 | psg_ids, scores = ranker.closest_docs(question, args.n_docs) 85 | top_ids_and_scores.append((psg_ids, scores)) 86 | else: 87 | from dpr.models import init_biencoder_components 88 | from dpr.utils.data_utils import Tensorizer 89 | from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint 90 | from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer 91 | from dense_retriever import DenseRetriever 92 | 93 | saved_state = load_states_from_checkpoint(args.dpr_model_file) 94 | set_encoder_params_from_state(saved_state.encoder_params, args) 95 | tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) 96 | encoder = encoder.question_model 97 | setup_args_gpu(args) 98 | encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, 99 | args.local_rank, 100 | args.fp16) 101 | encoder.eval() 102 | 103 | # load weights from the model file 104 | model_to_load = get_model_obj(encoder) 105 | prefix_len = len('question_model.') 106 | question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if 107 | key.startswith('question_model.')} 108 | model_to_load.load_state_dict(question_encoder_state) 109 | vector_size = model_to_load.get_out_size() 110 | 111 | index_buffer_sz = args.index_buffer 112 | if args.hnsw_index: 113 | index = DenseHNSWFlatIndexer(vector_size) 114 | index_buffer_sz = -1 # encode all at once 115 | else: 116 | index = DenseFlatIndexer(vector_size) 117 | 118 | retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) 119 | retriever.index.deserialize_from(args.dense_index_path) 120 | 121 | questions_tensor = retriever.generate_question_vectors(questions) 122 | top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs) 123 | 124 | 125 | all_passages = load_passages(args.db_path) 126 | 127 | retrieval_file = "tmp_{}.json".format(str(np.random.randint(0, 100000)).zfill(6)) 128 | questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, 129 | 1, args.match) 130 | 131 | save_results(all_passages, 132 | questions, 133 | question_answers, #["" for _ in questions], 134 | top_ids_and_scores, 135 | questions_doc_hits, #[[False for _ in range(args.n_docs)] for _n in questions], 136 | retrieval_file) 137 | setup_args_gpu(args) 138 | #print_args(args) 139 | args.dev_file = retrieval_file 140 | 141 | #from IPython import embed; embed() 142 | from train_reader import ReaderTrainer 143 | 144 | class MyReaderTrainer(ReaderTrainer): 145 | def _save_predictions(self, out_file, prediction_results): 146 | with open(out_file, 'w', encoding="utf-8") as output: 147 | save_results = [] 148 | for r in prediction_results: 149 | save_results.append({ 150 | 'question': r.id, 151 | 'prediction': r.predictions[args.passages_per_question_predict].prediction_text 152 | }) 153 | output.write(json.dumps(save_results, indent=4) + "\n") 154 | 155 | trainer = MyReaderTrainer(args) 156 | trainer.validate() 157 | 158 | os.remove(retrieval_file) 159 | for i in range(args.num_workers): 160 | os.remove(retrieval_file.replace(".json", ".{}.pkl".format(i))) 161 | 162 | --------------------------------------------------------------------------------