├── utils.py ├── convert_collection_to_memmap.py ├── LICENSE ├── convert_text_to_tokenized.py ├── .gitignore ├── modeling.py ├── retrieve.py ├── multi_retrieve.py ├── dataset.py ├── precompute.py ├── README.md ├── ms_marco_eval.py └── train.py /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | from collections import defaultdict 5 | import subprocess 6 | 7 | 8 | def generate_rank(input_path, output_path): 9 | score_dict = defaultdict(list) 10 | for line in open(input_path): 11 | query_id, para_id, score = line.split("\t") 12 | score_dict[int(query_id)].append((float(score), int(para_id))) 13 | with open(output_path, "w") as outFile: 14 | for query_id, para_lst in score_dict.items(): 15 | random.shuffle(para_lst) 16 | para_lst = sorted(para_lst, key=lambda x:x[0], reverse=True) 17 | for rank_idx, (score, para_id) in enumerate(para_lst): 18 | outFile.write("{}\t{}\t{}\n".format(query_id, para_id, rank_idx+1)) 19 | 20 | 21 | def eval_results(run_file_path, 22 | eval_script="./ms_marco_eval.py", 23 | qrels="./data/msmarco-passage/qrels.dev.small.tsv" ): 24 | assert os.path.exists(eval_script) and os.path.exists(qrels) 25 | result = subprocess.check_output(['python', eval_script, qrels, run_file_path]) 26 | match = re.search('MRR @10: ([\d.]+)', result.decode('utf-8')) 27 | mrr = float(match.group(1)) 28 | return mrr 29 | -------------------------------------------------------------------------------- /convert_collection_to_memmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | def cvt_collection_to_memmap(args): 8 | collection_size = sum(1 for _ in open(args.tokenized_collection)) 9 | max_seq_length = 512 10 | token_ids = np.memmap(f"{args.output_dir}/token_ids.memmap", dtype='int32', 11 | mode='w+', shape=(collection_size, max_seq_length)) 12 | pids = np.memmap(f"{args.output_dir}/pids.memmap", dtype='int32', 13 | mode='w+', shape=(collection_size,)) 14 | lengths = np.memmap(f"{args.output_dir}/lengths.memmap", dtype='int32', 15 | mode='w+', shape=(collection_size,)) 16 | 17 | for idx, line in enumerate(tqdm(open(args.tokenized_collection), 18 | desc="collection", total=collection_size)): 19 | data = json.loads(line) 20 | assert int(data['id']) == idx 21 | pids[idx] = idx 22 | lengths[idx] = len(data['ids']) 23 | ids = data['ids'][:max_seq_length] 24 | token_ids[idx, :lengths[idx]] = ids 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--tokenized_collection", type=str, 30 | default="./data/tokenize/collection.tokenize.json") 31 | parser.add_argument("--output_dir", type=str, default="./data/collection_memmap") 32 | args = parser.parse_args() 33 | 34 | if not os.path.exists(args.output_dir): 35 | os.makedirs(args.output_dir) 36 | cvt_collection_to_memmap(args) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, jingtaozhan 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /convert_text_to_tokenized.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import BertTokenizer 7 | 8 | def tokenize_file(tokenizer, input_file, output_file): 9 | total_size = sum(1 for _ in open(input_file)) 10 | with open(output_file, 'w') as outFile: 11 | for line in tqdm(open(input_file), total=total_size, 12 | desc=f"Tokenize: {os.path.basename(input_file)}"): 13 | seq_id, text = line.split("\t") 14 | tokens = tokenizer.tokenize(text) 15 | ids = tokenizer.convert_tokens_to_ids(tokens) 16 | outFile.write(json.dumps( 17 | {"id":seq_id, "ids":ids} 18 | )) 19 | outFile.write("\n") 20 | 21 | 22 | def tokenize_queries(args, tokenizer): 23 | for mode in ["dev"]:#, "eval.small", "dev", "eval", "train"]: 24 | query_output = f"{args.output_dir}/queries.{mode}.json" 25 | tokenize_file(tokenizer, f"{args.msmarco_dir}/queries.{mode}.tsv", query_output) 26 | 27 | 28 | def tokenize_collection(args, tokenizer): 29 | collection_output = f"{args.output_dir}/collection.tokenize.json" 30 | tokenize_file(tokenizer, f"{args.msmarco_dir}/collection.tsv", collection_output) 31 | 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--msmarco_dir", type=str, default="./data/msmarco-passage") 37 | parser.add_argument("--output_dir", type=str, default="./data/tokenize") 38 | parser.add_argument("--tokenize_queries", action="store_true") 39 | parser.add_argument("--tokenize_collection", action="store_true") 40 | args = parser.parse_args() 41 | 42 | if not os.path.exists(args.output_dir): 43 | os.makedirs(args.output_dir) 44 | 45 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 46 | if args.tokenize_queries: 47 | tokenize_queries(args, tokenizer) 48 | if args.tokenize_collection: 49 | tokenize_collection(args, tokenizer) 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | .DS_Store 133 | .vscode/ -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import logging 4 | from torch import nn 5 | import numpy as np 6 | from transformers.modeling_bert import (BertModel, BertPreTrainedModel) 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def _average_query_doc_embeddings(sequence_output, token_type_ids, valid_mask): 11 | query_flags = (token_type_ids==0)*(valid_mask==1) 12 | doc_flags = (token_type_ids==1)*(valid_mask==1) 13 | 14 | query_lengths = torch.sum(query_flags, dim=-1) 15 | query_lengths = torch.clamp(query_lengths, 1, None) 16 | doc_lengths = torch.sum(doc_flags, dim=-1) 17 | doc_lengths = torch.clamp(doc_lengths, 1, None) 18 | 19 | query_embeddings = torch.sum(sequence_output * query_flags[:,:,None], dim=1) 20 | query_embeddings = query_embeddings/query_lengths[:, None] 21 | doc_embeddings = torch.sum(sequence_output * doc_flags[:,:,None], dim=1) 22 | doc_embeddings = doc_embeddings/doc_lengths[:, None] 23 | return query_embeddings, doc_embeddings 24 | 25 | 26 | def _mask_both_directions(valid_mask, token_type_ids): 27 | assert valid_mask.dim() == 2 28 | attention_mask = valid_mask[:, None, :] 29 | 30 | type_attention_mask = torch.abs(token_type_ids[:, :, None] - token_type_ids[:, None, :]) 31 | attention_mask = attention_mask - type_attention_mask 32 | attention_mask = torch.clamp(attention_mask, 0, None) 33 | return attention_mask 34 | 35 | 36 | class RepBERT_Train(BertPreTrainedModel): 37 | def __init__(self, config): 38 | super(RepBERT_Train, self).__init__(config) 39 | self.bert = BertModel(config) 40 | self.init_weights() 41 | 42 | def forward(self, input_ids, token_type_ids, valid_mask, 43 | position_ids, labels=None): 44 | attention_mask = _mask_both_directions(valid_mask, token_type_ids) 45 | 46 | sequence_output = self.bert(input_ids, 47 | attention_mask=attention_mask, 48 | token_type_ids=token_type_ids, 49 | position_ids=position_ids)[0] 50 | 51 | query_embeddings, doc_embeddings = _average_query_doc_embeddings( 52 | sequence_output, token_type_ids, valid_mask 53 | ) 54 | 55 | similarities = torch.matmul(query_embeddings, doc_embeddings.T) 56 | 57 | output = (similarities, query_embeddings, doc_embeddings) 58 | if labels is not None: 59 | loss_fct = nn.MultiLabelMarginLoss() 60 | loss = loss_fct(similarities, labels) 61 | output = loss, *output 62 | return output 63 | 64 | 65 | 66 | def _average_sequence_embeddings(sequence_output, valid_mask): 67 | flags = valid_mask==1 68 | lengths = torch.sum(flags, dim=-1) 69 | lengths = torch.clamp(lengths, 1, None) 70 | sequence_embeddings = torch.sum(sequence_output * flags[:,:,None], dim=1) 71 | sequence_embeddings = sequence_embeddings/lengths[:, None] 72 | return sequence_embeddings 73 | 74 | 75 | class RepBERT(BertPreTrainedModel): 76 | def __init__(self, config): 77 | super(RepBERT, self).__init__(config) 78 | self.bert = BertModel(config) 79 | self.init_weights() 80 | 81 | if config.encode_type == "doc": 82 | self.token_type_func = torch.ones_like 83 | elif config.encode_type == "query": 84 | self.token_type_func = torch.zeros_like 85 | else: 86 | raise NotImplementedError() 87 | def forward(self, input_ids, valid_mask): 88 | token_type_ids = self.token_type_func(input_ids) 89 | sequence_output = self.bert(input_ids, 90 | attention_mask=valid_mask, 91 | token_type_ids=token_type_ids)[0] 92 | 93 | text_embeddings = _average_sequence_embeddings( 94 | sequence_output, valid_mask 95 | ) 96 | 97 | return text_embeddings 98 | 99 | -------------------------------------------------------------------------------- /retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import torch 5 | import logging 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | from queue import PriorityQueue 10 | from collections import namedtuple, defaultdict 11 | from transformers import BertTokenizer, BertConfig 12 | from torch.utils.data import DataLoader, Dataset 13 | from dataset import CollectionDataset, pack_tensor_2D, MSMARCODataset 14 | from utils import generate_rank, eval_results 15 | 16 | def get_embed_memmap(memmap_dir, dim): 17 | embedding_path = f"{memmap_dir}/embedding.memmap" 18 | id_path = f"{memmap_dir}/ids.memmap" 19 | # Tensor doesn't support non-writeable numpy array 20 | # Thus we use copy-on-write mode 21 | id_memmap = np.memmap(id_path, dtype='int32', mode="c") 22 | embedding_memmap = np.memmap(embedding_path, dtype='float32', 23 | mode="c", shape=(len(id_memmap), dim)) 24 | return embedding_memmap, id_memmap 25 | 26 | 27 | def allrank(args): 28 | doc_embedding_memmap, doc_id_memmap = get_embed_memmap( 29 | args.doc_embedding_dir, args.embedding_dim) 30 | assert np.all(doc_id_memmap == list(range(len(doc_id_memmap)))) 31 | 32 | query_embedding_memmap, query_id_memmap = get_embed_memmap( 33 | args.query_embedding_dir, args.embedding_dim) 34 | qid2pos = {identity:i for i, identity in enumerate(query_id_memmap)} 35 | results_dict = {qid:PriorityQueue(maxsize=args.hit) for qid in query_id_memmap} 36 | 37 | for doc_begin_index in tqdm(range(0, len(doc_id_memmap), args.per_gpu_doc_num), desc="doc"): 38 | doc_end_index = doc_begin_index+args.per_gpu_doc_num 39 | doc_ids = doc_id_memmap[doc_begin_index:doc_end_index] 40 | doc_embeddings = doc_embedding_memmap[doc_begin_index:doc_end_index] 41 | doc_embeddings = torch.from_numpy(doc_embeddings).to(args.device) 42 | for qid in tqdm(query_id_memmap, desc="query"): 43 | query_embedding = query_embedding_memmap[qid2pos[qid]] 44 | query_embedding = torch.from_numpy(query_embedding) 45 | query_embedding = query_embedding.to(args.device) 46 | 47 | all_scores = torch.sum(query_embedding * doc_embeddings, dim=-1) 48 | 49 | k = min(args.hit, len(doc_embeddings)) 50 | top_scores, top_indices = torch.topk(all_scores, k, largest=True, sorted=True) 51 | top_scores, top_indices = top_scores.cpu(), top_indices.cpu() 52 | top_doc_ids = doc_ids[top_indices.numpy()] 53 | cur_q_queue = results_dict[qid] 54 | for score, docid in zip(top_scores, top_doc_ids): 55 | score, docid = score.item(), docid.item() 56 | if cur_q_queue.full(): 57 | lowest_score, lowest_docid = cur_q_queue.get_nowait() 58 | if lowest_score >= score: 59 | cur_q_queue.put_nowait((lowest_score, lowest_docid)) 60 | break 61 | else: 62 | cur_q_queue.put_nowait((score, docid)) 63 | else: 64 | cur_q_queue.put_nowait((score, docid)) 65 | 66 | score_path = f"{args.output_path}.score" 67 | with open(score_path, 'w') as outputfile: 68 | for qid, docqueue in results_dict.items(): 69 | while not docqueue.empty(): 70 | score, docid = docqueue.get_nowait() 71 | outputfile.write(f"{qid}\t{docid}\t{score}\n") 72 | generate_rank(score_path, args.output_path) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | 78 | ## Required parameters 79 | parser.add_argument("--per_gpu_doc_num", default=1800000, type=int) 80 | parser.add_argument("--hit", type=int, default=1000) 81 | parser.add_argument("--embedding_dim", type=int, default=768) 82 | parser.add_argument("--output_path", type=str, 83 | default="./data/retrieve/repbert.dev.small.top1k.tsv") 84 | parser.add_argument("--doc_embedding_dir", type=str, 85 | default="./data/precompute/doc_embedding") 86 | parser.add_argument("--query_embedding_dir", type=str, 87 | default="./data/precompute/query_dev.small_embedding") 88 | args = parser.parse_args() 89 | 90 | print(args) 91 | 92 | # Setup CUDA, GPU 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | args.n_gpu = torch.cuda.device_count() 95 | assert args.n_gpu == 1 96 | 97 | args.device = device 98 | 99 | if not os.path.exists(os.path.dirname(args.output_path)): 100 | os.makedirs(os.path.dirname(args.output_path)) 101 | 102 | with torch.no_grad(): 103 | allrank(args) -------------------------------------------------------------------------------- /multi_retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import logging 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import traceback 9 | from functools import wraps 10 | from queue import PriorityQueue 11 | from multiprocessing import Pool, Manager 12 | from retrieve import get_embed_memmap 13 | from timeit import default_timer as timer 14 | 15 | 16 | def raise_immediately(func): 17 | @wraps(func) 18 | def ret_func(*args, **kwargs): 19 | try: 20 | func(*args, **kwargs) 21 | except: 22 | print(traceback.format_exc()) 23 | raise 24 | return ret_func 25 | 26 | 27 | @raise_immediately 28 | def writer(args, finish_queue_lst): 29 | _, query_id_memmap = get_embed_memmap( 30 | args.query_embedding_dir, args.embedding_dim) 31 | with open(args.output_path, 'w') as outFile: 32 | for qid in query_id_memmap: 33 | score_docid_lst = [] 34 | for q in finish_queue_lst: 35 | score_docid_lst = score_docid_lst + q.get() 36 | score_docid_lst = sorted(score_docid_lst, reverse=True) 37 | for rank_idx, (score, para_id) in enumerate(score_docid_lst[:args.hit]): 38 | outFile.write(f"{qid}\t{para_id}\t{rank_idx+1}\n") 39 | 40 | 41 | @raise_immediately 42 | def allrank(gpu_queue, doc_begin_index, doc_end_index, finish_queue): 43 | import os 44 | import torch 45 | gpuid = gpu_queue.get() 46 | os.environ["CUDA_VISIBLE_DEVICES"]=f"{gpuid}" 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | assert torch.cuda.device_count() == 1 49 | 50 | query_embedding_memmap, query_id_memmap = get_embed_memmap( 51 | args.query_embedding_dir, args.embedding_dim) 52 | qid2pos = {identity:i for i, identity in enumerate(query_id_memmap)} 53 | 54 | doc_embedding_memmap, doc_id_memmap = get_embed_memmap( 55 | args.doc_embedding_dir, args.embedding_dim) 56 | assert np.all(doc_id_memmap == list(range(len(doc_id_memmap)))) 57 | 58 | doc_embeddings = doc_embedding_memmap[doc_begin_index:doc_end_index] 59 | doc_ids = doc_id_memmap[doc_begin_index:doc_end_index] 60 | 61 | doc_embeddings = torch.from_numpy(doc_embeddings).to(device) 62 | results_dict = {qid:PriorityQueue(maxsize=args.hit) for qid in query_id_memmap} 63 | 64 | for qid in tqdm(query_id_memmap, desc=f"{gpuid}"): 65 | query_embedding = query_embedding_memmap[qid2pos[qid]] 66 | query_embedding = torch.from_numpy(query_embedding) 67 | query_embedding = query_embedding.to(device) 68 | 69 | all_scores = torch.sum(query_embedding * doc_embeddings, dim=-1) 70 | 71 | k = min(args.hit, len(doc_embeddings)) 72 | top_scores, top_indices = torch.topk(all_scores, k, largest=True, sorted=True) 73 | top_scores, top_indices = top_scores.cpu(), top_indices.cpu() 74 | top_doc_ids = doc_ids[top_indices.numpy()] 75 | cur_q_queue = results_dict[qid] 76 | for score, docid in zip(top_scores, top_doc_ids): 77 | score, docid = score.item(), docid.item() 78 | if cur_q_queue.full(): 79 | lowest_score, lowest_docid = cur_q_queue.get_nowait() 80 | if lowest_score >= score: 81 | cur_q_queue.put_nowait((lowest_score, lowest_docid)) 82 | break 83 | else: 84 | cur_q_queue.put_nowait((score, docid)) 85 | else: 86 | cur_q_queue.put_nowait((score, docid)) 87 | finish_queue.put(cur_q_queue.queue) 88 | doc_embeddings, all_scores, query_embedding, top_scores, top_indices = None, None, None, None, None 89 | torch.cuda.empty_cache() 90 | gpu_queue.put_nowait(gpuid) 91 | 92 | 93 | if __name__ == "__main__": 94 | work_dir = "/home/zhanjingtao/workspace/repbert" 95 | output_root = f"{work_dir}/msmarco/data/first_stage" 96 | parser = argparse.ArgumentParser() 97 | 98 | ## Required parameters 99 | parser.add_argument("--gpus", nargs="+", type=int, required=True) 100 | parser.add_argument("--per_gpu_doc_num", type=int, default=None) 101 | parser.add_argument("--hit", type=int, default=1000) 102 | parser.add_argument("--embedding_dim", type=int, default=768) 103 | parser.add_argument("--output_path", type=str, 104 | default="./data/retrieve/repbert.dev.small.top1k.tsv") 105 | parser.add_argument("--doc_embedding_dir", type=str, 106 | default="./data/precompute/doc_embedding") 107 | parser.add_argument("--query_embedding_dir", type=str, 108 | default="./data/precompute/query_dev.small_embedding") 109 | 110 | args = parser.parse_args() 111 | 112 | doc_size = len(get_embed_memmap(args.doc_embedding_dir, args.embedding_dim)[1]) 113 | if args.per_gpu_doc_num is None: 114 | args.per_gpu_doc_num = math.ceil(doc_size / len(args.gpus)) 115 | 116 | num_rounds = math.ceil(doc_size / args.per_gpu_doc_num) 117 | doc_arguments = [] 118 | for i in range(num_rounds): 119 | doc_begin_index = int(doc_size * i / num_rounds) 120 | doc_end_index = int(doc_size * (i+1) / num_rounds) 121 | doc_arguments.append((doc_begin_index, doc_end_index)) 122 | 123 | manager = Manager() 124 | finished_queue_lst = [manager.Queue() for _ in range(num_rounds)] 125 | gpu_queue = manager.Queue() 126 | for gpu in args.gpus: 127 | gpu_queue.put_nowait(gpu) 128 | 129 | pool = Pool(num_rounds+1) 130 | start = timer() 131 | for finish_queue, (doc_begin_index, doc_end_index) in zip(finished_queue_lst, doc_arguments): 132 | pool.apply_async(allrank, 133 | args=(gpu_queue, doc_begin_index, doc_end_index, finish_queue)) 134 | pool.apply_async(writer, args=(args, finished_queue_lst)) 135 | pool.close() 136 | pool.join() 137 | end = timer() 138 | print(end - start) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import torch 5 | import logging 6 | import numpy as np 7 | from tqdm import tqdm 8 | from collections import namedtuple, defaultdict 9 | from transformers import BertTokenizer 10 | from torch.utils.data import Dataset 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CollectionDataset: 15 | def __init__(self, collection_memmap_dir): 16 | self.pids = np.memmap(f"{collection_memmap_dir}/pids.memmap", dtype='int32',) 17 | self.lengths = np.memmap(f"{collection_memmap_dir}/lengths.memmap", dtype='int32',) 18 | self.collection_size = len(self.pids) 19 | self.token_ids = np.memmap(f"{collection_memmap_dir}/token_ids.memmap", 20 | dtype='int32', shape=(self.collection_size, 512)) 21 | 22 | def __len__(self): 23 | return self.collection_size 24 | 25 | def __getitem__(self, item): 26 | assert self.pids[item] == item 27 | return self.token_ids[item, :self.lengths[item]].tolist() 28 | 29 | 30 | def load_queries(tokenize_dir, mode): 31 | queries = dict() 32 | for line in tqdm(open(f"{tokenize_dir}/queries.{mode}.json"), desc="queries"): 33 | data = json.loads(line) 34 | queries[int(data['id'])] = data['ids'] 35 | return queries 36 | 37 | 38 | def load_querydoc_pairs(msmarco_dir, mode): 39 | qrels = defaultdict(set) 40 | qids, pids, labels = [], [], [] 41 | if mode == "train": 42 | for line in tqdm(open(f"{msmarco_dir}/qidpidtriples.train.small.tsv"), 43 | desc="load train triples"): 44 | qid, pos_pid, neg_pid = line.split("\t") 45 | qid, pos_pid, neg_pid = int(qid), int(pos_pid), int(neg_pid) 46 | qids.append(qid) 47 | pids.append(pos_pid) 48 | labels.append(1) 49 | qids.append(qid) 50 | pids.append(neg_pid) 51 | labels.append(0) 52 | for line in open(f"{msmarco_dir}/qrels.train.tsv"): 53 | qid, _, pid, _ = line.split() 54 | qrels[int(qid)].add(int(pid)) 55 | else: 56 | for line in open(f"{msmarco_dir}/top1000.{mode}"): 57 | qid, pid, _, _ = line.split("\t") 58 | qids.append(int(qid)) 59 | pids.append(int(pid)) 60 | qrels = dict(qrels) 61 | if not mode == "train": 62 | labels, qrels = None, None 63 | return qids, pids, labels, qrels 64 | 65 | 66 | class MSMARCODataset(Dataset): 67 | def __init__(self, mode, msmarco_dir, 68 | collection_memmap_dir, tokenize_dir, 69 | max_query_length=20, max_doc_length=256): 70 | 71 | self.collection = CollectionDataset(collection_memmap_dir) 72 | self.queries = load_queries(tokenize_dir, mode) 73 | self.qids, self.pids, self.labels, self.qrels = load_querydoc_pairs(msmarco_dir, mode) 74 | self.mode = mode 75 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 76 | self.cls_id = tokenizer.cls_token_id 77 | self.sep_id = tokenizer.sep_token_id 78 | self.max_query_length = max_query_length 79 | self.max_doc_length = max_doc_length 80 | 81 | def __len__(self): 82 | return len(self.qids) 83 | 84 | def __getitem__(self, item): 85 | qid, pid = self.qids[item], self.pids[item] 86 | query_input_ids, doc_input_ids = self.queries[qid], self.collection[pid] 87 | query_input_ids = query_input_ids[:self.max_query_length] 88 | query_input_ids = [self.cls_id] + query_input_ids + [self.sep_id] 89 | doc_input_ids = doc_input_ids[:self.max_doc_length] 90 | doc_input_ids = [self.cls_id] + doc_input_ids + [self.sep_id] 91 | 92 | ret_val = { 93 | "query_input_ids": query_input_ids, 94 | "doc_input_ids": doc_input_ids, 95 | "qid": qid, 96 | "docid" : pid 97 | } 98 | if self.mode == "train": 99 | ret_val["rel_docs"] = self.qrels[qid] 100 | return ret_val 101 | 102 | 103 | def pack_tensor_2D(lstlst, default, dtype, length=None): 104 | batch_size = len(lstlst) 105 | length = length if length is not None else max(len(l) for l in lstlst) 106 | tensor = default * torch.ones((batch_size, length), dtype=dtype) 107 | for i, l in enumerate(lstlst): 108 | tensor[i, :len(l)] = torch.tensor(l, dtype=dtype) 109 | return tensor 110 | 111 | 112 | def get_collate_function(mode): 113 | def collate_function(batch): 114 | input_ids_lst = [x["query_input_ids"] + x["doc_input_ids"] for x in batch] 115 | token_type_ids_lst = [[0]*len(x["query_input_ids"]) + [1]*len(x["doc_input_ids"]) 116 | for x in batch] 117 | valid_mask_lst = [[1]*len(input_ids) for input_ids in input_ids_lst] 118 | position_ids_lst = [list(range(len(x["query_input_ids"]))) + 119 | list(range(len(x["doc_input_ids"]))) for x in batch] 120 | data = { 121 | "input_ids": pack_tensor_2D(input_ids_lst, default=0, dtype=torch.int64), 122 | "token_type_ids": pack_tensor_2D(token_type_ids_lst, default=0, dtype=torch.int64), 123 | "valid_mask": pack_tensor_2D(valid_mask_lst, default=0, dtype=torch.int64), 124 | "position_ids": pack_tensor_2D(position_ids_lst, default=0, dtype=torch.int64), 125 | } 126 | qid_lst = [x['qid'] for x in batch] 127 | docid_lst = [x['docid'] for x in batch] 128 | if mode == "train": 129 | labels = [[j for j in range(len(docid_lst)) if docid_lst[j] in x['rel_docs'] ]for x in batch] 130 | data['labels'] = pack_tensor_2D(labels, default=-1, dtype=torch.int64, length=len(batch)) 131 | return data, qid_lst, docid_lst 132 | return collate_function 133 | 134 | 135 | def _test_dataset(): 136 | dataset = MSMARCODataset(mode="train") 137 | for data in dataset: 138 | tokens = dataset.tokenizer.convert_ids_to_tokens(data["query_input_ids"]) 139 | print(tokens) 140 | tokens = dataset.tokenizer.convert_ids_to_tokens(data["doc_input_ids"]) 141 | print(tokens) 142 | print(data['qid'], data['docid'], data['rel_docs']) 143 | print() 144 | k = input() 145 | if k == "q": 146 | break 147 | 148 | 149 | def _test_collate_func(): 150 | from torch.utils.data import DataLoader, SequentialSampler 151 | eval_dataset = MSMARCODataset(mode="train") 152 | train_sampler = SequentialSampler(eval_dataset) 153 | collate_fn = get_collate_function(mode="train") 154 | dataloader = DataLoader(eval_dataset, batch_size=26, 155 | num_workers=4, collate_fn=collate_fn, sampler=train_sampler) 156 | tokenizer = eval_dataset.tokenizer 157 | for batch, qidlst, pidlst in tqdm(dataloader): 158 | pass 159 | ''' 160 | print(batch['input_ids']) 161 | print(batch['token_type_ids']) 162 | print(batch['valid_mask']) 163 | print(batch['position_ids']) 164 | print(batch['labels']) 165 | k = input() 166 | if k == "q": 167 | break 168 | ''' 169 | 170 | if __name__ == "__main__": 171 | _test_collate_func() 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /precompute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import torch 5 | import logging 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | from timeit import default_timer as timer 10 | from collections import namedtuple, defaultdict 11 | from transformers import BertTokenizer, BertConfig 12 | from torch.utils.data import DataLoader, Dataset 13 | from dataset import (load_querydoc_pairs, load_queries, CollectionDataset, pack_tensor_2D, MSMARCODataset) 14 | from modeling import RepBERT 15 | 16 | logger = logging.getLogger(__name__) 17 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 18 | datefmt = '%d %H:%M:%S', 19 | level = logging.INFO) 20 | 21 | 22 | def create_embed_memmap(ids, memmap_dir, dim): 23 | if not os.path.exists(memmap_dir): 24 | os.makedirs(memmap_dir) 25 | embedding_path = f"{memmap_dir}/embedding.memmap" 26 | id_path = f"{memmap_dir}/ids.memmap" 27 | embed_open_mode = "r+" if os.path.exists(embedding_path) else "w+" 28 | id_open_mode = "r+" if os.path.exists(id_path) else "w+" 29 | logger.warning(f"Open Mode: embedding-{embed_open_mode} ids-{id_open_mode}") 30 | 31 | embedding_memmap = np.memmap(embedding_path, dtype='float32', 32 | mode=embed_open_mode, shape=(len(ids), dim)) 33 | id_memmap = np.memmap(id_path, dtype='int32', 34 | mode=id_open_mode, shape=(len(ids),)) 35 | id_memmap[:] = ids 36 | # not writable 37 | id_memmap = np.memmap(id_path, dtype='int32', 38 | shape=(len(ids),)) 39 | return embedding_memmap, id_memmap 40 | 41 | 42 | class MSMARCO_QueryDataset(Dataset): 43 | def __init__(self, tokenize_dir, msmarco_dir, task, max_query_length): 44 | self.max_query_length = max_query_length 45 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 46 | self.queries = load_queries(tokenize_dir, task) 47 | self.qids = list(self.queries.keys()) 48 | self.task = task 49 | self.cls_id = tokenizer.cls_token_id 50 | self.sep_id = tokenizer.sep_token_id 51 | self.all_ids = self.qids 52 | 53 | def __len__(self): 54 | return len(self.qids) 55 | 56 | def __getitem__(self, item): 57 | qid = self.qids[item] 58 | query_input_ids = self.queries[qid] 59 | query_input_ids = query_input_ids[:self.max_query_length] 60 | query_input_ids = [self.cls_id] + query_input_ids + [self.sep_id] 61 | ret_val = { 62 | "input_ids": query_input_ids, 63 | "id" : qid 64 | } 65 | return ret_val 66 | 67 | 68 | class MSMARCO_DocDataset(Dataset): 69 | def __init__(self, collection_memmap_dir, max_doc_length): 70 | self.max_doc_length = max_doc_length 71 | self.collection = CollectionDataset(collection_memmap_dir) 72 | self.pids = self.collection.pids 73 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 74 | self.cls_id = tokenizer.cls_token_id 75 | self.sep_id = tokenizer.sep_token_id 76 | self.all_ids = self.collection.pids 77 | 78 | def __len__(self): 79 | return len(self.pids) 80 | 81 | def __getitem__(self, item): 82 | pid = self.pids[item] 83 | doc_input_ids = self.collection[pid] 84 | doc_input_ids = doc_input_ids[:self.max_doc_length] 85 | doc_input_ids = [self.cls_id] + doc_input_ids + [self.sep_id] 86 | 87 | ret_val = { 88 | "input_ids": doc_input_ids, 89 | "id" : pid 90 | } 91 | return ret_val 92 | 93 | 94 | def get_collate_function(): 95 | def collate_function(batch): 96 | input_ids_lst = [x["input_ids"] for x in batch] 97 | valid_mask_lst = [[1]*len(input_ids) for input_ids in input_ids_lst] 98 | data = { 99 | "input_ids": pack_tensor_2D(input_ids_lst, default=0, 100 | dtype=torch.int64), 101 | "valid_mask": pack_tensor_2D(valid_mask_lst, default=0, 102 | dtype=torch.int64), 103 | } 104 | id_lst = [x['id'] for x in batch] 105 | return data, id_lst 106 | return collate_function 107 | 108 | 109 | def generate_embeddings(args, model, task): 110 | if task == "doc": 111 | dataset = MSMARCO_DocDataset(args.collection_memmap_dir, args.max_doc_length) 112 | memmap_dir = args.doc_embedding_dir 113 | else: 114 | query_str, mode = task.split("_") 115 | assert query_str == "query" 116 | dataset = MSMARCO_QueryDataset(args.tokenize_dir, args.msmarco_dir, mode, args.max_query_length) 117 | memmap_dir = args.query_embedding_dir 118 | embedding_memmap, ids_memmap = create_embed_memmap( 119 | dataset.all_ids, memmap_dir, model.config.hidden_size) 120 | id2pos = {identity:i for i, identity in enumerate(ids_memmap)} 121 | 122 | batch_size = args.per_gpu_batch_size * max(1, args.n_gpu) 123 | # Note that DistributedSampler samples randomly 124 | collate_fn = get_collate_function() 125 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) 126 | 127 | # multi-gpu eval 128 | if args.n_gpu > 1: 129 | model = torch.nn.DataParallel(model) 130 | # Eval! 131 | logger.info(" Num examples = %d", len(dataset)) 132 | logger.info(" Batch size = %d", batch_size) 133 | 134 | start = timer() 135 | for batch, ids in tqdm(dataloader, desc="Evaluating"): 136 | model.eval() 137 | with torch.no_grad(): 138 | batch = {k:v.to(args.device) for k, v in batch.items()} 139 | output = model(**batch) 140 | sequence_embeddings = output.detach().cpu().numpy() 141 | poses = [id2pos[identity] for identity in ids] 142 | embedding_memmap[poses] = sequence_embeddings 143 | end = timer() 144 | print(task, "time:", end-start) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | ## Required parameters 150 | parser.add_argument("--load_model_path", type=str, required=True) 151 | parser.add_argument("--task", choices=["query_dev.small", "query_eval.small", "doc"], 152 | required=True) 153 | parser.add_argument("--output_dir", type=str, default="./data/precompute") 154 | 155 | parser.add_argument("--msmarco_dir", type=str, default=f"./data/msmarco-passage") 156 | parser.add_argument("--collection_memmap_dir", type=str, default="./data/collection_memmap") 157 | parser.add_argument("--tokenize_dir", type=str, default="./data/tokenize") 158 | parser.add_argument("--max_query_length", type=int, default=20) 159 | parser.add_argument("--max_doc_length", type=int, default=256) 160 | parser.add_argument("--per_gpu_batch_size", default=100, type=int) 161 | args = parser.parse_args() 162 | 163 | args.doc_embedding_dir = f"{args.output_dir}/doc_embedding" 164 | args.query_embedding_dir = f"{args.output_dir}/{args.task}_embedding" 165 | 166 | logger.info(args) 167 | 168 | # Setup CUDA, GPU 169 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 170 | args.n_gpu = torch.cuda.device_count() 171 | 172 | args.device = device 173 | 174 | # Setup logging 175 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 176 | 177 | config = BertConfig.from_pretrained(args.load_model_path) 178 | if "query" in args.task: 179 | config.encode_type = "query" 180 | else: 181 | config.encode_type = "doc" 182 | model = RepBERT.from_pretrained(args.load_model_path, config=config) 183 | model.to(args.device) 184 | 185 | logger.info(args) 186 | generate_embeddings(args, model, args.task) 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RepBERT 2 | 3 | * 🔥**News 2021-10: Our full paper, [Learning Discrete Representations via Constrained Clustering for Effective and Efficient Dense Retrieval](https://arxiv.org/abs/2110.05789)\[[code](https://github.com/jingtaozhan/RepCONC)\], was accepted by WSDM'22. It presents RepCONC and achieves state-of-the-art first-stage retrieval effectiveness-efficiency tradeoff.** 4 | 5 | * 🔥**News 2021-8: Our full paper, [Jointly Optimizing Query Encoder and Product Quantization to Improve Retrieval Performance](https://arxiv.org/abs/2108.00644)\[[code](https://github.com/jingtaozhan/JPQ)\], was accepted by CIKM'21. It presents JPQ and greatly improves the efficiency of Dense Retrieval.** 6 | 7 | * 🔥**News 2021-4: Our full paper, [Optimizing Dense Retrieval Model Training with Hard Negatives](https://arxiv.org/abs/2104.08051)\[[code](https://github.com/jingtaozhan/DRhard)\], was accepted by SIGIR'21. It provides theoretical analysis on different negative sampling strategies and greatly improves the effectiveness of Dense Retrieval with hard negative sampling.** 8 | 9 | RepBERT is is currently the state-of-the-art first-stage retrieval technique on [MS MARCO Passage Ranking task](https://microsoft.github.io/msmarco/). It represents documents and queries with fixed-length contextualized embeddings. The inner products of them are regarded as relevance scores. Its efficiency is comparable to bag-of-words methods. For more details, check out our paper: 10 | 11 | + Zhan et al. [RepBERT: Contextualized Text Embeddings for First-Stage Retrieval.](https://arxiv.org/abs/2006.15498) 12 | 13 | 14 | MS MARCO Passage Ranking Leaderboard (Jun 28th 2020) | Category | Eval MRR@10 | Latency 15 | :------------------------------------ | :------------: | :------: | ------: 16 | [BM25 + BERT](https://github.com/nyu-dl/dl4marco-bert) from [(Nogueira and Cho, 2019)](https://arxiv.org/abs/1901.04085) | Cascade | 0.358 | 3400 ms 17 | RepBERT (this code) | First-Stage | 0.294 | 80 ms 18 | BiLSTM + Co-Attention + self attention based document scorer [(Alaparthi et al., 2019)](https://arxiv.org/abs/1906.06056) (best non-BERT) | Cascade | 0.291 | - 19 | [docTTTTTquery](https://github.com/castorini/docTTTTTquery) [(Nogueira1 et al., 2019)](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery.pdf) | First-Stage | 0.272 | 64 ms 20 | [DeepCT](https://github.com/AdeDZY/DeepCT) [(Dai and Callan, 2019)](https://github.com/AdeDZY/DeepCT) | First-Stage | 0.239 | 55 ms 21 | [doc2query](https://github.com/nyu-dl/dl4ir-doc2query) [(Nogueira et al., 2019)](https://github.com/nyu-dl/dl4ir-doc2query) | First-Stage | 0.218 | 90 ms 22 | [BM25(Anserini)](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-passage.md) | First-Stage | 0.186 | 50 ms 23 | 24 | ## Data and Trained Models 25 | 26 | We make the following data available for download: 27 | 28 | + `repbert.dev.small.top1k.tsv`: 6,980,000 pairs of dev set queries and retrieved passages. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file. 29 | + `repbert.eval.small.top1k.tsv`: 6,837,000 pairs of eval set queries and retrieved passages. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file. 30 | + `repbert.ckpt-350000.zip`: Trained BERT base model to represent queries and passages. It contains two files, namely `config.json` and `pytorch_model.bin`. 31 | 32 | Download and verify the above files from the below table: 33 | 34 | File | Size | MD5 | Download 35 | :----|-----:|:----|:----- 36 | `repbert.dev.small.top1k.tsv` | 127 MB | `0d08617b62a777c3c8b2d42ca5e89a8e` | [[Google Drive](https://drive.google.com/file/d/1MrrwDmTZOiFx3qjfPxi4lDSdQk1tR5C6/view?usp=sharing)] 37 | `repbert.eval.small.top1k.tsv` | 125 MB | `b56a79138f215292d674f58c694d5206` | [[Google Drive](https://drive.google.com/file/d/1twRGEJZFZc4zYa75q8UFEz9ZS2oh0oyE/view?usp=sharing)] 38 | `repbert.ckpt-350000.zip` | 386 MB| `b59a574f53c92de6a4ddd4b3fbef784a` | [[Google Drive](https://drive.google.com/file/d/1xhwy_nvRWSNyJ2V7uP3FC5zVwj1Xmylv/view?usp=sharing)] 39 | 40 | 41 | ## Replicating Results with Provided Trained Model 42 | 43 | We provide instructions on how to replicate RepBERT retrieval results using provided trained model. 44 | 45 | First, make sure you already installed [🤗 Transformers](https://github.com/huggingface/transformers): 46 | 47 | ```bash 48 | pip install transformers 49 | git clone https://github.com/jingtaozhan/RepBERT-Index 50 | cd RepBERT-Index 51 | ``` 52 | 53 | Next, download `collectionandqueries.tar.gz` from [MSMARCO-Passage-Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking). It contains passages, queries, and qrels. 54 | 55 | ```bash 56 | mkdir data 57 | cd data 58 | wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz 59 | mkdir msmarco-passage 60 | tar xvfz collectionandqueries.tar.gz -C msmarco-passage 61 | ``` 62 | 63 | To confirm, `collectionandqueries.tar.gz` should have MD5 checksum of `31644046b18952c1386cd4564ba2ae69`. 64 | 65 | To reduce duplication of effort in training and testing, we tokenize queries and passages in advance. This should take some time (about 3-4 hours). Besides, we convert tokenized passages to numpy memmap array, which can greatly reduce the memory overhead at run time. 66 | 67 | ```bash 68 | python convert_text_to_tokenized.py --tokenize_queries --tokenize_collection 69 | python convert_collection_to_memmap.py 70 | ``` 71 | 72 | Please download the provided model `repbert.ckpt-350000.zip`, put it in `./data`, and unzip it. You should see two files in the directory `./data/ckpt-350000`, namely `pytorch_model.bin` and `config.json`. 73 | 74 | Next, you need to precompute the representations of passages and queries. 75 | 76 | ```bash 77 | python precompute.py --load_model_path ./data/ckpt-350000 --task doc 78 | python precompute.py --load_model_path ./data/ckpt-350000 --task query_dev.small 79 | python precompute.py --load_model_path ./data/ckpt-350000 --task query_eval.small 80 | ``` 81 | 82 | At last, you can retrieve the passages for the queries in the dev set (or eval set). `multi_retrieve.py` will use the gpus specified by `--gpus` argument and the representations of all passages are evenly distributed among all gpus. If your CUDA memory is limited, you can use `--per_gpu_doc_num` to specify the num of passages distributed to each gpu. 83 | 84 | ```bash 85 | python multi_retrieve.py --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --gpus 0,1,2,3,4 86 | python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv 87 | ``` 88 | 89 | You can also retrieve the passages with only one GPU. 90 | 91 | ```bash 92 | export CUDA_VISIBLE_DEVICES=0 93 | python retrieve.py --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --per_gpu_doc_num 1800000 94 | python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv 95 | ``` 96 | 97 | The results should be: 98 | 99 | ``` 100 | ##################### 101 | MRR @10: 0.3038783713103188 102 | QueriesRanked: 6980 103 | ##################### 104 | ``` 105 | 106 | ## Train RepBERT 107 | 108 | Next, download `qidpidtriples.train.full.tsv.gz` from [MSMARCO-Passage-Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking). 109 | 110 | ```bash 111 | cd ./data/msmarco-passage 112 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.tsv.gz 113 | ``` 114 | 115 | Extract it and use `shuf` command to generate a smaller file (10%). 116 | 117 | ```bash 118 | shuf ./qidpidtriples.train.full.tsv -o ./qidpidtriples.train.small.tsv -n 26991900 119 | ``` 120 | 121 | Start training. Note that the evaluaton result is about reranking. 122 | 123 | ```bash 124 | python ./train.py --task train --evaluate_during_training 125 | ``` 126 | 127 | -------------------------------------------------------------------------------- /ms_marco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH! 3 | Command line: 4 | python msmarco_eval_ranking.py 5 | Creation Date : 06/12/2018 6 | Last Modified : 4/09/2019 7 | Authors : Daniel Campos , Rutger van Haasteren 8 | """ 9 | import sys 10 | import statistics 11 | 12 | from collections import Counter 13 | 14 | MaxMRRRank = 10 15 | 16 | def load_reference_from_stream(f): 17 | """Load Reference reference relevant passages 18 | Args:f (stream): stream to load. 19 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 20 | """ 21 | qids_to_relevant_passageids = {} 22 | for l in f: 23 | try: 24 | l = l.strip().split('\t') 25 | qid = int(l[0]) 26 | if qid in qids_to_relevant_passageids: 27 | pass 28 | else: 29 | qids_to_relevant_passageids[qid] = [] 30 | qids_to_relevant_passageids[qid].append(int(l[2])) 31 | except: 32 | raise IOError('\"%s\" is not valid format' % l) 33 | return qids_to_relevant_passageids 34 | 35 | def load_reference(path_to_reference): 36 | """Load Reference reference relevant passages 37 | Args:path_to_reference (str): path to a file to load. 38 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 39 | """ 40 | with open(path_to_reference,'r') as f: 41 | qids_to_relevant_passageids = load_reference_from_stream(f) 42 | return qids_to_relevant_passageids 43 | 44 | def load_candidate_from_stream(f): 45 | """Load candidate data from a stream. 46 | Args:f (stream): stream to load. 47 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 48 | """ 49 | qid_to_ranked_candidate_passages = {} 50 | for l in f: 51 | try: 52 | l = l.strip().split('\t') 53 | qid = int(l[0]) 54 | pid = int(l[1]) 55 | rank = int(l[2]) 56 | if qid in qid_to_ranked_candidate_passages: 57 | pass 58 | else: 59 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 60 | tmp = [0] * 1000 61 | qid_to_ranked_candidate_passages[qid] = tmp 62 | qid_to_ranked_candidate_passages[qid][rank-1]=pid 63 | except: 64 | raise IOError('\"%s\" is not valid format' % l) 65 | return qid_to_ranked_candidate_passages 66 | 67 | def load_candidate(path_to_candidate): 68 | """Load candidate data from a file. 69 | Args:path_to_candidate (str): path to file to load. 70 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 71 | """ 72 | 73 | with open(path_to_candidate,'r') as f: 74 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 75 | return qid_to_ranked_candidate_passages 76 | 77 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 78 | """Perform quality checks on the dictionaries 79 | Args: 80 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 81 | Dict as read in with load_reference or load_reference_from_stream 82 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 83 | Returns: 84 | bool,str: Boolean whether allowed, message to be shown in case of a problem 85 | """ 86 | message = '' 87 | allowed = True 88 | 89 | # Create sets of the QIDs for the submitted and reference queries 90 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 91 | ref_set = set(qids_to_relevant_passageids.keys()) 92 | 93 | # Check that we do not have multiple passages per query 94 | for qid in qids_to_ranked_candidate_passages: 95 | # Remove all zeros from the candidates 96 | duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 97 | 98 | if len(duplicate_pids-set([0])) > 0: 99 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 100 | qid=qid, pid=list(duplicate_pids)[0]) 101 | allowed = False 102 | 103 | return allowed, message 104 | 105 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 106 | """Compute MRR metric 107 | Args: 108 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 109 | Dict as read in with load_reference or load_reference_from_stream 110 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 111 | Returns: 112 | dict: dictionary of metrics {'MRR': } 113 | """ 114 | all_scores = {} 115 | MRR = 0 116 | qids_with_relevant_passages = 0 117 | ranking = [] 118 | for qid in qids_to_ranked_candidate_passages: 119 | if qid in qids_to_relevant_passageids: 120 | ranking.append(0) 121 | target_pid = qids_to_relevant_passageids[qid] 122 | candidate_pid = qids_to_ranked_candidate_passages[qid] 123 | for i in range(0,MaxMRRRank): 124 | if candidate_pid[i] in target_pid: 125 | MRR += 1/(i + 1) 126 | ranking.pop() 127 | ranking.append(i+1) 128 | break 129 | if len(ranking) == 0: 130 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 131 | 132 | MRR = MRR/len(qids_to_relevant_passageids) 133 | all_scores['MRR @10'] = MRR 134 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 135 | return all_scores 136 | 137 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 138 | """Compute MRR metric 139 | Args: 140 | p_path_to_reference_file (str): path to reference file. 141 | Reference file should contain lines in the following format: 142 | QUERYID\tPASSAGEID 143 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 144 | p_path_to_candidate_file (str): path to candidate file. 145 | Candidate file sould contain lines in the following format: 146 | QUERYID\tPASSAGEID1\tRank 147 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 148 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 149 | Where the values are separated by tabs and ranked in order of relevance 150 | Returns: 151 | dict: dictionary of metrics {'MRR': } 152 | """ 153 | 154 | qids_to_relevant_passageids = load_reference(path_to_reference) 155 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 156 | if perform_checks: 157 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 158 | if message != '': print(message) 159 | 160 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 161 | 162 | def main(): 163 | """Command line: 164 | python msmarco_eval_ranking.py 165 | """ 166 | path_to_candidate = sys.argv[2] 167 | path_to_reference = sys.argv[1] 168 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 169 | print('#####################') 170 | for metric in sorted(metrics): 171 | print('{}: {}'.format(metric, metrics[metric])) 172 | print('#####################') 173 | if __name__ == '__main__': 174 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import random 5 | import time 6 | import logging 7 | import argparse 8 | import subprocess 9 | import numpy as np 10 | from tqdm import tqdm, trange 11 | from collections import defaultdict 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.utils.data import DataLoader, SequentialSampler 14 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 15 | from transformers import (BertConfig, BertTokenizer, AdamW, get_linear_schedule_with_warmup) 16 | 17 | from modeling import RepBERT_Train 18 | from dataset import MSMARCODataset, get_collate_function 19 | from utils import generate_rank, eval_results 20 | 21 | logger = logging.getLogger(__name__) 22 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 23 | datefmt = '%d %H:%M:%S', 24 | level = logging.INFO) 25 | 26 | def set_seed(args): 27 | random.seed(args.seed) 28 | np.random.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | if args.n_gpu > 0: 31 | torch.cuda.manual_seed_all(args.seed) 32 | 33 | 34 | def save_model(model, output_dir, save_name, args): 35 | save_dir = os.path.join(output_dir, save_name) 36 | if not os.path.exists(save_dir): 37 | os.makedirs(save_dir) 38 | model_to_save = model.module if hasattr(model, 'module') else model 39 | model_to_save.save_pretrained(save_dir) 40 | torch.save(args, os.path.join(save_dir, 'training_args.bin')) 41 | 42 | 43 | def train(args, model): 44 | """ Train the model """ 45 | tb_writer = SummaryWriter(args.log_dir) 46 | 47 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 48 | 49 | train_dataset = MSMARCODataset("train", args.msmarco_dir, 50 | args.collection_memmap_dir, args.tokenize_dir, 51 | args.max_query_length, args.max_doc_length) 52 | 53 | # NOTE: Must Sequential! Pos, Neg, Pos, Neg, ... 54 | train_sampler = SequentialSampler(train_dataset) 55 | collate_fn = get_collate_function(mode="train") 56 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 57 | batch_size=args.train_batch_size, num_workers=args.data_num_workers, 58 | collate_fn=collate_fn) 59 | 60 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 61 | 62 | # Prepare optimizer and schedule (linear warmup and decay) 63 | no_decay = ['bias', 'LayerNorm.weight'] 64 | optimizer_grouped_parameters = [ 65 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 66 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 67 | ] 68 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 69 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 70 | num_training_steps=t_total) 71 | 72 | # multi-gpu training (should be after apex fp16 initialization) 73 | if args.n_gpu > 1: 74 | model = torch.nn.DataParallel(model) 75 | 76 | # Train! 77 | logger.info("***** Running training *****") 78 | logger.info(" Num examples = %d", len(train_dataset)) 79 | logger.info(" Num Epochs = %d", args.num_train_epochs) 80 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 81 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 82 | args.train_batch_size * args.gradient_accumulation_steps) 83 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 84 | logger.info(" Total optimization steps = %d", t_total) 85 | 86 | global_step = 0 87 | tr_loss, logging_loss = 0.0, 0.0 88 | model.zero_grad() 89 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 90 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 91 | for epoch_idx, _ in enumerate(train_iterator): 92 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 93 | for step, (batch, _, _) in enumerate(epoch_iterator): 94 | 95 | batch = {k:v.to(args.device) for k, v in batch.items()} 96 | model.train() 97 | outputs = model(**batch) 98 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 99 | 100 | if args.n_gpu > 1: 101 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 102 | if args.gradient_accumulation_steps > 1: 103 | loss = loss / args.gradient_accumulation_steps 104 | loss.backward() 105 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 106 | 107 | tr_loss += loss.item() 108 | if (step + 1) % args.gradient_accumulation_steps == 0: 109 | optimizer.step() 110 | scheduler.step() # Update learning rate schedule 111 | model.zero_grad() 112 | global_step += 1 113 | if args.evaluate_during_training and (global_step % args.training_eval_steps == 0): 114 | mrr = evaluate(args, model, mode="dev", prefix="step_{}".format(global_step)) 115 | tb_writer.add_scalar('dev/MRR@10', mrr, global_step) 116 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 117 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 118 | cur_loss = (tr_loss - logging_loss)/args.logging_steps 119 | tb_writer.add_scalar('train/loss', cur_loss, global_step) 120 | logging_loss = tr_loss 121 | 122 | if args.save_steps > 0 and global_step % args.save_steps == 0: 123 | # Save model checkpoint 124 | save_model(model, args.model_save_dir, 'ckpt-{}'.format(global_step), args) 125 | 126 | 127 | def evaluate(args, model, mode, prefix): 128 | eval_output_dir = args.eval_save_dir 129 | if not os.path.exists(eval_output_dir): 130 | os.makedirs(eval_output_dir) 131 | 132 | eval_dataset = MSMARCODataset(mode, args.msmarco_dir, 133 | args.collection_memmap_dir, args.tokenize_dir, 134 | args.max_query_length, args.max_doc_length) 135 | 136 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 137 | # Note that DistributedSampler samples randomly 138 | collate_fn = get_collate_function(mode=mode) 139 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, 140 | num_workers=args.data_num_workers, collate_fn=collate_fn) 141 | 142 | # multi-gpu eval 143 | if args.n_gpu > 1: 144 | model = torch.nn.DataParallel(model) 145 | 146 | # Eval! 147 | logger.info("***** Running evaluation {} *****".format(prefix)) 148 | logger.info(" Num examples = %d", len(eval_dataset)) 149 | logger.info(" Batch size = %d", args.eval_batch_size) 150 | 151 | output_file_path = f"{eval_output_dir}/{prefix}.{mode}.score.tsv" 152 | with open(output_file_path, 'w') as outputfile: 153 | for batch, qids, docids in tqdm(eval_dataloader, desc="Evaluating"): 154 | model.eval() 155 | with torch.no_grad(): 156 | batch = {k:v.to(args.device) for k, v in batch.items()} 157 | outputs = model(**batch) 158 | scores = torch.diagonal(outputs[0]).detach().cpu().numpy() 159 | assert len(qids) == len(docids) == len(scores) 160 | for qid, docid, score in zip(qids, docids, scores): 161 | outputfile.write(f"{qid}\t{docid}\t{score}\n") 162 | 163 | rank_output = f"{eval_output_dir}/{prefix}.{mode}.rank.tsv" 164 | generate_rank(output_file_path, rank_output) 165 | 166 | if mode == "dev": 167 | mrr = eval_results(rank_output) 168 | return mrr 169 | 170 | 171 | 172 | def run_parse_args(): 173 | parser = argparse.ArgumentParser() 174 | 175 | ## Required parameters 176 | parser.add_argument("--task", choices=["train", "dev", "eval"], required=True) 177 | parser.add_argument("--output_dir", type=str, default=f"./data/train") 178 | 179 | parser.add_argument("--msmarco_dir", type=str, default=f"./data/msmarco-passage") 180 | parser.add_argument("--collection_memmap_dir", type=str, default="./data/collection_memmap") 181 | parser.add_argument("--tokenize_dir", type=str, default="./data/tokenize") 182 | parser.add_argument("--max_query_length", type=int, default=20) 183 | parser.add_argument("--max_doc_length", type=int, default=256) 184 | 185 | ## Other parameters 186 | parser.add_argument("--eval_ckpt", type=int, default=None) 187 | parser.add_argument("--per_gpu_eval_batch_size", default=26, type=int,) 188 | parser.add_argument("--per_gpu_train_batch_size", default=26, type=int) 189 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2) 190 | 191 | parser.add_argument("--no_cuda", action='store_true') 192 | parser.add_argument('--seed', type=int, default=42) 193 | 194 | parser.add_argument("--evaluate_during_training", action="store_true") 195 | parser.add_argument("--training_eval_steps", type=int, default=5000) 196 | 197 | parser.add_argument("--save_steps", type=int, default=5000) 198 | parser.add_argument("--logging_steps", type=int, default=100) 199 | parser.add_argument("--data_num_workers", default=0, type=int) 200 | 201 | parser.add_argument("--learning_rate", default=3e-6, type=float) 202 | parser.add_argument("--weight_decay", default=0.01, type=float) 203 | parser.add_argument("--warmup_steps", default=10000, type=int) 204 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 205 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 206 | parser.add_argument("--num_train_epochs", default=1, type=int) 207 | 208 | args = parser.parse_args() 209 | 210 | time_stamp = time.strftime("%b-%d_%H:%M:%S", time.localtime()) 211 | args.log_dir = f"{args.output_dir}/log/{time_stamp}" 212 | args.model_save_dir = f"{args.output_dir}/models" 213 | args.eval_save_dir = f"{args.output_dir}/eval_results" 214 | return args 215 | 216 | 217 | def main(): 218 | args = run_parse_args() 219 | logger.info(args) 220 | 221 | # Setup CUDA, GPU 222 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 223 | args.n_gpu = torch.cuda.device_count() 224 | 225 | args.device = device 226 | 227 | # Setup logging 228 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 229 | 230 | # Set seed 231 | set_seed(args) 232 | 233 | if args.task == "train": 234 | load_model_path = f"bert-base-uncased" 235 | else: 236 | assert args.eval_ckpt is not None 237 | load_model_path = f"{args.model_save_dir}/ckpt-{args.eval_ckpt}" 238 | 239 | 240 | config = BertConfig.from_pretrained(load_model_path) 241 | model = RepBERT_Train.from_pretrained(load_model_path, config=config) 242 | model.to(args.device) 243 | 244 | logger.info("Training/evaluation parameters %s", args) 245 | # Evaluation 246 | if args.task == "train": 247 | train(args, model) 248 | else: 249 | result = evaluate(args, model, args.task, prefix=f"ckpt-{args.eval_ckpt}") 250 | print(result) 251 | 252 | 253 | 254 | if __name__ == "__main__": 255 | main() 256 | --------------------------------------------------------------------------------