├── README.md ├── __init__.py ├── densify ├── __init__.py ├── densify_corpus.py ├── densify_query.py └── output_vector.py ├── docs ├── aggretriever │ ├── beir-eval.md │ └── msmarco-passage-train-eval.md └── dhr │ ├── beir-eval.md │ ├── densify_exp.md │ └── msmarco-passage-train-eval.md ├── fig ├── aggretriever.png ├── aggretriever_teaser.png ├── densification.png └── single_model_fusion.png ├── retrieval ├── __init__.py ├── evaluation │ ├── __init__.py │ └── custom_metrics.py ├── gip_retrieval.py ├── index.py ├── merge.result.py ├── quantize_index.py ├── rcap_eval.py └── util.py └── tevatron ├── Aggretriever ├── __init__.py ├── modeling.py └── utils.py ├── ColBERT └── modeling.py ├── DHR ├── __init__.py ├── modeling.py └── utils.py ├── Dense ├── __init__.py └── modeling.py ├── __init__.py ├── arguments.py ├── data.py ├── datasets ├── __init__.py ├── beir │ ├── __init__.py │ ├── encode_and_retrieval.py │ ├── preprocess.py │ └── sentence_bert.py ├── dataset.py └── preprocessor.py ├── driver ├── __init__.py ├── encode.py ├── eval.py ├── jax_encode.py ├── jax_train.py └── train.py ├── faiss_retriever ├── __init__.py ├── __main__.py ├── reducer.py └── retriever.py ├── loss.py ├── preprocessor ├── __init__.py └── preprocessor_tsv.py ├── tevax ├── __init__.py ├── loss.py └── training.py ├── trainer.py └── utils ├── __init__.py ├── convert_from_dpr.py ├── data_reader.py ├── format ├── __init__.py └── convert_result_to_trec.py ├── metrics.py ├── tokenize_corpus.py └── tokenize_query.py /README.md: -------------------------------------------------------------------------------- 1 | # Dense Hybrid Retrieval 2 | In this repo, we introduce two approaches to training transformers to capture semantic and lexical text representations for robust dense passage retrieval. 3 | 1. *[Aggretriever: A Simple Approach to Aggregate Textual Representation for Robust Dense Passage Retrieval](https://arxiv.org/abs/2208.00511)* Sheng-Chieh Lin, Minghan Li and Jimmy Lin. (TACL just accepted) 4 | 2. *[A Dense Representation Framework for Lexical and Semantic Matching](https://dl.acm.org/doi/10.1145/3582426)* Sheng-Chieh Lin and Jimmy Lin. (TOIS 2021 in press) 5 | 6 | This repo contains three parts: (1) densify (2) training (tevatron) (3) retrieval. 7 | Our training code is mainly from [Tevatron](https://github.com/texttron/tevatron) with a minor revision. 8 | 9 | ## Requirements 10 | ``` 11 | pip install torch>=1.7.0 12 | pip install transformers==4.15.0 13 | pip install pyserini 14 | pip install beir 15 | ``` 16 | 17 | ## Huggingface Checkpoints 18 | Model | Initialization | MARCO Dev | BEIR (13 public datasets) | Huggingface Path | Document 19 | |---|---|---|---|---|--- 20 | DeLADE+[CLS] plus | [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) | 37.1 | 49.8 | [jacklin/DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/dhr) 21 | DeLADE+[CLS] | [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) | 35.7 | 48.5 | [jacklin/DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/dhr) 22 | Aggretriever | [distilbert-base-uncased](https://huggingface.co/bert-base-uncased) | 34.1 | 46.0 | [jacklin/DistilBERT-AGG](https://huggingface.co/jacklin/DistilBERT-AGG) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/aggretriever) 23 | 24 | # Aggretriever 25 | 26 | 27 | 28 | In this paper, we introduce a simple approach to aggregating token-level information into a single-vector dense representation. We provide instruction for model training and evaluation on MS MARCO passage ranking dataset in the [document](https://github.com/castorini/dhr/blob/main/docs/aggretriever/msmarco-passage-train-eval.md). We also provide instruction for the evaluation on BEIR datasets in the [document](https://github.com/castorini/dhr/blob/main/docs/aggretriever/beir-eval.md). 29 | 30 | # A Dense Representation Framework for Lexical and Semantic Matching 31 | In this paper, we introduce a unified representation framework for Lexical and Semantic Matching. We first introduce how to use our framework to conduct retrieval for high-dimensional (lexcial) representations and combine with single-vector dense (semantic) representations for hybrid search. 32 | ## Dense Lexical Retrieval 33 | 34 | 35 | 36 | We can densify any existing lexical matching models and conduct lexical matching on GPU. In the [document](https://github.com/jacklin64/DHR/blob/main/docs/densify_exp.md), we demonstrate how to conduct BM25 and uniCOIL end-to-end retrieval under our framework. Detailed description can be found in our [paper](https://arxiv.org/pdf/2112.04666.pdf). 37 | 38 | ## Dense Hybrid Retrieval 39 | With the densified lexical representations, we can easily conduct lexical and semantic hybrid retrieval using independent neural models. A document for hybrid retrieval will be coming soon. 40 | 41 | ## Dense Hybrid Representation Model 42 | 43 | 44 | 45 | In our paper, we propose a single model fusion approach by training the lexical and semantic components of a transformer while inference, we combine the densified lexical representations and dense representations as dense hybrid representations. Instead of training by yourself, you can also download our trained [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P), [DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) and [DeLADE](https://huggingface.co/jacklin/DeLADE) and directly peform inference on MSMARCO Passage dataset (see [document](https://github.com/jacklin64/DHR/blob/main/docs/dhr/msmarco-passage-train-eval.md)) or BEIR datasets (see [document](https://github.com/jacklin64/DHR/blob/main/docs/dhr/beir-eval.md)). 46 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/__init__.py -------------------------------------------------------------------------------- /densify/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/densify/__init__.py -------------------------------------------------------------------------------- /densify/densify_corpus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import glob 4 | import numpy as np 5 | import gzip 6 | import json 7 | import argparse 8 | from pyserini.index import IndexReader 9 | from multiprocessing import Pool, Manager, Queue 10 | from transformers import AutoModelForMaskedLM, AutoTokenizer 11 | import multiprocessing 12 | import os 13 | from tqdm import tqdm 14 | logger = logging.getLogger(__name__) 15 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) 16 | 17 | omission_num = \ 18 | {'bm25': 472, 19 | 'deepimpact': 502, 20 | 'unicoil': 570, 21 | 'splade': 570} 22 | 23 | whole_word_matching = \ 24 | {'bm25': True, 25 | 'deepimpact': True, 26 | 'unicoil': False, 27 | 'splade': False} 28 | 29 | def densify(data, dim, whole_word_matching, token2id, args): 30 | value = np.zeros((dim), dtype=np.float16) 31 | if whole_word_matching: 32 | index = np.zeros((dim), dtype=np.int16) 33 | else: 34 | index = np.zeros((dim), dtype=np.int8) 35 | collision_num = 0 36 | for i, (token, weight) in enumerate(data['vector'].items()): 37 | token_id = token2id[token] 38 | if token_id < omission_num[args.model]: 39 | continue 40 | else: 41 | slice_num = (token_id - omission_num[args.model])%dim 42 | index_num = (token_id - omission_num[args.model])//dim 43 | if value[slice_num]==0: 44 | value[slice_num] = weight 45 | index[slice_num] = index_num 46 | else: 47 | # collision 48 | collision_num += 1 49 | if value[slice_num] < weight: 50 | value[slice_num] = weight 51 | index[slice_num] = index_num 52 | return value, index, collision_num 53 | 54 | 55 | def vectorize_and_densify(files, file_type, dim, whole_word_matching, token2id, output_path, args): 56 | data_num = 0 57 | logger.info('count line number') 58 | for file in files: 59 | if file_type == 'jsonl.gz': 60 | f = gzip.open(file, "rb") 61 | else: 62 | f = open(file, 'r') 63 | for line in f: 64 | data_num+=1 65 | f.close() 66 | 67 | logger.info('initialize numpy array with {}X{}'.format(data_num, dim)) 68 | value_encoded = np.zeros((data_num, dim), dtype=np.float16) 69 | if whole_word_matching: 70 | index_encoded = np.zeros((data_num, dim), dtype=np.int16) 71 | else: 72 | index_encoded = np.zeros((data_num, dim), dtype=np.int8) 73 | docids =[] 74 | total_collision_num = 0 75 | counter = 0 76 | for file in files: 77 | if file_type == 'jsonl.gz': 78 | f = gzip.open(file, "rb") 79 | else: 80 | f = open(file, 'r') 81 | for i, line in tqdm(enumerate(f), desc=f"densify {file}"): 82 | data = json.loads(line) 83 | docids.append(data['id']) 84 | value, index, collision_num = densify(data, dim, whole_word_matching, token2id, args) 85 | total_collision_num += collision_num 86 | value_encoded[counter] = value 87 | index_encoded[counter] = index 88 | counter += 1 89 | f.close() 90 | 91 | print('Total {} collisions with {} passages'.format(total_collision_num, data_num)) 92 | with open(output_path, 'wb') as f_out: 93 | pickle.dump([value_encoded, index_encoded, docids], f_out, protocol=4) 94 | 95 | 96 | def get_files(directory): 97 | files = glob.glob(os.path.join(directory, '*.json')) 98 | if len(files) == 0: 99 | files = glob.glob(os.path.join(directory, '*.jsonl.gz')) 100 | file_type = 'jsonl.gz' 101 | else: 102 | file_type = 'json' 103 | if len(files) == 0: 104 | raise ValueError('There is no json or jsonl.gz files in {}'.format(directory)) 105 | return files, file_type 106 | 107 | def main(): 108 | parser = argparse.ArgumentParser(description='Densify corpus') 109 | parser.add_argument('--model', required=True, help='bm25, deepimpact, unicoil or splade') 110 | parser.add_argument('--tokenizer', required=False, default="bert-base-uncased", help='anserini index path or transformer tokenizer') 111 | parser.add_argument('--vector_dir', required=True, help='directory with json files') 112 | parser.add_argument('--output_dir', required=True, help='output pickle directory') 113 | parser.add_argument('--output_dims', type=int, required=False, default=768) 114 | parser.add_argument('--num_workers', type=int, required=False, default=None) 115 | parser.add_argument('--prefix', required=True, help='index name prefix') 116 | args = parser.parse_args() 117 | 118 | token2id = {} 119 | if (args.model == 'bm25') or (args.model == 'deepimpact'): 120 | tokenizer = IndexReader(args.tokenizer) 121 | for idx, token in tqdm(enumerate(tokenizer.terms()), desc=f"read index terms"): 122 | token2id[token.term] = idx 123 | elif (args.model == 'unicoil') or (args.model == 'splade'): 124 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 125 | token2id = tokenizer.vocab 126 | else: 127 | raise ValueError('We cannot handle you input model') 128 | 129 | 130 | if not os.path.exists(args.output_dir): 131 | os.mkdir(args.output_dir) 132 | 133 | densified_vector_dir = os.path.join(args.output_dir, f'encoding') 134 | if not os.path.exists(densified_vector_dir): 135 | os.mkdir(densified_vector_dir) 136 | 137 | files, file_type = get_files(args.vector_dir) 138 | 139 | total_num_files = len(files) 140 | if args.num_workers is None: 141 | args.num_workers = total_num_files 142 | num_files_per_worker = 1 143 | else: 144 | num_files_per_worker = total_num_files//args.num_workers 145 | if (total_num_files%args.num_workers) != 0: 146 | args.num_workers+=1 147 | 148 | pool = Pool(args.num_workers) 149 | for i in range(args.num_workers): 150 | start = i*num_files_per_worker 151 | output_path = os.path.join(densified_vector_dir, f"{args.prefix}.split{i}.pt") 152 | 153 | if i==(args.num_workers-1): 154 | pool.apply_async(vectorize_and_densify ,(files[start:], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args)) 155 | else: 156 | pool.apply_async(vectorize_and_densify ,(files[start:(start+num_files_per_worker)], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args)) 157 | 158 | # for debug 159 | # vectorize_and_densify(files[start:(start+num_files_per_worker)], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args) 160 | 161 | pool.close() 162 | pool.join() 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /densify/densify_query.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import glob 4 | import numpy as np 5 | import json 6 | import argparse 7 | from collections import defaultdict 8 | from pyserini.index import IndexReader 9 | from pyserini.analysis import Analyzer, get_lucene_analyzer 10 | from multiprocessing import Pool, Manager, Queue 11 | import multiprocessing 12 | import os 13 | from tqdm import tqdm 14 | from transformers import AutoModelForMaskedLM, AutoTokenizer 15 | from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder 16 | from .densify_corpus import densify, whole_word_matching 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser( 21 | description='Transform corpus into wordpiece corpus') 22 | parser.add_argument('--model', required=True, help='bm25, deepimpact, unicoil or splade') 23 | parser.add_argument('--tokenizer', required=False, default="bert-base-uncased", help='anserini index path or transformer tokenizer') 24 | parser.add_argument('--query_path', required=True, help='query tsv file') 25 | parser.add_argument('--output_dims', type=int, required=False, default=768) 26 | parser.add_argument('--output_dir', required=True, help='output pickle directory') 27 | parser.add_argument('--prefix', required=True, help='index name prefix') 28 | args = parser.parse_args() 29 | 30 | 31 | if not os.path.exists(args.output_dir): 32 | os.mkdir(args.output_dir) 33 | 34 | args.output_dir = os.path.join(args.output_dir, f'encoding') 35 | if not os.path.exists(args.output_dir): 36 | os.mkdir(args.output_dir) 37 | 38 | densified_vector_dir = os.path.join(args.output_dir, f"queries") 39 | if not os.path.exists(densified_vector_dir): 40 | os.mkdir(densified_vector_dir) 41 | 42 | 43 | args.model = args.model.lower() 44 | token2id = {} 45 | if (args.model == 'bm25') or (args.model == 'deepimpact'): 46 | analyzer = Analyzer(get_lucene_analyzer()) 47 | tokenizer = IndexReader(args.tokenizer) 48 | for idx, token in tqdm(enumerate(tokenizer.terms()), desc=f"read index terms"): 49 | token2id[token.term] = idx 50 | if args.model == 'bm25': 51 | analyze = True 52 | else: 53 | analyze = False 54 | query_encoder = None 55 | elif (args.model == 'unicoil') or (args.model == 'splade'): 56 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 57 | token2id = tokenizer.vocab 58 | if args.model == 'unicoil': 59 | query_encoder = UniCoilQueryEncoder('castorini/unicoil-msmarco-passage') 60 | else: 61 | raise ValueError('We cannot handle you input --model') 62 | 63 | 64 | f = open(args.query_path, 'r') 65 | print('count line number') 66 | data_num = 0 67 | for line in f: 68 | data_num+=1 69 | f.close() 70 | 71 | print('initialize numpy array with {}X{}'.format(data_num, args.output_dims)) 72 | value_encoded = np.zeros((data_num, args.output_dims), dtype=np.float16) 73 | index_encoded = np.zeros((data_num, args.output_dims), dtype=np.int16) 74 | 75 | qids = [] 76 | data = {} 77 | total_collision_num = 0 78 | with open(args.query_path, 'r') as f: 79 | for i, line in enumerate(f): 80 | qid, query = line.strip().split('\t') 81 | if query_encoder is None: 82 | if analyze: 83 | analyzed_query_terms = analyzer.analyze(query) 84 | else: 85 | analyzed_query_terms = query.split(' ') 86 | # use tf as term weight 87 | vector = defaultdict(int) 88 | for analyzed_query_term in analyzed_query_terms: 89 | vector[analyzed_query_term] += 1 90 | else: 91 | vector = query_encoder.encode(query) 92 | 93 | data['vector'] = vector 94 | 95 | qids.append(qid) 96 | value, index, collision_num = densify(data, args.output_dims, whole_word_matching[args.model] , token2id, args) 97 | total_collision_num += collision_num 98 | value_encoded[i] = value 99 | index_encoded[i] = index 100 | 101 | 102 | 103 | print('Total {} collisions with {} queries'.format(total_collision_num, i+1)) 104 | file_name = args.prefix + '.' + (args.query_path).split('/')[-1].replace('tsv','pt') 105 | output_path = os.path.join(densified_vector_dir, file_name) 106 | with open(output_path, 'wb') as f_out: 107 | pickle.dump([value_encoded, index_encoded, qids], f_out, protocol=4) 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /densify/output_vector.py: -------------------------------------------------------------------------------- 1 | from pyserini.search import SimpleSearcher 2 | from pyserini.index import IndexReader 3 | import json 4 | from tqdm import tqdm 5 | import argparse 6 | import itertools 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser( 9 | description='Extract text contents from anserini index') 10 | parser.add_argument('--index_path', required=True, help='anserini index path') 11 | parser.add_argument('--output_path', required=True, help='Output file in the anserini jsonl format.') 12 | parser.add_argument('--tf_only' , action='store_true') 13 | args = parser.parse_args() 14 | 15 | index_reader = IndexReader(args.index_path) 16 | searcher = SimpleSearcher(args.index_path) 17 | total_num_docs = searcher.num_docs 18 | 19 | # term_dict = {} 20 | # for idx, term in tqdm(enumerate(index_reader.terms()), desc=f"read index terms"): 21 | # term_dict[term.term] = idx 22 | 23 | fout = open(args.output_path, 'w') 24 | for i in tqdm(range(total_num_docs), total=total_num_docs, desc=f"compute bm25 vector"): 25 | docid = searcher.doc(i).docid() 26 | tf = index_reader.get_document_vector(docid) 27 | vector = {} 28 | for term in tf: 29 | vector[term] = index_reader.compute_bm25_term_weight(docid, term, analyzer=None) 30 | output_dict = {'id': docid, 'vector': vector} 31 | fout.write(json.dumps(output_dict) + '\n') 32 | fout.close() -------------------------------------------------------------------------------- /docs/aggretriever/beir-eval.md: -------------------------------------------------------------------------------- 1 | # BEIR Evaluation 2 | ## Evaluation with Sentence Transformer 3 | We use [BEIR](https://github.com/beir-cellar/beir) API to conduct brute-force search. 4 | ``` 5 | git clone https://huggingface.co/jacklin/DistilBERT-AGG 6 | export MODEL_DIR=DistilBERT-AGG 7 | export CUDA_VISIBLE_DEVICES=0 8 | export MODEL=AGG 9 | export AGGDIM=640 10 | export CORPUS=scifact 11 | python -m tevatron.datasets.beir.encode_and_retrieval --dataset ${CORPUS} --model_name_or_path ${MODEL_DIR} --model ${MODEL} --agg_dim ${AGGDIM} 12 | ``` 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/aggretriever/msmarco-passage-train-eval.md: -------------------------------------------------------------------------------- 1 | # Training and Inference on MSMARCO Passage ranking 2 | In the following, we describe how to train, encode and retrieve with Aggretriever on MS MARCO passage-v1. 3 | 1. [MS MARCO Passage-v1 Data Preparation](#msmarco_data_prep) 4 | 1. [Training](#training) 5 | 1. [Generate Passage and Query Embeddings](#generate_embeddings) 6 | 1. [End-To-End Retrieval](#retrieval) 7 | 1. [Evaluation](#evaluation) 8 | 9 | ## Data Preparation 10 | We first preprocess the corpus, development queries and official training data in the json format. Each passage in the corpus is a line with the format: `{"text_id": passage_id, "text": [vocab_ids]}`. Similarly, each query in the development set is a line with the format: `{"text_id": query_id, "text": [vocab_ids]}`. As for training data, we rearrange the official training data in the format: `{"query": [vocab_ids], "positive_pids": [positive_passage_id0, positive_passage_id1, ...], "negative_pids": [negative_passage_id0, negative_passage_id1, ...]}`. Note that we use string type for passage and query. You can also download our preprocessed data on huggingface hub: [official_train](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus), [queries](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_queries) and [corpus](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus). 11 | 12 | ## Training 13 | This below script is the Aggretriever training in our paper. Here we use distilbert-base-uncased as an example. You can switch to any backbone using `--model_name_or_path`. 14 | ```shell=bash 15 | export CUDA_VISIBLE_DEVICES=0 16 | export MODEL=AGG 17 | export CLSDIM=128 18 | export AGGDIM=640 19 | export MODEL_DIR=${MODEL}_CLS${CLSDIM}XAGG${AGGDIM} 20 | export DATA_DIR=need_your_assignment 21 | 22 | python -m tevatron.driver.train \ 23 | --output_dir ${MODEL_DIR} \ 24 | --train_dir ${DATA_DIR}/official_train \ 25 | --corpus_dir ${DATA_DIR}/corpus \ 26 | --model_name_or_path distilbert-base-uncased \ 27 | --do_train \ 28 | --save_steps 20000 \ 29 | --fp16 \ 30 | --per_device_train_batch_size 8 \ 31 | --learning_rate 5e-6 \ 32 | --q_max_len 32 \ 33 | --p_max_len 128 \ 34 | --num_train_epochs 3 \ 35 | --add_pooler \ 36 | --model ${MODEL} \ 37 | --projection_out_dim ${CLSDIM} \ 38 | --agg_dim ${AGGDIM} 39 | --train_n_passages 8 \ 40 | --dataloader_num_workers 8 \ 41 | ``` 42 | 43 | ## Inference MSMARCO Passage for Retrieval 44 | ``` 45 | export CUDA_VISIBLE_DEVICES=0 46 | export CORPUS=msmarco-passage 47 | export SPLIT=dev.small 48 | export INDEX_DIR=${MODEL_DIR}/encoding 49 | export DATA_DIR=need_your_assignment 50 | 51 | # Corpus 52 | for i in $(seq -f "%02g" 0 10) 53 | do 54 | echo '============= Inference doc.split '${i} ' =============' 55 | srun --gres=gpu:p100:1 --mem=16G --cpus-per-task=2 --time=1:40:00 \ 56 | python -m tevatron.driver.encode \ 57 | --output_dir ${MODEL_DIR} \ 58 | --model_name_or_path ${MODEL_DIR} \ 59 | --add_pooler \ 60 | --projection_out_dim ${CLSDIM} \ 61 | --agg_dim ${AGGDIM} \ 62 | --model ${MODEL} \ 63 | --fp16 \ 64 | --p_max_len 128 \ 65 | --per_device_eval_batch_size 128 \ 66 | --encode_in_path ${DATA_DIR}/corpus/split${i}.json \ 67 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt & 68 | done 69 | 70 | # Merge index 71 | python -m retrieval.index \ 72 | --index_path ${INDEX_DIR} \ 73 | --index_prefix ${CORPUS} 74 | mkdir ${INDEX_DIR}/index 75 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/ 76 | 77 | # Queries 78 | for SPLIT in dev.small 79 | do 80 | mkdir ${INDEX_DIR}/queries 81 | python -m tevatron.driver.encode \ 82 | --output_dir ${MODEL_DIR} \ 83 | --model_name_or_path ${MODEL_DIR} \ 84 | --fp16 \ 85 | --q_max_len 32 \ 86 | --model ${MODEL} \ 87 | --encode_is_qry \ 88 | --add_pooler \ 89 | --projection_out_dim ${CLSDIM} \ 90 | --agg_dim ${AGGDIM} \ 91 | --per_device_eval_batch_size 128 \ 92 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \ 93 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt 94 | done 95 | ``` 96 | 97 | ## End-to-End Retrieval 98 | ``` 99 | # IP retrieval 100 | for shrad in 0 101 | do 102 | echo 'run shrad'$shrad 103 | python -m retrieval.gip_retrieval \ 104 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \ 105 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \ 106 | --topk 1000 \ 107 | --total_shrad 1 \ 108 | --shrad $shrad \ 109 | --IP \ 110 | --use_gpu \ 111 | done 112 | ``` 113 | 114 | ## Evaluation 115 | The run file, result.trec, is in the trec format so that you can directly evaluate the result using pyserini. 116 | ``` 117 | python -m pyserini.eval.trec_eval -c -M 10 -m recip_rank ${QREL_PATH} result.trec 118 | python -m pyserini.eval.trec_eval -c -m recall.1000 ${QREL_PATH} result.trec 119 | ``` 120 | 121 | 122 | -------------------------------------------------------------------------------- /docs/dhr/beir-eval.md: -------------------------------------------------------------------------------- 1 | # BEIR Evaluation 2 | We provide two scripts for BEIR evaluation and use the model, [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P), and the dataset, trec-covid, as an example. 3 | 1. [Evaluation with GIP Retrieval](#evaluation_with_gip) 4 | 1. [Evaluation with Sentence Transformer](#evaluation_with_sentence_transformer) 5 | 6 | ## Evaluation with GIP Retrieval 7 | We first downlaod our model and beir dataset. 8 | ``` 9 | git clone https://huggingface.co/jacklin/DeLADE-CLS-P 10 | export MODEL_DIR=DeLADE-CLS-P 11 | export CORPUS=trec-covid 12 | export SPLIT=test 13 | python -m tevatron.datasets.beir.preprocess --dataset ${CORPUS} 14 | ``` 15 | Then we tokenize the query and corpus. 16 | ``` 17 | python -m tevatron.utils.tokenize_corpus \ 18 | --corpus_path ./dataset/${CORPUS}/corpus/collection.json \ 19 | --output_dir ./dataset/${CORPUS}/tokenized_data/corpus \ 20 | --corpus_domain beir \ 21 | --tokenize --encode --num_workers 10 22 | 23 | python -m tevatron.utils.tokenize_query \ 24 | --qry_file ./dataset/${CORPUS}/queries/queries.${SPLIT}.tsv \ 25 | --output_dir ./dataset/${CORPUS}/tokenized_data/queries 26 | 27 | ``` 28 | Following the [inference scripts](https://github.com/castorini/DHR/blob/main/docs/msmarco-passage-train-eval.md#inference-msmarco-passage-for-retrieval) for msmarco-passage data, we run inference, GIP retrieval and evaluation on the BEIR dataset. 29 | ``` 30 | export CUDA_VISIBLE_DEVICES=0 31 | export MODEL=DHR #change to DLR if you use DLR model 32 | export CLSDIM=128 33 | export DLRDIM=768 34 | export CORPUS=trec-covid 35 | export SPLIT=test 36 | export INDEX_DIR=${MODEL_DIR}/encoding${DLRDIM} 37 | export DATA_DIR=./dataset/${CORPUS}/tokenized_data 38 | 39 | # Corpus 40 | for file in ${DATA_DIR}/corpus/split*.json 41 | do 42 | i=$(echo $file |rev | cut -c -7 |rev | cut -c -2 ) 43 | echo "===========inference ${file}===========" 44 | python -m tevatron.driver.encode \ 45 | --output_dir ${MODEL_DIR} \ 46 | --model_name_or_path ${MODEL_DIR} \ 47 | --projection_out_dim ${CLSDIM} \ 48 | --dlr_out_dim ${DLRDIM} \ 49 | --model ${MODEL} \ 50 | --add_pooler \ 51 | --combine_cls \ 52 | --fp16 \ 53 | --p_max_len 512 \ 54 | --per_device_eval_batch_size 32 \ 55 | --encode_in_path ${file} \ 56 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt 57 | done 58 | 59 | # Merge index 60 | python -m retrieval.index \ 61 | --index_path ${INDEX_DIR} \ 62 | --index_prefix ${CORPUS} 63 | mkdir ${INDEX_DIR}/index 64 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/ 65 | 66 | # QUERY 67 | mkdir ${INDEX_DIR}/queries 68 | python -m tevatron.driver.encode \ 69 | --output_dir ${MODEL_DIR} \ 70 | --model_name_or_path ${MODEL_DIR} \ 71 | --fp16 \ 72 | --q_max_len 512 \ 73 | --model ${MODEL} \ 74 | --encode_is_qry \ 75 | --combine_cls \ 76 | --add_pooler \ 77 | --projection_out_dim ${CLSDIM} \ 78 | --dlr_out_dim ${DLRDIM} \ 79 | --per_device_eval_batch_size 128 \ 80 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \ 81 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt 82 | 83 | ``` 84 | ``` 85 | # GIP retrieval 86 | for shrad in 0 87 | do 88 | echo 'run shrad'$shrad 89 | python -m retrieval.gip_retrieval \ 90 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \ 91 | --emb_dim ${DLRDIM} \ 92 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \ 93 | --topk 1000 \ 94 | --total_shrad 1 \ 95 | --shrad $shrad \ 96 | --theta 0.3 \ 97 | --rerank \ 98 | --use_gpu \ 99 | --combine_cls 100 | done 101 | ``` 102 | ``` 103 | # Evaluation 104 | python -m pyserini.eval.trec_eval -c -mndcg_cut.10 -mrecall.100 ./dataset/${CORPUS}/qrels/qrels.${SPLIT}.tsv result.trec 105 | python -m retrieval.rcap_eval --qrel_file_path ./dataset/${CORPUS}/qrels/qrels.${SPLIT}.tsv --run_file_path result.trec 106 | 107 | ``` 108 | ## Evaluation with Sentence Transformer 109 | The second one is to directly use [BEIR](https://github.com/beir-cellar/beir) API to conduct brute-force search. No densification before retrieval; thus, the result is slightly different from the numbers reported in our paper. Note that, for this script, we currently only support our DHR models, [DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) and [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P). 110 | ``` 111 | git clone https://huggingface.co/jacklin/DeLADE-CLS-P 112 | export MODEL_DIR=DeLADE-CLS-P 113 | python -m tevatron.datasets.beir.encode_and_retrieval --dataset trec-covid --model_name_or_path ${MODEL_DIR} 114 | ``` 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /docs/dhr/densify_exp.md: -------------------------------------------------------------------------------- 1 | # Densify Sparse Vector 2 | The repo is to demonstrate how to densify existing sparse lexical retrievers for dense search. We use [pyserini](https://github.com/castorini/pyserini) to get the sparse vectors from models. We show how to densify BM25 on msmarco-passage ranking dataset in this repo. 3 | 1. [Densifying BM25](#densifying_bm25) 4 | 1. [Densifying uniCOIL](#densifying_uniCOIL) 5 | 6 | # Densifying BM25 7 | ## Data Prepare 8 | Folloing the [instruction](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-passage.md), we first download MSMARCO passage collection and query files. Then, convert the collection.tsv into json file in $COLLECTION_PATH for pyserini index, and put queries.dev.small.tsv file into $Q_DIR. 9 | ```shell=bash 10 | export COLLECTION_PATH=need_your_assignment 11 | export INDEX_PATH=need_your_assignment 12 | export VECTOR_DIR=need_your_assignment 13 | export Q_DIR=need_your_assignment 14 | export MODEL=BM25 15 | export DLRDIM=768 16 | export CORPUS=msmarco-passage 17 | export DLR_PATH=${MODEL}_DIM${DLRDIM} 18 | export SPLIT=dev.small 19 | ``` 20 | ## Output BM25 Vector from index 21 | We first index the json corpus using BM25. 22 | ```shell=bash 23 | python -m pyserini.index.lucene \ 24 | --collection JsonVectorCollection \ 25 | --input ${COLLECTION_PATH} \ 26 | --index ${INDEX_PATH} \ 27 | --generator DefaultLuceneDocumentGenerator \ 28 | --threads 12 \ 29 | --storeDocvectors --storeRaw --optimize 30 | ``` 31 | Then, we output the sparse vector in a json file. We split the json file into multiple splits for multi-process in the next step. 32 | ```shell=bash 33 | python -m densify.output_vector \ 34 | --index_path ${INDEX_PATH} \ 35 | --output_path ${VECTOR_DIR}/split.json 36 | 37 | split -a 2 -dl 1000000 --additional-suffix=.json ${VECTOR_DIR}/split.json ${VECTOR_DIR}/split 38 | rm ${VECTOR_DIR}/split.json 39 | ``` 40 | ## Sparse vector densification 41 | We now start to densify corpus and queries. 42 | ```shell=bash 43 | python -m densify.densify_corpus \ 44 | --model ${MODEL} \ 45 | --prefix ${CORPUS} \ 46 | --tokenizer ${INDEX_PATH} \ 47 | --vector_dir ${VECTOR_DIR} \ 48 | --output_dir ${DLR_PATH} \ 49 | --output_dims ${DLRDIM} 50 | 51 | python -m densify.densify_query \ 52 | --model bm25 \ 53 | --prefix ${CORPUS} \ 54 | --tokenizer ${INDEX_PATH} \ 55 | --query_path ${Q_DIR}/queries.${SPLIT}.tsv \ \ 56 | --output_dir ${DLR_PATH} \ 57 | --output_dims ${DLRDIM} \ 58 | ``` 59 | ## BM25 search on GPU 60 | We then merge index and start DLR search. 61 | ```shell=bash 62 | # Merge index 63 | python -m retrieval.index \ 64 | --index_path ${DLR_PATH}/encoding \ 65 | --index_prefix ${CORPUS} \ 66 | 67 | mkdir ${DLR_PATH}/encoding/index 68 | mv ${DLR_PATH}/encoding/${CORPUS}.index.pt ${DLR_PATH}/encoding/index/ 69 | 70 | # Search 71 | python -m retrieval.gip_retrieval \ 72 | --query_emb_path ${DLR_PATH}/encoding/queries/queries.${CORPUS}.${SPLIT}.pt \ 73 | --emb_dim ${DLRDIM} \ 74 | --index_path ${DLR_PATH}/encoding/index/${CORPUS}.index.pt \ 75 | --theta 1 \ 76 | --rerank \ 77 | --use_gpu \ 78 | ``` 79 | 80 | # Densifying uniCOIL 81 | ## Data Prepare 82 | Folloing the [instruction](https://github.com/castorini/pyserini/blob/master/docs/experiments-unicoil.md), we download pre-encoded uniCOIL passage collection. 83 | ```shell=bash 84 | wget https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/data/msmarco-passage-unicoil.tar -P collections/ 85 | 86 | tar xvf collections/msmarco-passage-unicoil.tar -C collections/ 87 | ``` 88 | ```shell=bash 89 | export MODEL=uniCOIL 90 | export DLRDIM=768 91 | export CORPUS=msmarco-passage 92 | export VECTOR_DIR=./collections/msmarco-passage-unicoil-b8 93 | export DLR_PATH=${MODEL}_DIM${DLRDIM} 94 | export SPLIT=dev.small 95 | ``` 96 | ## Sparse vector densification 97 | We now start to densify corpus and queries. 98 | ```shell=bash 99 | python -m densify.densify_corpus \ 100 | --model ${MODEL} \ 101 | --prefix ${CORPUS} \ 102 | --vector_dir ${VECTOR_DIR} \ 103 | --output_dir ${DLR_PATH} \ 104 | --output_dims ${DLRDIM} 105 | 106 | python -m densify.densify_query \ 107 | --model ${MODEL} \ 108 | --prefix ${CORPUS} \ 109 | --query_path ${Q_DIR}/queries.${SPLIT}.tsv \ \ 110 | --output_dir ${DLR_PATH} \ 111 | --output_dims ${DLRDIM} \ 112 | ``` 113 | 114 | We then merge index and start DLR search. 115 | ```shell=bash 116 | # Merge index 117 | python -m retrieval.index \ 118 | --index_path ${DLR_PATH}/encoding \ 119 | --index_prefix ${CORPUS} \ 120 | 121 | mkdir ${DLR_PATH}/encoding/index 122 | mv ${DLR_PATH}/encoding/${CORPUS}.index.pt ${DLR_PATH}/encoding/index/ 123 | 124 | # Search 125 | python -m retrieval.gip_retrieval \ 126 | --query_emb_path ${DLR_PATH}/encoding/queries/queries.${CORPUS}.${SPLIT}.pt \ 127 | --emb_dim ${DLRDIM} \ 128 | --index_path ${DLR_PATH}/encoding/index/${CORPUS}.index.pt \ 129 | --theta 1 \ 130 | --rerank \ 131 | --use_gpu \ 132 | ``` -------------------------------------------------------------------------------- /docs/dhr/msmarco-passage-train-eval.md: -------------------------------------------------------------------------------- 1 | # Training and Inference on MSMARCO Passage ranking 2 | In the following, we describe how to train, encode and retrieve with DHR on MS MARCO passage-v1. 3 | 1. [MS MARCO Passage-v1 Data Preparation](#msmarco_data_prep) 4 | 1. [Training](#training) 5 | 1. [Generate Passage and Query Embeddings](#generate_embeddings) 6 | 1. [End-To-End Retrieval](#retrieval) 7 | 1. [Retrieval on GPU](#retrieval_on_gpu) 8 | 1. [Retrieval on CPU](#retrieval_on_cpu) 9 | 1. [Evaluation](#evaluation) 10 | 11 | 12 | ## MS MARCO Passage-v1 Data Preparation 13 | We first preprocess the corpus, development queries and official training data in the json format. Each passage in the corpus is a line with the format: `{"text_id": passage_id, "text": [vocab_ids]}`. Similarly, each query in the development set is a line with the format: `{"text_id": query_id, "text": [vocab_ids]}`. As for training data, we rearrange the official training data in the format: `{"query": [vocab_ids], "positive_pids": [positive_passage_id0, positive_passage_id1, ...], "negative_pids": [negative_passage_id0, negative_passage_id1, ...]}`. Note that we use string type for passage and query. You can also download our preprocessed data on huggingface hub: [official_train](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus), [queries](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_queries) and [corpus](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus). 14 | 15 | ## Training 16 | This below script is the DHR (DLR) training in our paper. You can simply switch ${MODEL} from DHR to DLR, and the option `--combine_cls` would be turned off automatically. 17 | ```shell=bash 18 | export CUDA_VISIBLE_DEVICES=0 19 | export MODEL=DHR 20 | export CLSDIM=128 21 | export DLRDIM=768 22 | export MODEL_DIR=${MODEL}_CLS${CLSDIM} 23 | export DATA_DIR=need_your_assignment 24 | 25 | python -m tevatron.driver.train \ 26 | --output_dir ${MODEL_DIR} \ 27 | --train_dir ${DATA_DIR}/official_train \ 28 | --corpus_dir ${DATA_DIR}/corpus \ 29 | --model_name_or_path distilbert-base-uncased \ 30 | --do_train \ 31 | --save_steps 20000 \ 32 | --fp16 \ 33 | --per_device_train_batch_size 24 \ 34 | --learning_rate 7e-6 \ 35 | --q_max_len 32 \ 36 | --p_max_len 150 \ 37 | --num_train_epochs 6 \ 38 | --add_pooler \ 39 | --model ${MODEL} \ 40 | --projection_out_dim ${CLSDIM} \ 41 | --train_n_passages 8 \ 42 | --dataloader_num_workers 8 \ 43 | --combine_cls \ 44 | ``` 45 | 46 | ## Generate Passage and Query Embeddings 47 | ``` 48 | export CUDA_VISIBLE_DEVICES=0 49 | export MODEL=DHR #place DHR for DeLADE+[CLS] and DLR for DeLADE 50 | export CLSDIM=128 51 | export DLRDIM=768 52 | export MODEL_DIR=${MODEL}_CLS${CLSDIM} 53 | export CORPUS=msmarco-passage 54 | export SPLIT=dev.small 55 | export INDEX_DIR=${MODEL_DIR}/encoding${DLRDIM} 56 | export DATA_DIR=need_your_assignment 57 | 58 | # Corpus 59 | for i in $(seq -f "%02g" 0 10) 60 | do 61 | echo '============= Inference doc.split '${i} ' =============' 62 | srun --gres=gpu:p100:1 --mem=16G --cpus-per-task=2 --time=1:40:00 \ 63 | python -m tevatron.driver.encode \ 64 | --output_dir ${MODEL_DIR} \ 65 | --model_name_or_path ${MODEL_DIR} \ 66 | --add_pooler \ 67 | --projection_out_dim ${CLSDIM} \ 68 | --dlr_out_dim ${DLRDIM} \ 69 | --combine_cls \ 70 | --model ${MODEL} \ 71 | --fp16 \ 72 | --p_max_len 150 \ 73 | --per_device_eval_batch_size 128 \ 74 | --encode_in_path ${DATA_DIR}/corpus/split${i}.json \ 75 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt & 76 | done 77 | 78 | # Merge index 79 | python -m retrieval.index \ 80 | --index_path ${INDEX_DIR} \ 81 | --index_prefix ${CORPUS} 82 | mkdir ${INDEX_DIR}/index 83 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/ 84 | 85 | # Queries 86 | for SPLIT in dev.small 87 | do 88 | mkdir ${INDEX_DIR}/queries 89 | python -m tevatron.driver.encode \ 90 | --output_dir ${MODEL_DIR} \ 91 | --model_name_or_path ${MODEL_DIR} \ 92 | --fp16 \ 93 | --q_max_len 32 \ 94 | --model ${MODEL} \ 95 | --encode_is_qry \ 96 | --combine_cls \ 97 | --add_pooler \ 98 | --projection_out_dim ${CLSDIM} \ 99 | --dlr_out_dim ${DLRDIM} \ 100 | --per_device_eval_batch_size 128 \ 101 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \ 102 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt 103 | done 104 | ``` 105 | 106 | ## End-to-end Retrieval 107 | ### Retrieval on GPU 108 | If you want to use GPU for retrieval, we suggest to use our implemented two-stage retrieval. 109 | ``` 110 | # GIP retrieval 111 | for shrad in 0 112 | do 113 | echo 'run shrad'$shrad 114 | python -m retrieval.gip_retrieval \ 115 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \ 116 | --emb_dim ${DLRDIM} \ 117 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \ 118 | --topk 1000 \ 119 | --total_shrad 1 \ 120 | --shrad $shrad \ 121 | --theta 0.3 \ 122 | --rerank \ 123 | --use_gpu \ 124 | --combine_cls \ 125 | done 126 | ``` 127 | 128 | ### Retrieval on CPU 129 | If you only have CPU, we suggest to first quanize the index; then, use our implemented two-stage retrieval. 130 | ``` 131 | # index quanization 132 | python -m retrieval.quantize_index \ 133 | --index_path ${INDEX_PATH}/index/${CORPUS}.index.pt \ 134 | --output_index_path ${INDEX_PATH}/index/${CORPUS}.pq64.faiss.index \ 135 | --qauntized_dim 64 136 | 137 | # GIP retrieval 138 | python -m retrieval.gip_retrieval \ 139 | --query_emb_path ${INDEX_PATH}/queries/queries.${CORPUS}.${SPLIT}.pt \ 140 | --index_path ${INDEX_PATH}/index/${CORPUS}.index.pt \ 141 | --faiss_pq_index_path ${INDEX_PATH}/index/${CORPUS}.pq64.faiss.index \ 142 | --emb_dim ${DLRDIM} \ 143 | --topk 1000 \ 144 | --lamda 1 \ 145 | --batch 1 \ 146 | --PQIP \ 147 | --rerank 148 | ``` 149 | 150 | ## Evaluation 151 | The run file, result.trec, is in the trec format so that you can directly evaluate the result using pyserini. 152 | ``` 153 | python -m pyserini.eval.trec_eval -c -M 10 -m recip_rank ${QREL_PATH} result.trec 154 | python -m pyserini.eval.trec_eval -c -m recall.1000 ${QREL_PATH} result.trec 155 | ``` 156 | 157 | 158 | -------------------------------------------------------------------------------- /fig/aggretriever.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/aggretriever.png -------------------------------------------------------------------------------- /fig/aggretriever_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/aggretriever_teaser.png -------------------------------------------------------------------------------- /fig/densification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/densification.png -------------------------------------------------------------------------------- /fig/single_model_fusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/single_model_fusion.png -------------------------------------------------------------------------------- /retrieval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/retrieval/__init__.py -------------------------------------------------------------------------------- /retrieval/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/retrieval/evaluation/__init__.py -------------------------------------------------------------------------------- /retrieval/evaluation/custom_metrics.py: -------------------------------------------------------------------------------- 1 | ## copy from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/custom_metrics.py 2 | import logging 3 | from typing import List, Dict, Union, Tuple 4 | 5 | def mrr(qrels: Dict[str, Dict[str, int]], 6 | results: Dict[str, Dict[str, float]], 7 | k_values: List[int]) -> Tuple[Dict[str, float]]: 8 | 9 | MRR = {} 10 | 11 | for k in k_values: 12 | MRR[f"MRR@{k}"] = 0.0 13 | 14 | k_max, top_hits = max(k_values), {} 15 | logging.info("\n") 16 | 17 | for query_id, doc_scores in results.items(): 18 | top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 19 | 20 | for query_id in top_hits: 21 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]) 22 | for k in k_values: 23 | for rank, hit in enumerate(top_hits[query_id][0:k]): 24 | if hit[0] in query_relevant_docs: 25 | MRR[f"MRR@{k}"] += 1.0 / (rank + 1) 26 | break 27 | 28 | for k in k_values: 29 | MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5) 30 | logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"])) 31 | 32 | return MRR 33 | 34 | def recall_cap(qrels: Dict[str, Dict[str, int]], 35 | results: Dict[str, Dict[str, float]], 36 | k_values: List[int]) -> Tuple[Dict[str, float]]: 37 | 38 | capped_recall = {} 39 | 40 | for k in k_values: 41 | capped_recall[f"R_cap@{k}"] = 0.0 42 | 43 | k_max = max(k_values) 44 | logging.info("\n") 45 | 46 | for query_id, doc_scores in results.items(): 47 | top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 48 | query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] 49 | for k in k_values: 50 | retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0] 51 | denominator = min(len(query_relevant_docs), k) 52 | capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator) 53 | 54 | for k in k_values: 55 | capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5) 56 | logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"])) 57 | 58 | return capped_recall 59 | 60 | 61 | def hole(qrels: Dict[str, Dict[str, int]], 62 | results: Dict[str, Dict[str, float]], 63 | k_values: List[int]) -> Tuple[Dict[str, float]]: 64 | 65 | Hole = {} 66 | 67 | for k in k_values: 68 | Hole[f"Hole@{k}"] = 0.0 69 | 70 | annotated_corpus = set() 71 | for _, docs in qrels.items(): 72 | for doc_id, score in docs.items(): 73 | annotated_corpus.add(doc_id) 74 | 75 | k_max = max(k_values) 76 | logging.info("\n") 77 | 78 | for _, scores in results.items(): 79 | top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 80 | for k in k_values: 81 | hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus] 82 | Hole[f"Hole@{k}"] += len(hole_docs) / k 83 | 84 | for k in k_values: 85 | Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5) 86 | logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"])) 87 | 88 | return Hole 89 | 90 | def top_k_accuracy( 91 | qrels: Dict[str, Dict[str, int]], 92 | results: Dict[str, Dict[str, float]], 93 | k_values: List[int]) -> Tuple[Dict[str, float]]: 94 | 95 | top_k_acc = {} 96 | 97 | for k in k_values: 98 | top_k_acc[f"Accuracy@{k}"] = 0.0 99 | 100 | k_max, top_hits = max(k_values), {} 101 | logging.info("\n") 102 | 103 | for query_id, doc_scores in results.items(): 104 | top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]] 105 | 106 | for query_id in top_hits: 107 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]) 108 | for k in k_values: 109 | for relevant_doc_id in query_relevant_docs: 110 | if relevant_doc_id in top_hits[query_id][0:k]: 111 | top_k_acc[f"Accuracy@{k}"] += 1.0 112 | break 113 | 114 | for k in k_values: 115 | top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5) 116 | logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"])) 117 | 118 | return top_k_acc -------------------------------------------------------------------------------- /retrieval/gip_retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import numpy as np 5 | import math 6 | from tqdm import tqdm 7 | from multiprocessing import Pool, Manager 8 | import pickle5 as pickle 9 | import torch 10 | import torch.nn as nn 11 | import time 12 | import faiss 13 | 14 | def faiss_search(query_embs, corpus_embs, batch=1, topk=1000): 15 | print('start faiss index') 16 | query_embs = np.concatenate([query_embs,query_embs], axis=1) 17 | corpus_embs = np.concatenate([corpus_embs,corpus_embs], axis=1) 18 | 19 | dimension = query_embs.shape[1] 20 | res = faiss.StandardGpuResources() 21 | res.noTempMemory() 22 | # res.setTempMemory(1000 * 1024 * 1024) # 1G GPU memory for serving query 23 | flat_config = faiss.GpuIndexFlatConfig() 24 | flat_config.device = 0 25 | flat_config.useFloat16=True 26 | index = faiss.GpuIndexFlatIP(res, dimension, flat_config) 27 | 28 | print("Load index to GPU...") 29 | index.add(corpus_embs) 30 | 31 | Distance = [] 32 | Index = [] 33 | print("Search with batch size %d"%(batch)) 34 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 35 | ' ', ETA(), ' ', FileTransferSpeed()] 36 | pbar = ProgressBar(widgets=widgets, maxval=query_embs.shape[0]//batch).start() 37 | start_time = time.time() 38 | 39 | for i in range(query_embs.shape[0]//batch): 40 | D,I=index.search(query_embs[i*batch:(i+1)*batch], topk) 41 | 42 | 43 | Distance.append(D) 44 | Index.append(I) 45 | pbar.update(i + 1) 46 | 47 | 48 | D,I=index.search(query_embs[(i+1)*batch:], topk) 49 | 50 | 51 | Distance.append(D) 52 | Index.append(I) 53 | 54 | time_per_query = (time.time() - start_time)/query_embs.shape[0] 55 | print('Retrieving {} queries ({:0.3f} s/query)'.format(query_embs.shape[0], time_per_query)) 56 | Distance = np.concatenate(Distance, axis=0) 57 | Index = np.concatenate(Index, axis=0) 58 | return Distance, Index 59 | 60 | def IP_retrieval(qids, query_embs, corpus_embs, args): 61 | 62 | description = 'Brute force IP search' 63 | 64 | 65 | all_results = {} 66 | all_scores = {} 67 | 68 | start_time = time.time() 69 | total_num_idx = 0 70 | for i, (query_emb) in tqdm(enumerate(query_embs), total=len(query_embs), desc=description): 71 | 72 | 73 | 74 | scores = torch.einsum('ij,j->i',(corpus_embs, query_emb)) 75 | sort_candidates = torch.argsort(scores, descending=True)[:args.topk] 76 | sort_scores = scores[sort_candidates] 77 | 78 | all_scores[qids[i]]=sort_scores.cpu().tolist() 79 | all_results[qids[i]]=sort_candidates.cpu().tolist() 80 | 81 | average_num_idx = total_num_idx/query_embs.shape[0] 82 | time_per_query = (time.time() - start_time)/query_embs.shape[0] 83 | print('Retrieving {} queries ({:0.3f} s/query), average number of index use {}'.format(query_embs.shape[0], time_per_query, average_num_idx)) 84 | 85 | return all_results, all_scores 86 | 87 | 88 | def GIP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs, args): 89 | if args.brute_force: 90 | args.theta = 0 91 | description = 'Brute force GIP search' 92 | else: 93 | if not args.IP: 94 | if args.rerank: 95 | description = 'GIP (\u03F4={}) retrieval w/ GIP rerank'.format(args.theta) 96 | else: 97 | description = 'GIP (\u03F4={}) retrieval w/o GIP rerank'.format(args.theta) 98 | else: 99 | if args.rerank: 100 | description = 'IP retrieval w/ GIP rerank' 101 | else: 102 | description = 'IP retrieval w/o GIP rerank' 103 | 104 | all_results = {} 105 | all_scores = {} 106 | 107 | start_time = time.time() 108 | total_num_idx = 0 109 | 110 | cls_dim = query_embs.shape[1] - args.emb_dim 111 | if cls_dim > 0: 112 | query_arg_idxs = torch.nn.functional.pad(query_arg_idxs, (0, cls_dim), mode='constant', value=1) 113 | corpus_arg_idxs = torch.nn.functional.pad(corpus_arg_idxs, (0, cls_dim), mode='constant', value=1) 114 | 115 | for i, (query_emb, query_arg_idx) in tqdm(enumerate(zip(query_embs, query_arg_idxs)), total=len(query_embs), desc=description): 116 | 117 | if args.theta==0: 118 | total_num_idx += args.emb_dim 119 | candidate_sparse_embs = ((corpus_arg_idxs==query_arg_idx)*corpus_embs) 120 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb)) 121 | del candidate_sparse_embs 122 | 123 | sort_idx = torch.topk(scores, args.topk, dim=0).indices 124 | sort_candidates = sort_idx 125 | sort_scores = scores[sort_idx] 126 | torch.cuda.empty_cache() 127 | 128 | else: 129 | 130 | num_idx = int((query_emb > args.theta).sum()) 131 | important_idx = torch.topk(query_emb, num_idx, dim=0).indices.tolist() 132 | 133 | if not args.IP: 134 | # Approximate GIP 135 | candidate_sparse_embs = ( (corpus_arg_idxs[:,important_idx]==query_arg_idx[important_idx]) * corpus_embs[:,important_idx] ) 136 | partial_scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb[important_idx])) 137 | else: 138 | # IN as an approximation 139 | partial_scores = torch.einsum('ij,j->i',(corpus_embs, query_emb)) 140 | 141 | if args.rerank: 142 | candidates = torch.topk(partial_scores, args.agip_topk, dim=0).indices 143 | 144 | candidate_sparse_embs = ((corpus_arg_idxs[candidates,:]==query_arg_idx)*corpus_embs[candidates]) 145 | 146 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb)) 147 | 148 | sort_idx = torch.topk(scores, args.topk, dim=0).indices 149 | sort_candidates = candidates[sort_idx] 150 | sort_scores = scores[sort_idx] 151 | 152 | del important_idx, candidates, candidate_sparse_embs, scores, sort_idx 153 | torch.cuda.empty_cache() 154 | else: 155 | sort_candidates = torch.topk(partial_scores, args.topk, dim=0).indices 156 | sort_scores = partial_scores[sort_candidates] 157 | 158 | all_scores[qids[i]]=sort_scores.cpu().tolist() 159 | all_results[qids[i]]=sort_candidates.cpu().tolist() 160 | 161 | average_num_idx = total_num_idx/query_embs.shape[0] 162 | time_per_query = (time.time() - start_time)/query_embs.shape[0] 163 | print('Retrieving {} queries ({:0.3f} s/query), average number of index use {}'.format(query_embs.shape[0], time_per_query, average_num_idx)) 164 | 165 | return all_results, all_scores 166 | 167 | def PQ_IP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs, args): 168 | assert args.faiss_pq_index_path is not None, 'you do not spesify your PQ index through --faiss_pq_index_path' 169 | print('Load PQ index ...') 170 | faiss_index = faiss.read_index(args.faiss_pq_index_path) 171 | 172 | if args.rerank: 173 | description = 'IP (Product Quantization) search w/ GIP rerank' 174 | else: 175 | description = 'IP (Product Quantization) search w/o GIP rerank' 176 | 177 | all_results = {} 178 | all_scores = {} 179 | 180 | cls_dim = query_embs.shape[1] - query_arg_idxs.shape[1] 181 | if cls_dim > 0: 182 | query_arg_idxs = torch.nn.functional.pad(query_arg_idxs, (0, cls_dim), mode='constant', value=1) 183 | corpus_arg_idxs = torch.nn.functional.pad(corpus_arg_idxs, (0, cls_dim), mode='constant', value=1) 184 | 185 | if len(query_embs)%args.batch == 0: 186 | total_batch = len(query_embs)//args.batch 187 | else: 188 | total_batch = len(query_embs)//args.batch + 1 189 | 190 | start_time = time.time() 191 | for i in tqdm(range(total_batch), total=total_batch, desc=description): 192 | 193 | if i == (total_batch -1): 194 | batch_query_embs = query_embs[i*args.batch:] 195 | batch_query_arg_idxs = query_arg_idxs[i*args.batch:] 196 | batch_qids = qids[i*args.batch:] 197 | else: 198 | batch_query_embs = query_embs[i*args.batch:(i+1)*args.batch] 199 | batch_query_arg_idxs = query_arg_idxs[i*args.batch:(i+1)*args.batch] 200 | batch_qids = qids[i*args.batch:(i+1)*args.batch] 201 | 202 | scores, candidates = faiss_index.search(batch_query_embs.numpy(), args.agip_topk) 203 | 204 | 205 | for i, (qid, query_emb, query_arg_idx, candidate) in enumerate(zip(batch_qids, batch_query_embs, batch_query_arg_idxs, candidates)): 206 | if args.rerank: 207 | candidate_sparse_embs = ((corpus_arg_idxs[candidate,:]==query_arg_idx)*corpus_embs[candidate]) 208 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb)) 209 | 210 | sort_idx = torch.topk(scores, args.topk, dim=0).indices 211 | sort_candidates = candidate[sort_idx] 212 | sort_scores = scores[sort_idx] 213 | 214 | all_scores[qid] = sort_scores.tolist() 215 | all_results[qid] = sort_candidates.tolist() 216 | 217 | # del candidates, candidate_sparse_embs, scores, sort_idx 218 | else: 219 | # sort_candidates = torch.argsort(partial_scores, descending=True)[:args.topk] 220 | all_scores[qid] = scores[i, :args.topk].tolist() 221 | all_results[qid] = candidates[i, :args.topk].tolist() 222 | 223 | 224 | torch.cuda.empty_cache() 225 | 226 | 227 | 228 | time_per_query = (time.time() - start_time)/query_embs.shape[0] 229 | print('Retrieving {} queries ({:0.3f} s/query)'.format(query_embs.shape[0], time_per_query)) 230 | 231 | return all_results, all_scores 232 | 233 | def main(): 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument("--query_emb_path", type=str, required=True) 236 | parser.add_argument("--index_path", type=str, required=True) 237 | parser.add_argument("--faiss_pq_index_path", type=str, default=None) 238 | parser.add_argument("--emb_dim", type=int, default=768, help='DLR dimension') 239 | parser.add_argument("--theta", type=float, default=0.1) 240 | parser.add_argument("--topk", type=int, default=1000) 241 | parser.add_argument("--agip_topk", type=int, default=10000) 242 | parser.add_argument("--combine_cls", action='store_true') 243 | parser.add_argument("--IP", action='store_true') 244 | parser.add_argument("--PQIP", action='store_true') 245 | parser.add_argument("--batch", type=int, default=1) 246 | parser.add_argument("--brute_force", action='store_true') 247 | parser.add_argument("--use_gpu", action='store_true') 248 | parser.add_argument("--rerank", action='store_true') 249 | parser.add_argument("--lamda", type=float, default=1, help='weight for [CSL] for concatenation') 250 | parser.add_argument("--total_shrad", type=int, default=1) 251 | parser.add_argument("--shrad", type=int, default=0) 252 | parser.add_argument("--run_name", type=str, default='h2oloo') 253 | args = parser.parse_args() 254 | 255 | if not args.use_gpu: 256 | if args.batch > 1: 257 | torch.set_num_threads(72) 258 | else: 259 | torch.set_num_threads(1) 260 | else: 261 | torch.cuda.set_device(0) 262 | 263 | # load query embeddings 264 | print('Load query embeddings ...') 265 | with open(args.query_emb_path, 'rb') as f: 266 | query_embs, query_arg_idxs, qids=pickle.load(f) 267 | 268 | if args.use_gpu: 269 | query_embs = torch.from_numpy(query_embs).cuda(0) 270 | try: 271 | query_arg_idxs = torch.from_numpy(query_arg_idxs).cuda(0) 272 | except: 273 | query_arg_idxs = None 274 | else: 275 | query_embs = torch.from_numpy(query_embs.astype(np.float32)) 276 | try: 277 | query_arg_idxs = torch.from_numpy(query_arg_idxs) 278 | except: 279 | query_arg_idxs = None 280 | 281 | cls_dim = query_embs.shape[1] - args.emb_dim 282 | if cls_dim > 0: 283 | query_embs[:,-cls_dim:] = args.lamda * query_embs[:,-cls_dim:] 284 | 285 | 286 | 287 | # load index 288 | print('Load index ...') 289 | with open(args.index_path, 'rb') as f: 290 | corpus_embs, corpus_arg_idxs, docids=pickle.load(f) 291 | 292 | doc_num_per_shrad = len(docids)//args.total_shrad 293 | if args.shrad==(args.total_shrad-1): 294 | corpus_embs = corpus_embs[doc_num_per_shrad*args.shrad:] 295 | try: 296 | corpus_arg_idxs = corpus_arg_idxs[doc_num_per_shrad*args.shrad:] 297 | except: 298 | corpus_arg_idxs = None 299 | docids = docids[doc_num_per_shrad*args.shrad:] 300 | else: 301 | corpus_embs = corpus_embs[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)] 302 | try: 303 | corpus_arg_idxs = corpus_arg_idxs[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)] 304 | except: 305 | corpus_arg_idxs = None 306 | docids = docids[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)] 307 | 308 | if args.use_gpu: 309 | corpus_embs = torch.from_numpy(corpus_embs).cuda(0) 310 | if corpus_arg_idxs is not None: 311 | corpus_arg_idxs = torch.from_numpy(corpus_arg_idxs).cuda(0) 312 | else: 313 | corpus_embs = torch.from_numpy(corpus_embs.astype(np.float32)) 314 | if corpus_arg_idxs is not None: 315 | corpus_arg_idxs = torch.from_numpy(corpus_arg_idxs) 316 | # density = corpus_embs!=0 317 | # density = density.sum(axis=1) 318 | # print(torch.sum(density)/8841823/args.emb_dim) 319 | 320 | 321 | if query_arg_idxs is not None: 322 | if not args.PQIP: 323 | results, scores = GIP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs ,args) 324 | else: 325 | results, scores = PQ_IP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs ,args) 326 | else: 327 | results, scores = IP_retrieval(qids, query_embs, corpus_embs, args) 328 | 329 | if args.total_shrad==1: 330 | fout = open('result.trec', 'w') 331 | else: 332 | fout = open('result{}.trec'.format(args.shrad), 'w') 333 | for i, query_id in tqdm(enumerate(results), total=len(results), desc=f"write results"): 334 | result = results[query_id] 335 | score = scores[query_id] 336 | 337 | for rank, docidx in enumerate(result): 338 | 339 | docid = docids[docidx] 340 | if (docid!=query_id): 341 | fout.write('{} Q0 {} {} {} {}\n'.format(query_id, docid, rank+1, score[rank], args.run_name)) 342 | fout.close() 343 | 344 | print('finish') 345 | 346 | 347 | if __name__ == "__main__": 348 | main() -------------------------------------------------------------------------------- /retrieval/index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | # os.environ['OMP_NUM_THREADS'] = str(32) 5 | import numpy as np 6 | import math 7 | from progressbar import * 8 | # from util import load_tfrecords_and_index, read_id_dict, faiss_index 9 | from multiprocessing import Pool, Manager 10 | import pickle 11 | import torch 12 | import torch.nn as nn 13 | import time 14 | 15 | 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--index_prefix", type=str, default='msmarco-passage') 21 | parser.add_argument("--emb_dim", type=int, default=768) 22 | parser.add_argument("--index_path", type=str, required=True) 23 | args = parser.parse_args() 24 | 25 | ## Merge index 26 | corpus_files = glob.glob(os.path.join(args.index_path, args.index_prefix + '.split*.pt')) 27 | 28 | corpus_embs = [] 29 | corpus_arg_idxs = [] 30 | docids = [] 31 | for corpus_file in corpus_files: 32 | with open(corpus_file, 'rb') as f: 33 | print('Load index: {}...'.format(corpus_file)) 34 | corpus_emb, corpus_arg_idx, docid=pickle.load(f) 35 | corpus_embs.append(corpus_emb) 36 | corpus_arg_idxs.append(corpus_arg_idx) 37 | docids += docid 38 | 39 | print('Merge index ...') 40 | try: 41 | corpus_arg_idxs = np.concatenate(corpus_arg_idxs, axis=0) 42 | except: 43 | corpus_arg_idxs = 0 44 | corpus_embs = np.concatenate(corpus_embs, axis=0) 45 | 46 | with open(os.path.join(args.index_path, args.index_prefix + '.index.pt'), 'wb') as f: 47 | pickle.dump([corpus_embs, corpus_arg_idxs, docids], f, protocol=4) 48 | 49 | 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /retrieval/merge.result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import glob 4 | import os 5 | import numpy as np 6 | from collections import defaultdict 7 | from progressbar import * 8 | 9 | 10 | 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--total_shrad", type=int, default=1) 16 | parser.add_argument("--topk", type=int, default=1000) 17 | parser.add_argument("--run_name", default='dhr') 18 | 19 | args = parser.parse_args() 20 | results = defaultdict(list) 21 | scores = defaultdict(list) 22 | for shrad in range(args.total_shrad): 23 | with open('result{:02d}.trec'.format(shrad), 'r') as f: 24 | for line in f: 25 | query_id, _, docid, rank, score, _ = line.strip().split(' ') 26 | score = float(score) 27 | results[query_id].append(docid) 28 | scores[query_id].append(score) 29 | 30 | 31 | print('write results ...') 32 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 33 | ' ', ETA(), ' ', FileTransferSpeed()] 34 | pbar = ProgressBar(widgets=widgets, maxval=10*len(results)).start() 35 | fout = open('result.trec', 'w') 36 | for i, query_id in enumerate(results): 37 | score = scores[query_id] 38 | result = results[query_id] 39 | sort_idx = np.array(score).argsort()[::-1][:args.topk] 40 | for rank, idx in enumerate(sort_idx): 41 | fout.write('{} Q0 {} {} {} {}\n'.format(query_id, result[idx], rank+1, score[idx], args.run_name)) 42 | pbar.update(10 * i + 1) 43 | fout.close() 44 | 45 | 46 | 47 | if __name__ == "__main__": 48 | main() -------------------------------------------------------------------------------- /retrieval/quantize_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pickle5 as pickle 5 | import faiss 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--index_path", type=str, required=True) 11 | parser.add_argument("--output_index_path", type=str, default=None) 12 | parser.add_argument("--qauntized_dim", type=int, default=64) 13 | parser.add_argument("--n_bits", type=int, default=8) 14 | args = parser.parse_args() 15 | 16 | if args.output_index_path is None: 17 | # assign to index dir 18 | index_dir = '/'.join(index_path.split('/')[:-1]) 19 | args.output_index_path = os.path.join(index_dir, 'pq{}_index'.format(args.qauntized_dim)) 20 | 21 | # load index 22 | print('Load index ...') 23 | with open(args.index_path, 'rb') as f: 24 | corpus_embs, corpus_arg_idxs, docids=pickle.load(f) 25 | corpus_embs = corpus_embs.astype(np.float32) 26 | 27 | faiss.omp_set_num_threads(36) 28 | print('build PQ index...') 29 | index = faiss.IndexPQ(corpus_embs.shape[1], args.qauntized_dim, args.n_bits, faiss.METRIC_INNER_PRODUCT) 30 | index.verbose = True 31 | 32 | print('train PQ...') 33 | index.train(corpus_embs) 34 | print('build index...') 35 | index.add(corpus_embs) 36 | print('write index to {}'.format(args.output_index_path)) 37 | faiss.write_index(index, args.output_index_path) 38 | print('finish') 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /retrieval/rcap_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .evaluation.custom_metrics import recall_cap 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--qrel_file_path", type=str, required=True) 7 | parser.add_argument("--run_file_path", type=str, required=True) 8 | parser.add_argument("--cutoff", type=int, default=100, required=False) 9 | args = parser.parse_args() 10 | 11 | qrels = {} 12 | with open(args.qrel_file_path, 'r') as f: 13 | for line in f: 14 | qid, _, docid, rel = line.strip().split('\t') 15 | if qid not in qrels: 16 | qrels[qid] = {docid: int(rel)} 17 | else: 18 | qrels[qid][docid] = int(rel) 19 | 20 | results = {} 21 | with open(args.run_file_path, 'r') as f: 22 | for line in f: 23 | qid, _, docid, rank, score, _ = line.strip().split(' ') 24 | if qid not in results: 25 | results[qid] = {docid: float(score)} 26 | else: 27 | results[qid][docid] = float(score) 28 | 29 | print(recall_cap(qrels, results, [args.cutoff])) 30 | 31 | if __name__ == "__main__": 32 | main() -------------------------------------------------------------------------------- /retrieval/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | # import mkl 5 | # mkl.set_num_threads(16) 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | from numpy import linalg as LA 9 | from progressbar import * 10 | from collections import defaultdict 11 | import glob 12 | from scipy.sparse import csc_matrix 13 | import gzip 14 | import json 15 | 16 | def read_pickle(filename): 17 | with open(filename, 'rb') as f: 18 | Distance, Index=pickle.load(f) 19 | return Distance, Index 20 | 21 | 22 | def read_id_dict(path): 23 | if os.path.isdir(path): 24 | files = glob.glob(os.path.join(path, '*.id')) 25 | else: 26 | files = [path] 27 | 28 | idx_to_id = {} 29 | id_to_idx = {} 30 | for file in files: 31 | f = open(file, 'r') 32 | for i, line in enumerate(f): 33 | try: 34 | idx, Id =line.strip().split('\t') 35 | idx_to_id[int(idx)] = Id 36 | id_to_idx[Id] = int(idx) 37 | except: 38 | Id = line.strip() 39 | idx_to_id[i] = Id 40 | # if len(Id.split(' '))==1: 41 | 42 | # else: 43 | # print(line+' has no id') 44 | return idx_to_id, id_to_idx 45 | 46 | def write_result(qidxs, Index, Score, file, idx_to_qid, idx_to_docid, topk=None, run_name='Faiss'): 47 | print('write results...') 48 | with open(file, 'w') as fout: 49 | for i, qidx in enumerate(qidxs): 50 | try: 51 | qid = idx_to_qid[qidx] 52 | except: 53 | qid = qidx 54 | if topk==None: 55 | docidxs=Index[i] 56 | scores=Score[i] 57 | for rank, docidx in enumerate(docidxs): 58 | try: 59 | docid = idx_to_docid[docidx] 60 | except: 61 | docid = docidx 62 | fout.write('{} Q0 {} {} {} {}\n'.format(qid, docid, rank + 1, scores[rank], run_name)) 63 | else: 64 | try: 65 | hit=min(topk, len(Index[i])) 66 | except: 67 | print('debug') 68 | 69 | docidxs=Index[i] 70 | scores=Score[i] 71 | for rank, docidx in enumerate(docidxs[:hit]): 72 | try: 73 | docid = idx_to_docid[docidx] 74 | except: 75 | docid = docidx 76 | fout.write('{} Q0 {} {} {} {}\n'.format(qid, docid, rank + 1, scores[rank], run_name)) 77 | 78 | 79 | def faiss_index(corpus_embs, docids, save_path, index_method): 80 | 81 | dimension=corpus_embs.shape[1] 82 | print("Indexing ...") 83 | if index_method==None or index_method=='flatip': 84 | cpu_index = faiss.IndexFlatIP(dimension) 85 | 86 | elif index_method=='hsw': 87 | cpu_index = faiss.IndexHNSWFlat(dimension, 256, faiss.METRIC_INNER_PRODUCT) 88 | cpu_index.hnsw.efConstruction = 256 89 | elif index_method=='quantize': # still try better way for balanced efficiency and effectiveness 90 | cpu_index = faiss.IndexHNSWPQ(dimension, 192, 256) 91 | cpu_index.hnsw.efConstruction = 256 92 | cpu_index.metric_type = faiss.METRIC_INNER_PRODUCT 93 | # ncentroids = 1000 94 | # code_size = dimension//4 95 | # cpu_index = faiss.IndexIVFPQ(cpu_index, dimension, ncentroids, code_size, 8) 96 | # cpu_index = faiss.IndexPQ(dimension, code_size, 8) 97 | # cpu_index = faiss.index_factory(768, "OPQ128,IVF4096,PQ128", faiss.METRIC_INNER_PRODUCT) 98 | # cpu_index = faiss.IndexIDMap(cpu_index) 99 | # cpu_index = faiss.GpuIndexScalarQuantizer(dimension, faiss.ScalarQuantizer.QT_16bit_direct, faiss.METRIC_INNER_PRODUCT) 100 | 101 | 102 | cpu_index.verbose = True 103 | cpu_index.add(corpus_embs) 104 | if index_method=='quantize': 105 | print("Train index...") 106 | cpu_index.train(corpus_embs) 107 | print("Save Index {}...".format(save_path)) 108 | faiss.write_index(cpu_index, save_path) 109 | 110 | def save_pickle(corpus_embs, arg_idxs, docids, filename): 111 | print('save pickle...') 112 | with open(filename, 'wb') as f: 113 | pickle.dump([corpus_embs, arg_idxs, docids], f, protocol=4) 114 | 115 | def load_tfrecords_and_index(srcfiles, data_num, word_num, dim, data_type, add_cls, index=False, save_path=None, batch=10000): 116 | def _parse_function(example_proto): 117 | features = {'doc_emb': tf.FixedLenFeature([],tf.string) , #tf.FixedLenSequenceFeature([],tf.string, allow_missing=True), 118 | 'argx_id_id': tf.FixedLenFeature([],tf.string) , 119 | 'docid': tf.FixedLenFeature([],tf.int64)} 120 | parsed_features = tf.parse_single_example(example_proto, features) 121 | arg_idx = tf.decode_raw(parsed_features['argx_id_id'], tf.uint8) 122 | if data_type=='16': 123 | corpus = tf.decode_raw(parsed_features['doc_emb'], tf.float16) 124 | elif data_type=='32': 125 | corpus = tf.decode_raw(parsed_features['doc_emb'], tf.float32) 126 | docid = tf.cast(parsed_features['docid'], tf.int32) 127 | return corpus, arg_idx, docid 128 | print('Read embeddings...') 129 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 130 | ' ', ETA(), ' ', FileTransferSpeed()] 131 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start() 132 | with tf.Session() as sess: 133 | docids=[] 134 | if add_cls: 135 | segment=2 136 | else: 137 | segment=1 138 | #assign memory in advance so that we can save memory without concatenate 139 | arg_idxs = np.zeros((word_num*data_num*len(srcfiles) , dim), dtype=np.uint8) 140 | if (data_type=='16'): # Faiss now only support index array with float32 141 | corpus_embs = np.zeros((word_num*data_num*len(srcfiles) , dim*segment), dtype=np.float16) 142 | elif data_type=='32': 143 | corpus_embs = np.zeros((word_num*data_num*len(srcfiles) , dim*segment), dtype=np.float32) 144 | # else: 145 | # raise Exception('Please assign datatype 16 or 32 bits') 146 | counter = 0 147 | i = 0 148 | 149 | for srcfile in srcfiles: 150 | try: 151 | dataset = tf.data.TFRecordDataset(srcfile) # load tfrecord file 152 | except: 153 | print('Cannot find data') 154 | continue 155 | dataset = dataset.map(_parse_function) # parse data into tensor 156 | dataset = dataset.repeat(1) 157 | dataset = dataset.batch(batch) 158 | iterator = dataset.make_one_shot_iterator() 159 | next_data = iterator.get_next() 160 | 161 | while True: 162 | try: 163 | corpus_emb, arg_idx, docid = sess.run(next_data) 164 | 165 | corpus_emb = corpus_emb.reshape(-1, dim*segment) 166 | 167 | sent_num = corpus_emb.shape[0] 168 | corpus_embs[counter:(counter+sent_num)] = corpus_emb 169 | arg_idxs[counter:(counter+sent_num)] = arg_idx 170 | 171 | docids+=docid.tolist() 172 | counter+=sent_num 173 | pbar.update(10 * i + 1) 174 | i+=sent_num 175 | except tf.errors.OutOfRangeError: 176 | break 177 | 178 | docids = np.array(docids).reshape(-1) 179 | corpus_embs = (corpus_embs[:len(docids)]) 180 | arg_idxs = (arg_idxs[:len(docids)]) 181 | mask = docids!=-1 182 | docids = docids[mask] 183 | corpus_embs = corpus_embs[mask] 184 | arg_idxs = arg_idxs[mask] 185 | if index: 186 | save_pickle(corpus_embs, arg_idxs, docids, save_path) 187 | else: 188 | return corpus_embs, arg_idxs, docids 189 | 190 | def load_jsonl_and_index(srcfiles, data_num, dim, vocab_dict, data_type, add_cls, index=False, save_path=None, batch=10000): 191 | print('Count line...') 192 | data_num = 0 193 | for srcfile in srcfiles: 194 | with gzip.open(srcfile, 'rb') as f: 195 | for l in f: 196 | data_num+=1 197 | print('Total {} lines'.format(data_num)) 198 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 199 | ' ', ETA(), ' ', FileTransferSpeed()] 200 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start() 201 | docids=[] 202 | if add_cls: 203 | segment=2 204 | else: 205 | segment=1 206 | #assign memory in advance so that we can save memory without concatenate 207 | arg_idxs = np.zeros((data_num , dim), dtype=np.uint8) 208 | if (data_type=='16'): # Faiss now only support index array with float32 209 | corpus_embs = np.zeros((data_num , dim*segment), dtype=np.float16) 210 | elif data_type=='32': 211 | corpus_embs = np.zeros((data_num , dim*segment), dtype=np.float32) 212 | # else: 213 | # raise Exception('Please assign datatype 16 or 32 bits') 214 | counter = 0 215 | i = 0 216 | 217 | 218 | for srcfile in srcfiles: 219 | 220 | with gzip.open(srcfile, "rb") as f: 221 | for line in f: 222 | data = json.loads(line.strip()) 223 | embedding =np.zeros((30522), dtype=np.float16) 224 | for vocab, term_weight in data['vector'].items(): 225 | embedding[vocab_dict[vocab]] = term_weight/100 226 | 227 | embedding = np.reshape(embedding[570:],(-1, dim)) 228 | corpus_emb = embedding.max(0) 229 | arg_idx = embedding.argmax(0) 230 | docid = int(data['id']) 231 | 232 | 233 | corpus_emb = corpus_emb.reshape(-1, dim*segment) 234 | 235 | sent_num = corpus_emb.shape[0] 236 | corpus_embs[counter:(counter+sent_num)] = corpus_emb 237 | arg_idxs[counter:(counter+sent_num)] = arg_idx 238 | 239 | docids+=[docid] 240 | counter+=sent_num 241 | pbar.update(10 * i + 1) 242 | i+=sent_num 243 | 244 | 245 | docids = np.array(docids).reshape(-1) 246 | corpus_embs = (corpus_embs[:len(docids)]) 247 | arg_idxs = (arg_idxs[:len(docids)]) 248 | mask = docids!=-1 249 | docids = docids[mask] 250 | corpus_embs = corpus_embs[mask] 251 | arg_idxs = arg_idxs[mask] 252 | if index: 253 | save_pickle(corpus_embs, arg_idxs, docids, save_path) 254 | else: 255 | return corpus_embs, arg_idxs, docids 256 | 257 | def load_tfrecords_and_analyze(srcfiles, data_num, word_num, dim, data_type, batch=1): 258 | def _parse_function(example_proto): 259 | features = {#'doc_emb': tf.FixedLenFeature([],tf.string) , #tf.FixedLenSequenceFeature([],tf.string, allow_missing=True), 260 | 'id_p1': tf.FixedLenSequenceFeature([],tf.int64, allow_missing=True) , 261 | 'docid': tf.FixedLenFeature([],tf.int64)} 262 | parsed_features = tf.parse_single_example(example_proto, features) 263 | vocab_ids = tf.cast(parsed_features['id_p1'], tf.int32) 264 | docid = tf.cast(parsed_features['docid'], tf.int32) 265 | return vocab_ids, docid 266 | print('Read embeddings...') 267 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 268 | ' ', ETA(), ' ', FileTransferSpeed()] 269 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start() 270 | with tf.Session() as sess: 271 | docids=[] 272 | segment=1 273 | # else: 274 | # raise Exception('Please assign datatype 16 or 32 bits') 275 | counter = 0 276 | i = 0 277 | vocab_adj =np.zeros((30522,30522), dtype=np.uint32) 278 | vocab_freq = np.zeros((30522), dtype=np.uint32) 279 | for srcfile in srcfiles: 280 | try: 281 | dataset = tf.data.TFRecordDataset(srcfile) # load tfrecord file 282 | except: 283 | print('Cannot find data') 284 | continue 285 | dataset = dataset.map(_parse_function) # parse data into tensor 286 | dataset = dataset.repeat(1) 287 | dataset = dataset.batch(batch) 288 | iterator = dataset.make_one_shot_iterator() 289 | next_data = iterator.get_next() 290 | 291 | while True: 292 | try: 293 | vocab_ids, docid = sess.run(next_data) 294 | 295 | vocab_id_list = vocab_ids.squeeze().tolist() 296 | try: 297 | num_vocab_id = len(vocab_id_list) 298 | if num_vocab_id >1: 299 | for m in range(num_vocab_id): 300 | vocab_freq[vocab_id_list[m]]+=1 301 | for n in range(m+1,num_vocab_id,1): 302 | vocab_adj[vocab_id_list[m], vocab_id_list[n]]+=1 303 | 304 | except: 305 | vocab_freq[vocab_id_list]+=1 306 | 307 | 308 | 309 | pbar.update(10 * i + 1) 310 | i+=1 311 | # if i>=20000: 312 | # break 313 | except tf.errors.OutOfRangeError: 314 | break 315 | 316 | 317 | return vocab_freq, vocab_adj -------------------------------------------------------------------------------- /tevatron/Aggretriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/Aggretriever/__init__.py -------------------------------------------------------------------------------- /tevatron/Aggretriever/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | # remove_dim_dict = {768: -198, 640: -198, 512: 826, 256: 314, 128: 58} 6 | # remove_dim_dict1 = {768: 570, 640: 442, 512: 314, 256: 58, 128: 58} 7 | 8 | def cal_remove_dim(dims, vocab_size=30522): 9 | 10 | remove_dims = vocab_size % dims 11 | if remove_dims > 1000: # the first 1000 tokens in BERT are useless 12 | remove_dims -= dims 13 | 14 | return remove_dims 15 | 16 | def aggregate(lexical_reps: Tensor, 17 | dims: int = 640, 18 | remove_dims: int = -198, 19 | full: bool = True 20 | ): 21 | 22 | if full: 23 | remove_dims = cal_remove_dim(dims*2) 24 | batch_size = lexical_reps.shape[0] 25 | if remove_dims >= 0: 26 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims*2) 27 | else: 28 | lexical_reps = torch.nn.functional.pad(lexical_reps, (0, -remove_dims), "constant", 0).view(batch_size, -1, dims*2) 29 | 30 | tok_reps, _ = lexical_reps.max(1) 31 | 32 | positive_tok_reps = tok_reps[:, 0:2*dims:2] 33 | negative_tok_reps = tok_reps[:, 1:2*dims:2] 34 | 35 | positive_mask = positive_tok_reps > negative_tok_reps 36 | negative_mask = positive_tok_reps <= negative_tok_reps 37 | tok_reps = positive_tok_reps * positive_mask - negative_tok_reps * negative_mask 38 | else: 39 | remove_dims = cal_remove_dim(dims) 40 | batch_size = lexical_reps.shape[0] 41 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims) 42 | tok_reps, index_reps = lexical_reps.max(1) 43 | 44 | return tok_reps 45 | 46 | -------------------------------------------------------------------------------- /tevatron/DHR/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/DHR/__init__.py -------------------------------------------------------------------------------- /tevatron/DHR/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | def densify(lexical_reps: Tensor, 6 | dims: int = 768, 7 | strategy: str = 'stride', 8 | remove_dims: int = 570 9 | ): 10 | 11 | if not (len(lexical_reps.shape)==2): 12 | raise ValueError( 'Input lexical representation shape should be 2 (batch, vocab), but the input shape is {}'.format( len(lexical_reps.shape) ) ) 13 | 14 | orig_dims = lexical_reps.shape[-1] 15 | if not ( (orig_dims-remove_dims)%dims==0 ): 16 | raise ValueError('Input lexical representation cannot be densified, please fix dims or remove_dims') 17 | 18 | # Todo: add other strategy 19 | batch_size = lexical_reps.shape[0] 20 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims) 21 | value_reps, index_reps = lexical_reps.max(1) 22 | return value_reps, index_reps 23 | -------------------------------------------------------------------------------- /tevatron/Dense/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/Dense/__init__.py -------------------------------------------------------------------------------- /tevatron/Dense/modeling.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import copy 4 | from dataclasses import dataclass 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | import torch.distributed as dist 10 | 11 | from transformers import AutoModel, PreTrainedModel, AutoModelForMaskedLM 12 | from transformers.modeling_outputs import ModelOutput 13 | 14 | 15 | from typing import Optional, Dict 16 | 17 | from ..arguments import ModelArguments, DataArguments, \ 18 | DenseTrainingArguments as TrainingArguments 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | @dataclass 25 | class DenseOutput(ModelOutput): 26 | q_reps: Tensor = None 27 | p_reps: Tensor = None 28 | loss: Tensor = None 29 | scores: Tensor = None 30 | 31 | 32 | class LinearPooler(nn.Module): 33 | def __init__( 34 | self, 35 | input_dim: int = 768, 36 | output_dim: int = 768, 37 | tied=True, 38 | elementwise=False, 39 | name='pooler' 40 | ): 41 | super(LinearPooler, self).__init__() 42 | self.name = name 43 | self.elementwise=elementwise 44 | self.linear_q = nn.Linear(input_dim, output_dim) 45 | if tied: 46 | self.linear_p = self.linear_q 47 | else: 48 | self.linear_p = nn.Linear(input_dim, output_dim) 49 | 50 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied} 51 | 52 | def forward(self, q: Tensor = None, p: Tensor = None): 53 | if q is not None: 54 | return self.linear_q(q) 55 | elif p is not None: 56 | return self.linear_p(p) 57 | else: 58 | raise ValueError 59 | 60 | def load(self, ckpt_dir: str): 61 | if ckpt_dir is not None: 62 | _pooler_path = os.path.join(ckpt_dir, '{}.pt'.format(self.name)) 63 | if os.path.exists(_pooler_path): 64 | logger.info(f'Loading Pooler from {ckpt_dir}') 65 | state_dict = torch.load(os.path.join(ckpt_dir, '{}.pt'.format(self.name)), map_location='cpu') 66 | self.load_state_dict(state_dict) 67 | return 68 | logger.info("Training {} from scratch".format(self.name)) 69 | return 70 | 71 | def save_pooler(self, save_path): 72 | torch.save(self.state_dict(), os.path.join(save_path, '{}.pt'.format(self.name))) 73 | with open(os.path.join(save_path, '{}_config.json').format(self.name), 'w') as f: 74 | json.dump(self._config, f) 75 | 76 | 77 | class DenseModel(nn.Module): 78 | def __init__( 79 | self, 80 | lm_q: PreTrainedModel, 81 | lm_p: PreTrainedModel, 82 | pooler: nn.Module = None, 83 | model_args: ModelArguments = None, 84 | data_args: DataArguments = None, 85 | train_args: TrainingArguments = None, 86 | ): 87 | super().__init__() 88 | 89 | self.lm_q = lm_q 90 | self.lm_p = lm_p 91 | self.pooler = pooler 92 | 93 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 94 | 95 | self.model_args = model_args 96 | self.train_args = train_args 97 | self.data_args = data_args 98 | 99 | if train_args.negatives_x_device: 100 | if not dist.is_initialized(): 101 | raise ValueError('Distributed training has not been initialized for representation all gather.') 102 | self.process_rank = dist.get_rank() 103 | self.world_size = dist.get_world_size() 104 | 105 | def forward( 106 | self, 107 | query: Dict[str, Tensor] = None, 108 | passage: Dict[str, Tensor] = None, 109 | teacher_scores: Tensor = None, 110 | ): 111 | 112 | 113 | q_hidden, q_reps = self.encode_query(query, self.model_args.pooling_method) 114 | p_hidden, p_reps = self.encode_passage(passage, self.model_args.pooling_method) 115 | 116 | if q_reps is None or p_reps is None: 117 | return DenseOutput( 118 | q_reps=q_reps, 119 | p_reps=p_reps 120 | ) 121 | 122 | if self.training: 123 | if self.train_args.negatives_x_device: 124 | q_reps = self.dist_gather_tensor(q_reps) 125 | p_reps = self.dist_gather_tensor(p_reps) 126 | 127 | effective_bsz = self.train_args.per_device_train_batch_size * self.world_size \ 128 | if self.train_args.negatives_x_device \ 129 | else self.train_args.per_device_train_batch_size 130 | 131 | scores = torch.matmul(q_reps, p_reps.transpose(0, 1)) 132 | scores = scores.view(effective_bsz, -1) 133 | 134 | target = torch.arange( 135 | scores.size(0), 136 | device=scores.device, 137 | dtype=torch.long 138 | ) 139 | target = target * self.data_args.train_n_passages 140 | loss = self.cross_entropy(scores, target) 141 | if self.train_args.negatives_x_device: 142 | loss = loss * self.world_size # counter average weight reduction 143 | return DenseOutput( 144 | loss=loss, 145 | scores=scores, 146 | q_reps=q_reps, 147 | p_reps=p_reps 148 | ) 149 | 150 | else: 151 | loss = None 152 | if query and passage: 153 | scores = (q_reps * p_reps).sum(1) 154 | else: 155 | scores = None 156 | 157 | return DenseOutput( 158 | loss=loss, 159 | scores=scores, 160 | q_reps=q_reps, 161 | p_reps=p_reps 162 | ) 163 | 164 | def encode_passage(self, psg, pooling_method): 165 | if psg is None: 166 | return None, None 167 | 168 | psg_out = self.lm_p(**psg, return_dict=True) 169 | p_hidden = psg_out.last_hidden_state 170 | 171 | if pooling_method == 'cls': 172 | p_reps = p_hidden[:, 0] 173 | elif pooling_method == 'average': 174 | attention_mask = psg['attention_mask'] 175 | p_hidden = p_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 176 | p_reps = p_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 177 | 178 | if self.pooler is not None: 179 | p_reps = self.pooler(p=p_reps) # D * d 180 | 181 | return p_hidden, p_reps 182 | 183 | def encode_query(self, qry, pooling_method): 184 | if qry is None: 185 | return None, None 186 | qry_out = self.lm_q(**qry, return_dict=True) 187 | q_hidden = qry_out.last_hidden_state 188 | 189 | if pooling_method == 'cls': 190 | q_reps = q_hidden[:, 0] 191 | elif pooling_method == 'average': 192 | attention_mask = qry['attention_mask'] 193 | q_hidden = q_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 194 | q_reps = q_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 195 | 196 | 197 | if self.pooler is not None: 198 | q_reps = self.pooler(q=q_reps) 199 | 200 | return q_hidden, q_reps 201 | 202 | @staticmethod 203 | def build_pooler(model_args): 204 | pooler = LinearPooler( 205 | model_args.projection_in_dim, 206 | model_args.projection_out_dim, 207 | tied=not model_args.untie_encoder 208 | ) 209 | pooler.load(model_args.model_name_or_path) 210 | return pooler 211 | 212 | @classmethod 213 | def build( 214 | cls, 215 | model_args: ModelArguments, 216 | data_args: DataArguments, 217 | train_args: TrainingArguments, 218 | **hf_kwargs, 219 | ): 220 | # load local 221 | if os.path.isdir(model_args.model_name_or_path): 222 | if model_args.untie_encoder: 223 | _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model') 224 | _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model') 225 | if not os.path.exists(_qry_model_path): 226 | _qry_model_path = model_args.model_name_or_path 227 | _psg_model_path = model_args.model_name_or_path 228 | logger.info(f'loading query model weight from {_qry_model_path}') 229 | lm_q = AutoModel.from_pretrained( 230 | _qry_model_path, 231 | **hf_kwargs 232 | ) 233 | logger.info(f'loading passage model weight from {_psg_model_path}') 234 | lm_p = AutoModel.from_pretrained( 235 | _psg_model_path, 236 | **hf_kwargs 237 | ) 238 | else: 239 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 240 | lm_p = lm_q 241 | # load pre-trained 242 | else: 243 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 244 | lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q 245 | 246 | if model_args.add_pooler: 247 | pooler = cls.build_pooler(model_args) 248 | else: 249 | pooler = None 250 | 251 | model = cls( 252 | lm_q=lm_q, 253 | lm_p=lm_p, 254 | pooler=pooler, 255 | model_args=model_args, 256 | data_args=data_args, 257 | train_args=train_args 258 | ) 259 | return model 260 | 261 | def save(self, output_dir: str): 262 | if self.model_args.untie_encoder: 263 | os.makedirs(os.path.join(output_dir, 'query_model')) 264 | os.makedirs(os.path.join(output_dir, 'passage_model')) 265 | self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model')) 266 | self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model')) 267 | else: 268 | self.lm_q.save_pretrained(output_dir) 269 | 270 | if self.model_args.add_pooler: 271 | self.pooler.save_pooler(output_dir) 272 | 273 | def dist_gather_tensor(self, t: Optional[torch.Tensor]): 274 | if t is None: 275 | return None 276 | t = t.contiguous() 277 | 278 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] 279 | dist.all_gather(all_tensors, t) 280 | 281 | all_tensors[self.process_rank] = t 282 | all_tensors = torch.cat(all_tensors, dim=0) 283 | 284 | return all_tensors 285 | 286 | 287 | class DenseModelForInference(DenseModel): 288 | POOLER_CLS = LinearPooler 289 | 290 | def __init__( 291 | self, 292 | model_args, 293 | lm_q: PreTrainedModel, 294 | lm_p: PreTrainedModel, 295 | pooler: nn.Module = None, 296 | **kwargs, 297 | ): 298 | nn.Module.__init__(self) 299 | self.lm_q = lm_q 300 | self.lm_p = lm_p 301 | self.pooler = pooler 302 | self.model_args = model_args 303 | 304 | @torch.no_grad() 305 | def encode_passage(self, psg, pooling_method): 306 | return super(DenseModelForInference, self).encode_passage(psg, pooling_method) 307 | 308 | @torch.no_grad() 309 | def encode_query(self, qry, pooling_method): 310 | return super(DenseModelForInference, self).encode_query(qry, pooling_method) 311 | 312 | # def forward( 313 | # self, 314 | # query: Dict[str, Tensor] = None, 315 | # passage: Dict[str, Tensor] = None, 316 | # ): 317 | # q_hidden, q_reps = self.encode_query(query) 318 | # p_hidden, p_reps = self.encode_passage(passage) 319 | # return DenseOutput(q_reps=q_reps, p_reps=p_reps) 320 | 321 | @classmethod 322 | def build( 323 | cls, 324 | model_name_or_path: str = None, 325 | model_args: ModelArguments = None, 326 | data_args: DataArguments = None, 327 | train_args: TrainingArguments = None, 328 | **hf_kwargs, 329 | ): 330 | assert model_name_or_path is not None or model_args is not None 331 | if model_name_or_path is None: 332 | model_name_or_path = model_args.model_name_or_path 333 | 334 | # load local 335 | if os.path.isdir(model_name_or_path): 336 | _qry_model_path = os.path.join(model_name_or_path, 'query_model') 337 | _psg_model_path = os.path.join(model_name_or_path, 'passage_model') 338 | if os.path.exists(_qry_model_path): 339 | logger.info(f'found separate weight for query/passage encoders') 340 | logger.info(f'loading query model weight from {_qry_model_path}') 341 | lm_q = AutoModel.from_pretrained( 342 | _qry_model_path, 343 | **hf_kwargs 344 | ) 345 | logger.info(f'loading passage model weight from {_psg_model_path}') 346 | lm_p = AutoModel.from_pretrained( 347 | _psg_model_path, 348 | **hf_kwargs 349 | ) 350 | else: 351 | logger.info(f'try loading tied weight') 352 | logger.info(f'loading model weight from {model_name_or_path}') 353 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) 354 | lm_p = lm_q 355 | else: 356 | logger.info(f'try loading tied weight') 357 | logger.info(f'loading model weight from {model_name_or_path}') 358 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) 359 | lm_p = lm_q 360 | 361 | pooler_weights = os.path.join(model_name_or_path, 'pooler.pt') 362 | pooler_config = os.path.join(model_name_or_path, 'pooler_config.json') 363 | if os.path.exists(pooler_weights) and os.path.exists(pooler_config): 364 | logger.info(f'found pooler weight and configuration') 365 | with open(pooler_config) as f: 366 | pooler_config_dict = json.load(f) 367 | pooler = cls.POOLER_CLS(**pooler_config_dict) 368 | pooler.load(model_name_or_path) 369 | else: 370 | pooler = None 371 | 372 | model = cls( 373 | model_args=model_args, 374 | lm_q=lm_q, 375 | lm_p=lm_p, 376 | pooler=pooler 377 | 378 | ) 379 | return model -------------------------------------------------------------------------------- /tevatron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/__init__.py -------------------------------------------------------------------------------- /tevatron/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | target_model_path: str = field( 13 | default=None, 14 | metadata={"help": "Path to pretrained reranker target model"} 15 | ) 16 | config_name: Optional[str] = field( 17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 18 | ) 19 | tokenizer_name: Optional[str] = field( 20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 21 | ) 22 | cache_dir: Optional[str] = field( 23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 24 | ) 25 | 26 | # modeling 27 | model: str = field( 28 | default='DHR', 29 | metadata={"help": "ColBERT, DHR, AGG, Dense"} 30 | ) 31 | untie_encoder: bool = field( 32 | default=False, 33 | metadata={"help": "no weight sharing between qry passage encoders"} 34 | ) 35 | 36 | # knowledge distillation 37 | teacher_model_name_or_path: str = field( 38 | default=None, 39 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 40 | ) 41 | tct: bool = field(default=False) 42 | kd: bool = field(default=False) 43 | 44 | # out projection 45 | combine_cls: bool = field(default=False) 46 | add_pooler: bool = field(default=False) 47 | projection_in_dim: int = field(default=768) 48 | projection_out_dim: int = field(default=768) 49 | 50 | # Dense 51 | pooling_method: str = field( 52 | default='cls', 53 | metadata={"help": "cls, average"} 54 | ) 55 | 56 | # dlr option 57 | dlr_out_dim: int = field(default=768) 58 | 59 | # agg option 60 | agg_dim: int = field(default=640) 61 | semi_aggregate: bool = field(default=False) 62 | skip_mlm: bool = field(default=False) 63 | 64 | 65 | # for Jax training 66 | dtype: Optional[str] = field( 67 | default="float32", 68 | metadata={ 69 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one " 70 | "of `[float32, float16, bfloat16]`. " 71 | }, 72 | ) 73 | 74 | 75 | @dataclass 76 | class ColBERTModelArguments: 77 | config_name: Optional[str] = field( 78 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 79 | ) 80 | tokenizer_name: Optional[str] = field( 81 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 82 | ) 83 | cache_dir: Optional[str] = field( 84 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 85 | ) 86 | 87 | # modeling 88 | model: str = field( 89 | default='ColBERT', 90 | metadata={"help": "ColBERT"} 91 | ) 92 | untie_encoder: bool = field( 93 | default=False, 94 | metadata={"help": "no weight sharing between qry passage encoders"} 95 | ) 96 | 97 | # out projection 98 | combine_cls: bool = field(default=False) 99 | add_pooler: bool = field(default=True) 100 | projection_in_dim: int = field(default=768) 101 | projection_out_dim: int = field(default=768) 102 | 103 | # for Jax training 104 | dtype: Optional[str] = field( 105 | default="float32", 106 | metadata={ 107 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one " 108 | "of `[float32, float16, bfloat16]`. " 109 | }, 110 | ) 111 | 112 | 113 | @dataclass 114 | class DataArguments: 115 | train_dir: str = field( 116 | default=None, metadata={"help": "Path to train directory"} 117 | ) 118 | corpus_dir: str = field( 119 | default=None, metadata={"help": "Path to corpus directory"} 120 | ) 121 | query_cluster_dir: str = field( 122 | default=None, metadata={"help": "Path to query cluster direcotry"} 123 | ) 124 | dataset_name: str = field( 125 | default=None, metadata={"help": "huggingface dataset name"} 126 | ) 127 | passage_field_separator: str = field(default=' ') 128 | dataset_proc_num: int = field( 129 | default=12, metadata={"help": "number of proc used in dataset preprocess"} 130 | ) 131 | train_n_passages: int = field(default=8) 132 | positive_passage_no_shuffle: bool = field( 133 | default=False, metadata={"help": "always use the first positive passage"}) 134 | negative_passage_no_shuffle: bool = field( 135 | default=False, metadata={"help": "always use the first negative passages"}) 136 | 137 | tasb_sampling: bool = field( 138 | default=False, metadata={"help": "use topic-aware balanced sampling"}) 139 | 140 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"}) 141 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"}) 142 | encode_is_qry: bool = field(default=False) 143 | encode_num_shard: int = field(default=1) 144 | encode_shard_index: int = field(default=0) 145 | 146 | q_max_len: int = field( 147 | default=32, 148 | metadata={ 149 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 150 | "than this will be truncated, sequences shorter will be padded." 151 | }, 152 | ) 153 | p_max_len: int = field( 154 | default=128, 155 | metadata={ 156 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " 157 | "than this will be truncated, sequences shorter will be padded." 158 | }, 159 | ) 160 | data_cache_dir: Optional[str] = field( 161 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"} 162 | ) 163 | 164 | def __post_init__(self): 165 | if self.dataset_name is not None: 166 | info = self.dataset_name.split('/') 167 | self.dataset_split = info[-1] if len(info) == 3 else 'train' 168 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info) 169 | self.dataset_language = 'default' 170 | if ':' in self.dataset_name: 171 | self.dataset_name, self.dataset_language = self.dataset_name.split(':') 172 | else: 173 | self.dataset_name = 'json' 174 | self.dataset_split = 'train' 175 | self.dataset_language = 'default' 176 | if self.train_dir is not None: 177 | files = sorted(os.listdir(self.train_dir)) 178 | self.train_path = [ 179 | os.path.join(self.train_dir, f) 180 | for f in files 181 | if f.endswith('jsonl') or f.endswith('json') 182 | ] 183 | else: 184 | self.train_path = None 185 | if self.corpus_dir is not None: 186 | files = sorted(os.listdir(self.corpus_dir)) 187 | self.corpus_path = [ 188 | os.path.join(self.corpus_dir, f) 189 | for f in files 190 | if f.endswith('jsonl') or f.endswith('json') 191 | ] 192 | else: 193 | self.corpus_path = None 194 | 195 | if self.query_cluster_dir is not None: 196 | files = sorted(os.listdir(self.query_cluster_dir)) 197 | self.query_cluster_path = [ 198 | os.path.join(self.query_cluster_dir, f) 199 | for f in files 200 | if f.endswith('jsonl') or f.endswith('json') 201 | ] 202 | else: 203 | self.query_cluster_path = None 204 | 205 | 206 | 207 | @dataclass 208 | class DenseTrainingArguments(TrainingArguments): 209 | warmup_ratio: float = field(default=0.1) 210 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) 211 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) 212 | ddp_find_unused_parameters: bool = field(default=True, metadata={"help": "set find unused parameters"}) 213 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) 214 | gc_q_chunk_size: int = field(default=4) 215 | gc_p_chunk_size: int = field(default=32) 216 | -------------------------------------------------------------------------------- /tevatron/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import List, Tuple 4 | 5 | from tqdm import tqdm 6 | import glob 7 | import os 8 | import json 9 | 10 | import datasets 11 | from torch.utils.data import Dataset 12 | from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding 13 | import torch 14 | 15 | from .arguments import DataArguments 16 | from .trainer import DenseTrainer 17 | 18 | import logging 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class TrainDataset(Dataset): 23 | def __init__( 24 | self, 25 | data_args: DataArguments, 26 | dataset: datasets.Dataset, 27 | tokenizer: PreTrainedTokenizer, 28 | trainer: DenseTrainer = None, 29 | ): 30 | self.train_data = dataset 31 | self.tok = tokenizer 32 | self.trainer = trainer 33 | 34 | self.data_args = data_args 35 | self.total_len = len(self.train_data) 36 | 37 | def create_one_example(self, text_encoding: List[int], is_query=False): 38 | item = self.tok.encode_plus( 39 | text_encoding, 40 | truncation='only_first', 41 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, 42 | padding=False, 43 | return_attention_mask=False, 44 | return_token_type_ids=False, 45 | ) 46 | return item 47 | 48 | def __len__(self): 49 | return self.total_len 50 | 51 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: 52 | group = self.train_data[item] 53 | epoch = int(self.trainer.state.epoch) 54 | 55 | _hashed_seed = hash(item + self.trainer.args.seed) 56 | 57 | qry = group['query'] 58 | encoded_query = self.create_one_example(qry, is_query=True) 59 | 60 | encoded_passages = [] 61 | group_positives = group['positives'] 62 | group_negatives = group['negatives'] 63 | 64 | if self.data_args.positive_passage_no_shuffle: 65 | pos_psg = group_positives[0] 66 | else: 67 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)] 68 | encoded_passages.append(self.create_one_example(pos_psg)) 69 | 70 | negative_size = self.data_args.train_n_passages - 1 71 | if len(group_negatives) < negative_size: 72 | negs = random.choices(group_negatives, k=negative_size) 73 | elif self.data_args.train_n_passages == 1: 74 | negs = [] 75 | elif self.data_args.negative_passage_no_shuffle: 76 | negs = group_negatives[:negative_size] 77 | else: 78 | _offset = epoch * negative_size % len(group_negatives) 79 | negs = [x for x in group_negatives] 80 | random.Random(_hashed_seed).shuffle(negs) 81 | negs = negs * 2 82 | negs = negs[_offset: _offset + negative_size] 83 | 84 | for neg_psg in negs: 85 | encoded_passages.append(self.create_one_example(neg_psg)) 86 | 87 | return encoded_query, encoded_passages 88 | 89 | class TrainTASBDataset(Dataset): 90 | # This is now only for msmarco-passage; since the id starts from 0. While using other datasets, this should be revised. 91 | def __init__( 92 | self, 93 | data_args: DataArguments, 94 | kd, 95 | dataset: datasets.Dataset, 96 | corpus: datasets.Dataset, 97 | tokenizer: PreTrainedTokenizer, 98 | trainer: DenseTrainer = None, 99 | ): 100 | self.train_data, self.qidx_cluster = dataset 101 | self.corpus = corpus 102 | self.tok = tokenizer 103 | self.trainer = trainer 104 | self.data_args = data_args 105 | self.tasb_sampling = data_args.tasb_sampling 106 | self.kd = kd 107 | 108 | if self.data_args.corpus_dir is None: 109 | raise ValueError('You should input --corpus_dir with files split*.json') 110 | 111 | # if (self.data_args.train_n_passages!=2) and (self.tasb_sampling): 112 | # raise ValueError('--train_n_passages should be 2 if you use tasb sampling') 113 | 114 | if (self.qidx_cluster is None) and (self.tasb_sampling): 115 | raise ValueError('You should input --query_cluster_dir for tasb sampling') 116 | 117 | self.data_args = data_args 118 | self.total_len = len(self.train_data) 119 | if self.qidx_cluster: 120 | self.cluster_num = len(self.qidx_cluster) 121 | 122 | 123 | def create_one_example(self, text_encoding: List[int], is_query=False): 124 | item = self.tok.encode_plus( 125 | text_encoding, 126 | truncation='only_first', 127 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, 128 | padding=False, 129 | return_attention_mask=False, 130 | return_token_type_ids=False, 131 | ) 132 | return item 133 | 134 | def output_qp(self, group, _hashed_seed): 135 | epoch = int(self.trainer.state.epoch) 136 | qry = group['query'] 137 | encoded_query = self.create_one_example(qry, is_query=True) 138 | 139 | encoded_passages = [] 140 | group_positives = group['positive_pids'] 141 | group_negatives = group['negative_pids'] 142 | 143 | if self.data_args.positive_passage_no_shuffle: 144 | pos_psg_id = group_positives[0] 145 | else: 146 | pos_psg_id = group_positives[(_hashed_seed + epoch) % len(group_positives)] 147 | pos_psg = self.corpus[int(pos_psg_id)]['text'] 148 | encoded_passages.append(self.create_one_example(pos_psg)) 149 | 150 | negative_size = self.data_args.train_n_passages - 1 151 | if len(group_negatives) < negative_size: 152 | negs = random.choices(group_negatives, k=negative_size) 153 | elif self.data_args.train_n_passages == 1: 154 | negs = [] 155 | elif self.data_args.negative_passage_no_shuffle: 156 | negs = group_negatives[:negative_size] 157 | else: 158 | _offset = epoch * negative_size % len(group_negatives) 159 | negs = [x for x in group_negatives] 160 | random.Random(_hashed_seed).shuffle(negs) 161 | negs = negs * 2 162 | negs = negs[_offset: _offset + negative_size] 163 | 164 | for neg_psg_pid in negs: 165 | neg_psg = self.corpus[int(neg_psg_pid)]['text'] 166 | encoded_passages.append(self.create_one_example(neg_psg)) 167 | 168 | return encoded_query, encoded_passages, None 169 | 170 | def output_qp_with_score(self, group, _hashed_seed): 171 | qry = group['query'] 172 | encoded_query = self.create_one_example(qry, is_query=True) 173 | 174 | encoded_passages = [] 175 | scores = [] 176 | qids_bin_pairs = group['bin_pairs'] 177 | bins_pairs = random.choices(qids_bin_pairs, k=1)[0] 178 | 179 | pairs = [] 180 | negative_size = self.data_args.train_n_passages - 1 181 | 182 | for i in range(negative_size): 183 | bin_pairs = random.choices(bins_pairs, k=1)[0] 184 | pairs.append(random.choices(bin_pairs, k=1)[0]) 185 | 186 | pos_psg_idx = int(pairs[0][0]) 187 | pos_psg_id = group['positive_pids'][pos_psg_idx] 188 | pos_psg = self.corpus[int(pos_psg_id)]['text'] 189 | encoded_passages.append(self.create_one_example(pos_psg)) 190 | 191 | for pair in pairs: 192 | neg_psg_idx = int(pair[1]) 193 | neg_psg_id = group['negative_pids'][neg_psg_idx] 194 | neg_psg = self.corpus[int(neg_psg_id)]['text'] 195 | encoded_passages.append(self.create_one_example(neg_psg)) 196 | scores.append(-pair[2]) 197 | 198 | return encoded_query, encoded_passages, scores 199 | 200 | def __len__(self): 201 | return self.total_len 202 | 203 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: 204 | _hashed_seed = hash(item + self.trainer.args.seed) 205 | if self.tasb_sampling: 206 | # make sure the same query cluster gathered in the same batch 207 | random.seed(self.trainer.state.global_step) 208 | cluster_list = random.choices(self.qidx_cluster, k=24) 209 | 210 | #sampling different queries in a batch 211 | random.seed(_hashed_seed) 212 | cluster = random.choices(cluster_list, k=1)[0] 213 | item = random.choices(cluster['qidx'])[0] 214 | 215 | group = self.train_data[item] 216 | else: 217 | group = self.train_data[item] 218 | 219 | if self.kd: 220 | return self.output_qp_with_score(group, _hashed_seed) 221 | else: 222 | return self.output_qp(group, _hashed_seed) 223 | 224 | 225 | 226 | 227 | class EncodeDataset(Dataset): 228 | input_keys = ['text_id', 'text'] 229 | 230 | def __init__(self, dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer, max_len=128): 231 | self.encode_data = dataset 232 | self.tok = tokenizer 233 | self.max_len = max_len 234 | 235 | def __len__(self): 236 | return len(self.encode_data) 237 | 238 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]: 239 | text_id, text = (self.encode_data[item][f] for f in self.input_keys) 240 | if len(text)==0: 241 | text = [0] 242 | encoded_text = self.tok.encode_plus( 243 | text, 244 | max_length=self.max_len, 245 | truncation='only_first', 246 | padding=False, 247 | return_token_type_ids=False, 248 | ) 249 | return text_id, encoded_text 250 | 251 | class EvalDataset(Dataset): 252 | input_keys = ['qry_text_id', 'qry_text', 'psg_text_id', 'psg_text', 'rel'] 253 | 254 | def __init__(self, 255 | data_args: DataArguments, 256 | dataset: datasets.Dataset, 257 | tokenizer: PreTrainedTokenizer): 258 | self.encode_data = dataset 259 | self.tok = tokenizer 260 | self.data_args = data_args 261 | 262 | def __len__(self): 263 | return len(self.encode_data) 264 | 265 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]: 266 | qry_text_id, qry_text, psg_text_id, psg_text, rel = (self.encode_data[item][f] for f in self.input_keys) 267 | encoded_qry_text = self.tok.encode_plus( 268 | qry_text, 269 | max_length=self.data_args.q_max_len, 270 | truncation='only_first', 271 | padding=False, 272 | return_token_type_ids=False, 273 | ) 274 | if len(psg_text)==0: 275 | psg_text = [0] 276 | encoded_psg_text = self.tok.encode_plus( 277 | psg_text, 278 | max_length=self.data_args.p_max_len, 279 | truncation='only_first', 280 | padding=False, 281 | return_token_type_ids=False, 282 | ) 283 | return qry_text_id, encoded_qry_text, psg_text_id, encoded_psg_text, rel 284 | 285 | 286 | @dataclass 287 | class QPCollator(DataCollatorWithPadding): 288 | """ 289 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] 290 | and pass batch separately to the actual collator. 291 | Abstract out data detail for the model. 292 | """ 293 | max_q_len: int = 32 294 | max_p_len: int = 128 295 | 296 | def __call__(self, features): 297 | qq = [f[0] for f in features] 298 | dd = [f[1] for f in features] 299 | 300 | if isinstance(qq[0], list): 301 | qq = sum(qq, []) 302 | if isinstance(dd[0], list): 303 | dd = sum(dd, []) 304 | 305 | q_collated = self.tokenizer.pad( 306 | qq, 307 | padding='max_length', 308 | max_length=self.max_q_len, 309 | return_tensors="pt", 310 | ) 311 | d_collated = self.tokenizer.pad( 312 | dd, 313 | padding='max_length', 314 | max_length=self.max_p_len, 315 | return_tensors="pt", 316 | ) 317 | 318 | if features[0][2] is not None: 319 | scores = [[0]+f[2] for f in features] 320 | scores_collated = torch.tensor(scores) 321 | else: 322 | scores_collated = None 323 | 324 | return q_collated, d_collated, scores_collated 325 | 326 | 327 | @dataclass 328 | class EncodeCollator(DataCollatorWithPadding): 329 | def __call__(self, features): 330 | text_ids = [x[0] for x in features] 331 | text_features = [x[1] for x in features] 332 | collated_features = super().__call__(text_features) 333 | return text_ids, collated_features 334 | 335 | @dataclass 336 | class EvalCollator(DataCollatorWithPadding): 337 | max_q_len: int = 32 338 | max_p_len: int = 128 339 | def __call__(self, features): 340 | qry_text_ids = [x[0] for x in features] 341 | qry_text_features = [x[1] for x in features] 342 | psg_text_ids = [x[2] for x in features] 343 | psg_text_features = [x[3] for x in features] 344 | rels = [x[4] for x in features] 345 | if isinstance(qry_text_features[0], list): 346 | qry_text_features = sum(qry_text_features, []) 347 | if isinstance(psg_text_features[0], list): 348 | psg_text_features = sum(psg_text_features, []) 349 | 350 | qry_collated_features = self.tokenizer.pad( 351 | qry_text_features, 352 | padding='max_length', 353 | max_length=self.max_q_len, 354 | return_tensors="pt", 355 | ) 356 | psg_collated_features = self.tokenizer.pad( 357 | psg_text_features, 358 | padding='max_length', 359 | max_length=self.max_p_len, 360 | return_tensors="pt", 361 | ) 362 | return qry_text_ids, qry_collated_features, psg_text_ids, psg_collated_features, rels -------------------------------------------------------------------------------- /tevatron/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import HFTrainDataset, HFQueryDataset, HFCorpusDataset, HFEvalDataset 2 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor 3 | -------------------------------------------------------------------------------- /tevatron/datasets/beir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/datasets/beir/__init__.py -------------------------------------------------------------------------------- /tevatron/datasets/beir/encode_and_retrieval.py: -------------------------------------------------------------------------------- 1 | ################################################################################################################ 2 | # The evaluation code is revised from SPLADE repo: https://github.com/naver/splade/blob/main/src/beir_eval.py 3 | 4 | 5 | import argparse 6 | from .sentence_bert import Retriever, SentenceTransformerModel 7 | from transformers import AutoModelForMaskedLM, AutoTokenizer 8 | 9 | 10 | from ...arguments import ModelArguments 11 | 12 | 13 | from beir.datasets.data_loader import GenericDataLoader 14 | from beir.retrieval.evaluation import EvaluateRetrieval 15 | from beir import util, LoggingHandler 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--dataset", type=str, required=True) 20 | parser.add_argument("--model_name_or_path", type=str, required=True) 21 | parser.add_argument("--max_length", type=int, default=512) 22 | parser.add_argument("--model", type=str, default='dhr', help='dhr, agg, dense') 23 | parser.add_argument("--agg_dim", type=int, default=640, help='for agg model') 24 | parser.add_argument("--semi_aggregate", action='store_true', help='for agg model') 25 | parser.add_argument("--skip_mlm", action='store_true', help='for agg model') 26 | parser.add_argument("--pooling_method", type=str, default='cls', help='for dense model') 27 | args = parser.parse_args() 28 | 29 | 30 | model_type_or_dir = args.model_name_or_path 31 | model_args = ModelArguments 32 | model_args.model = args.model.lower() 33 | # agg method 34 | model_args.agg_dim = args.agg_dim 35 | model_args.semi_aggregate = args.semi_aggregate 36 | model_args.skip_mlm = args.skip_mlm 37 | model_args.pooling_method = args.pooling_method 38 | # loading model and tokenizer 39 | model = Retriever(model_type_or_dir, model_args) 40 | 41 | model.eval() 42 | tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir, use_fast=False) 43 | sentence_transformer = SentenceTransformerModel(model, tokenizer, args.max_length) 44 | 45 | 46 | dataset = args.dataset 47 | 48 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 49 | out_dir = "dataset/{}".format(dataset) 50 | data_path = util.download_and_unzip(url, out_dir) 51 | 52 | #### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader 53 | # data folder would contain these files: 54 | # (1) nfcorpus/corpus.jsonl (format: jsonlines) 55 | # (2) nfcorpus/queries.jsonl (format: jsonlines) 56 | # (3) nfcorpus/qrels/test.tsv (format: tsv ("\t")) 57 | 58 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test") 59 | 60 | from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES 61 | from beir.retrieval.evaluation import EvaluateRetrieval 62 | 63 | dres = DRES(sentence_transformer) 64 | retriever = EvaluateRetrieval(dres, score_function="dot") 65 | results = retriever.retrieve(corpus, queries) 66 | ndcg, map_, recall, p = EvaluateRetrieval.evaluate(qrels, results, [1, 10, 100, 1000]) 67 | results2 = EvaluateRetrieval.evaluate_custom(qrels, results, [1, 10, 100, 1000], metric="r_cap") 68 | res = {"NDCG@10": ndcg["NDCG@10"], 69 | "Recall@100": recall["Recall@100"], 70 | "R_cap@100": results2["R_cap@100"]} 71 | print("res for {}:".format(dataset), res, flush=True) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() -------------------------------------------------------------------------------- /tevatron/datasets/beir/preprocess.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) 3 | import argparse 4 | import pathlib, os 5 | from beir import util, LoggingHandler 6 | from beir.datasets.data_loader import GenericDataLoader 7 | logger = logging.getLogger(__name__) 8 | from ...utils.data_reader import create_dir 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--output_dir", required=False, default='./dataset', type=str) 14 | parser.add_argument("--dataset", required=True, type=str, help="beir dataset name") 15 | parser.add_argument("--split", default='test', type=str, help="beir dataset name") 16 | args = parser.parse_args() 17 | 18 | #### Download scifact.zip dataset and unzip the dataset 19 | create_dir(os.path.join('./download')) 20 | create_dir(os.path.join(args.output_dir)) 21 | dataset = args.dataset 22 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 23 | data_path = util.download_and_unzip(url, './download') 24 | 25 | #### Provide the data_path where scifact has been downloaded and unzipped 26 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=args.split) 27 | 28 | create_dir(os.path.join(args.output_dir, args.dataset, 'corpus')) 29 | os.rename(os.path.join('./download', args.dataset, 'corpus.jsonl'), os.path.join(args.output_dir, args.dataset, 'corpus', 'collection.json')) 30 | 31 | create_dir(os.path.join(args.output_dir, args.dataset,'qrels')) 32 | qrel_fout = open(os.path.join(args.output_dir, args.dataset,'qrels', 'qrels.' + args.split + '.tsv'), 'w') 33 | 34 | create_dir(os.path.join(args.output_dir, args.dataset,'queries')) 35 | query_fout = open(os.path.join(args.output_dir, args.dataset, 'queries', 'queries.' + args.split + '.tsv'), 'w') 36 | 37 | for qid, answer in qrels.items(): 38 | for docid, rel in answer.items(): 39 | qrel_fout.write('{}\tQ0\t{}\t{}\n'.format(qid, docid, rel)) 40 | query_fout.write('{}\t{}\n'.format(qid, queries[qid])) 41 | 42 | qrel_fout.close() 43 | query_fout.close() 44 | 45 | if __name__ == "__main__": 46 | main() -------------------------------------------------------------------------------- /tevatron/datasets/beir/sentence_bert.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Union 3 | 4 | import numpy as np 5 | import torch 6 | from numpy import ndarray 7 | from torch import Tensor 8 | from tqdm.autonotebook import trange 9 | from transformers import AutoModelForMaskedLM 10 | 11 | 12 | try: 13 | import sentence_transformers 14 | from sentence_transformers.util import batch_to_device 15 | except ImportError: 16 | print("Import Error: could not load sentence_transformers... proceeding") 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class SentenceTransformerModel: 21 | def __init__(self, model, tokenizer, max_length=512): 22 | self.max_length = max_length 23 | self.tokenizer = tokenizer 24 | self.model = model 25 | self.sep = ' ' 26 | 27 | # Write your own encoding query function (Returns: Query embeddings as numpy array) 28 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray: 29 | X = self.model.encode_sentence_bert(self.tokenizer, queries, is_q=True, maxlen=self.max_length) 30 | return X 31 | 32 | # Write your own encoding corpus function (Returns: Document embeddings as numpy array) 33 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray: 34 | sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus] 35 | return self.model.encode_sentence_bert(self.tokenizer, sentences, maxlen=self.max_length) 36 | 37 | 38 | 39 | class Retriever(torch.nn.Module): 40 | 41 | def __init__(self, model_type_or_dir, model_args): 42 | super().__init__() 43 | self.model_args = model_args 44 | if self.model_args.model.lower() == 'dhr': 45 | from ...DHR.modeling import DHRModelForInference 46 | from ...DHR.modeling import DHROutput as output 47 | self.transformer = DHRModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args) 48 | elif self.model_args.model.lower() == 'agg': 49 | from ...Aggretriever.modeling import DenseModelForInference 50 | from ...Aggretriever.modeling import DenseOutput as output 51 | self.transformer = DenseModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args) 52 | elif self.model_args.model.lower() == 'dense': 53 | from ...Dense.modeling import DenseModelForInference 54 | from ...Dense.modeling import DenseOutput as Output 55 | self.transformer = DenseModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args) 56 | else: 57 | raise ValueError('--rep_type can only be dhr or dense (CLS) or agg.') 58 | def forward(self, features, is_q): 59 | if is_q: 60 | if self.model_args.model== 'dhr': 61 | out = self.transformer(query=features) 62 | return [out.q_lexical_reps, out.q_semantic_reps] 63 | if self.model_args.model == 'agg': 64 | out = self.transformer(query=features) 65 | return out.q_reps 66 | elif self.model_args.model == 'dense': 67 | out = self.transformer(query=features) 68 | return out.q_reps 69 | else: 70 | if self.model_args.model == 'dhr': 71 | out = self.transformer(passage=features) 72 | return [out.p_lexical_reps, out.p_semantic_reps] 73 | if self.model_args.model == 'agg': 74 | out = self.transformer(passage=features) 75 | return out.p_reps 76 | elif self.model_args.model == 'dense': 77 | out = self.transformer(passage=features) 78 | return out.p_reps 79 | 80 | def _text_length(self, text: Union[List[int], List[List[int]]]): 81 | """helper function to get the length for the input text. Text can be either 82 | a list of ints (which means a single text as input), or a tuple of list of ints 83 | (representing several text inputs to the model). 84 | """ 85 | 86 | if isinstance(text, dict): # {key: value} case 87 | return len(next(iter(text.values()))) 88 | elif not hasattr(text, '__len__'): # Object has no len() method 89 | return 1 90 | elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints 91 | return len(text) 92 | else: 93 | return sum([len(t) for t in text]) # Sum of length of individual strings 94 | 95 | def encode_sentence_bert(self, tokenizer, sentences: Union[str, List[str], List[int]], 96 | batch_size: int = 32, 97 | show_progress_bar: bool = None, 98 | output_value: str = 'dhr_embeddings', 99 | convert_to_numpy: bool = True, 100 | convert_to_tensor: bool = False, 101 | device: str = None, 102 | normalize_embeddings: bool = False, 103 | maxlen: int = 512, 104 | is_q: bool = False) -> Union[List[Tensor], ndarray, Tensor]: 105 | """ 106 | Computes sentence embeddings 107 | :param sentences: the sentences to embed 108 | :param batch_size: the batch size used for the computation 109 | :param show_progress_bar: Output a progress bar when encode sentences 110 | :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. 111 | :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. 112 | :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy 113 | :param device: Which torch.device to use for the computation 114 | :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. 115 | :return: 116 | By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. 117 | """ 118 | if self.model_args.model == 'dense': 119 | output_value = 'sentence_embeddings' 120 | elif self.model_args.model == 'agg': 121 | output_value = 'sentence_embeddings' 122 | else: 123 | output_value = 'dhr_embeddings' 124 | 125 | 126 | 127 | self.eval() 128 | if show_progress_bar is None: 129 | show_progress_bar = True 130 | 131 | if convert_to_tensor: 132 | convert_to_numpy = False 133 | 134 | if output_value == 'token_embeddings': 135 | convert_to_tensor = False 136 | convert_to_numpy = False 137 | 138 | input_was_string = False 139 | if isinstance(sentences, str) or not hasattr(sentences, '__len__'): 140 | # Cast an individual sentence to a list with length 1 141 | sentences = [sentences] 142 | input_was_string = True 143 | 144 | if device is None: 145 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 146 | 147 | self.to(device) 148 | 149 | all_embeddings = [] 150 | all_semantic_embeddings = [] 151 | all_lexical_embeddings = [] 152 | length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) 153 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 154 | 155 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 156 | sentences_batch = sentences_sorted[start_index:start_index + batch_size] 157 | # features = tokenizer(sentences_batch) 158 | # print(sentences_batch) 159 | features = tokenizer(sentences_batch, 160 | add_special_tokens=True, 161 | padding="longest", # pad to max sequence length in batch 162 | truncation="only_first", # truncates to self.max_length 163 | max_length=maxlen, 164 | return_attention_mask=True, 165 | return_tensors="pt") 166 | # print(features) 167 | features = batch_to_device(features, device) 168 | 169 | with torch.no_grad(): 170 | out_features = self.forward(features, is_q) 171 | if output_value == 'dhr_embeddings': 172 | lexical_embeddings = out_features[0].detach() 173 | try: 174 | semantic_embeddings = out_features[1].detach() 175 | semantic_dim = semantic_embeddings.shape[1] 176 | except: 177 | semantic_dim = 0 178 | if convert_to_numpy: 179 | lexical_embeddings = lexical_embeddings.cpu() 180 | try: 181 | semantic_embeddings = semantic_embeddings.cpu() 182 | except: 183 | semantic_dim = 0 184 | 185 | embeddings = torch.zeros((lexical_embeddings.shape[0], lexical_embeddings.shape[1] + semantic_dim)) 186 | embeddings[:,:lexical_embeddings.shape[1]] = lexical_embeddings 187 | if semantic_dim != 0: 188 | embeddings[:,lexical_embeddings.shape[1]:] = semantic_embeddings 189 | 190 | else: 191 | if output_value == 'token_embeddings': 192 | embeddings = [] 193 | for token_emb, attention in zip(out_features[output_value], out_features['attention_mask']): 194 | last_mask_id = len(attention) - 1 195 | while last_mask_id > 0 and attention[last_mask_id].item() == 0: 196 | last_mask_id -= 1 197 | embeddings.append(token_emb[0:last_mask_id + 1]) 198 | elif output_value == 'sentence_embeddings': 199 | # embeddings = out_features[output_value] 200 | embeddings = out_features 201 | embeddings = embeddings.detach() 202 | if normalize_embeddings: 203 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 204 | # fixes for #522 and #487 to avoid oom problems on gpu with large datasets 205 | if convert_to_numpy: 206 | embeddings = embeddings.cpu() 207 | 208 | all_embeddings.extend(embeddings) 209 | 210 | 211 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 212 | if convert_to_tensor: 213 | all_embeddings = torch.stack(all_embeddings) 214 | elif convert_to_numpy: 215 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 216 | if input_was_string: 217 | all_embeddings = all_embeddings[0] 218 | return all_embeddings 219 | 220 | -------------------------------------------------------------------------------- /tevatron/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import PreTrainedTokenizer 3 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor, EvalPreProcessor 4 | from ..arguments import DataArguments 5 | 6 | DEFAULT_PROCESSORS = [TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor, EvalPreProcessor] 7 | PROCESSOR_INFO = { 8 | 'Tevatron/wikipedia-nq': DEFAULT_PROCESSORS, 9 | 'Tevatron/wikipedia-trivia': DEFAULT_PROCESSORS, 10 | 'Tevatron/wikipedia-curated': DEFAULT_PROCESSORS, 11 | 'Tevatron/wikipedia-wq': DEFAULT_PROCESSORS, 12 | 'Tevatron/wikipedia-squad': DEFAULT_PROCESSORS, 13 | 'Tevatron/scifact': DEFAULT_PROCESSORS, 14 | 'Tevatron/msmarco-passage': DEFAULT_PROCESSORS, 15 | 'json': [None, None, None, None] 16 | } 17 | 18 | 19 | class HFTrainDataset: 20 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 21 | data_files = data_args.train_path 22 | if data_files: 23 | data_files = {data_args.dataset_split: data_files} 24 | 25 | self.dataset = load_dataset(data_args.dataset_name, 26 | data_args.dataset_language, 27 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 28 | 29 | if data_args.query_cluster_path is not None: 30 | data_files = {data_args.dataset_split: data_args.query_cluster_path} 31 | self.qidx_cluster = load_dataset(data_args.dataset_name, 32 | data_args.dataset_language, 33 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 34 | else: 35 | self.qidx_cluster = None 36 | 37 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][0] if data_args.dataset_name in PROCESSOR_INFO\ 38 | else DEFAULT_PROCESSORS[0] 39 | self.tokenizer = tokenizer 40 | self.q_max_len = data_args.q_max_len 41 | self.p_max_len = data_args.p_max_len 42 | self.proc_num = data_args.dataset_proc_num 43 | self.neg_num = data_args.train_n_passages - 1 44 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 45 | 46 | def process(self, shard_num=1, shard_idx=0): 47 | self.dataset = self.dataset.shard(shard_num, shard_idx) 48 | if self.preprocessor is not None: 49 | self.dataset = self.dataset.map( 50 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len, self.separator), 51 | batched=False, 52 | num_proc=self.proc_num, 53 | remove_columns=self.dataset.column_names, 54 | desc="Running tokenizer on train dataset", 55 | ) 56 | return self.dataset, self.qidx_cluster 57 | 58 | 59 | class HFQueryDataset: 60 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 61 | data_files = data_args.encode_in_path 62 | if data_files: 63 | data_files = {data_args.dataset_split: data_files} 64 | self.dataset = load_dataset(data_args.dataset_name, 65 | data_args.dataset_language, 66 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 67 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][1] if data_args.dataset_name in PROCESSOR_INFO \ 68 | else DEFAULT_PROCESSORS[1] 69 | self.tokenizer = tokenizer 70 | self.q_max_len = data_args.q_max_len 71 | self.proc_num = data_args.dataset_proc_num 72 | 73 | def process(self, shard_num=1, shard_idx=0): 74 | self.dataset = self.dataset.shard(shard_num, shard_idx) 75 | if self.preprocessor is not None: 76 | self.dataset = self.dataset.map( 77 | self.preprocessor(self.tokenizer, self.q_max_len), 78 | batched=False, 79 | num_proc=self.proc_num, 80 | remove_columns=self.dataset.column_names, 81 | desc="Running tokenization", 82 | ) 83 | return self.dataset 84 | 85 | 86 | class HFCorpusDataset: 87 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 88 | if data_args.encode_in_path is not None: 89 | data_files = data_args.encode_in_path 90 | if data_args.corpus_path is not None: 91 | data_files = data_args.corpus_path 92 | if data_files: 93 | data_files = {data_args.dataset_split: data_files} 94 | self.dataset = load_dataset(data_args.dataset_name, 95 | data_args.dataset_language, 96 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 97 | script_prefix = data_args.dataset_name 98 | if script_prefix.endswith('-corpus'): 99 | script_prefix = script_prefix[:-7] 100 | self.preprocessor = PROCESSOR_INFO[script_prefix][2] \ 101 | if script_prefix in PROCESSOR_INFO else DEFAULT_PROCESSORS[2] 102 | self.tokenizer = tokenizer 103 | self.p_max_len = data_args.p_max_len 104 | self.proc_num = data_args.dataset_proc_num 105 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 106 | 107 | def process(self, shard_num=1, shard_idx=0): 108 | self.dataset = self.dataset.shard(shard_num, shard_idx) 109 | if self.preprocessor is not None: 110 | self.dataset = self.dataset.map( 111 | self.preprocessor(self.tokenizer, self.p_max_len, self.separator), 112 | batched=False, 113 | num_proc=self.proc_num, 114 | remove_columns=self.dataset.column_names, 115 | desc="Running tokenization", 116 | ) 117 | return self.dataset 118 | 119 | class HFEvalDataset: 120 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 121 | data_files = data_args.encode_in_path 122 | if data_files: 123 | data_files = {data_args.dataset_split: data_files} 124 | self.dataset = load_dataset(data_args.dataset_name, 125 | data_args.dataset_language, 126 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 127 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][3] if data_args.dataset_name in PROCESSOR_INFO \ 128 | else DEFAULT_PROCESSORS[3] 129 | self.tokenizer = tokenizer 130 | self.q_max_len = data_args.q_max_len 131 | self.p_max_len = data_args.p_max_len 132 | self.proc_num = data_args.dataset_proc_num 133 | 134 | def process(self, shard_num=1, shard_idx=0): 135 | self.dataset = self.dataset.shard(shard_num, shard_idx) 136 | if self.preprocessor is not None: 137 | self.dataset = self.dataset.map( 138 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len), 139 | batched=False, 140 | num_proc=self.proc_num, 141 | remove_columns=self.dataset.column_names, 142 | desc="Running tokenization", 143 | ) 144 | return self.dataset -------------------------------------------------------------------------------- /tevatron/datasets/preprocessor.py: -------------------------------------------------------------------------------- 1 | class TrainPreProcessor: 2 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256, separator=' '): 3 | self.tokenizer = tokenizer 4 | self.query_max_length = query_max_length 5 | self.text_max_length = text_max_length 6 | self.separator = separator 7 | 8 | def __call__(self, example): 9 | query = self.tokenizer.encode(example['query'], 10 | add_special_tokens=False, 11 | max_length=self.query_max_length, 12 | truncation=True) 13 | positives = [] 14 | for pos in example['positive_passages']: 15 | text = pos['title'] + self.separator + pos['text'] if 'title' in pos else pos['text'] 16 | positives.append(self.tokenizer.encode(text, 17 | add_special_tokens=False, 18 | max_length=self.text_max_length, 19 | truncation=True)) 20 | negatives = [] 21 | for neg in example['negative_passages']: 22 | text = neg['title'] + self.separator + neg['text'] if 'title' in neg else neg['text'] 23 | negatives.append(self.tokenizer.encode(text, 24 | add_special_tokens=False, 25 | max_length=self.text_max_length, 26 | truncation=True)) 27 | return {'query': query, 'positives': positives, 'negatives': negatives} 28 | 29 | 30 | class QueryPreProcessor: 31 | def __init__(self, tokenizer, query_max_length=32): 32 | self.tokenizer = tokenizer 33 | self.query_max_length = query_max_length 34 | 35 | def __call__(self, example): 36 | query_id = example['query_id'] 37 | query = self.tokenizer.encode(example['query'], 38 | add_special_tokens=False, 39 | max_length=self.query_max_length, 40 | truncation=True) 41 | return {'text_id': query_id, 'text': query} 42 | 43 | 44 | class CorpusPreProcessor: 45 | def __init__(self, tokenizer, text_max_length=256, separator=' '): 46 | self.tokenizer = tokenizer 47 | self.text_max_length = text_max_length 48 | self.separator = separator 49 | 50 | def __call__(self, example): 51 | docid = example['docid'] 52 | text = example['title'] + self.separator + example['text'] if 'title' in example else example['text'] 53 | text = self.tokenizer.encode(text, 54 | add_special_tokens=False, 55 | max_length=self.text_max_length, 56 | truncation=True) 57 | return {'text_id': docid, 'text': text} 58 | 59 | class EvalPreProcessor: 60 | def __init__(self, tokenizer, qry_max_length=32, psg_max_length=256, separator=' '): 61 | self.tokenizer = tokenizer 62 | self.qry_max_length = qry_max_length 63 | self.psg_max_length = psg_max_length 64 | self.separator = separator 65 | 66 | def __call__(self, example): 67 | docid = example['docid'] 68 | qry_text = example['qry_text'] 69 | qry_text = self.tokenizer.encode(qry_text, 70 | add_special_tokens=False, 71 | max_length=self.qry_max_length, 72 | truncation=True) 73 | psg_text = example['title'] + self.separator + example['psg_text'] if 'title' in example else example['psg_text'] 74 | psg_text = self.tokenizer.encode(psg_text, 75 | add_special_tokens=False, 76 | max_length=self.psg_max_length, 77 | truncation=True) 78 | return {'text_id': docid, 'qry_text': qry_text, 'psg_text': psg_text} 79 | -------------------------------------------------------------------------------- /tevatron/driver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/driver/__init__.py -------------------------------------------------------------------------------- /tevatron/driver/encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoConfig, AutoTokenizer 14 | from transformers import ( 15 | HfArgumentParser, 16 | ) 17 | 18 | from tevatron.arguments import ModelArguments, DataArguments, \ 19 | DenseTrainingArguments as TrainingArguments 20 | from tevatron.data import EncodeDataset, EncodeCollator 21 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset 22 | from tevatron.DHR.utils import densify 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 38 | raise NotImplementedError('Multi-GPU encoding is not supported.') 39 | 40 | # Setup logging 41 | logging.basicConfig( 42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 43 | datefmt="%m/%d/%Y %H:%M:%S", 44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 45 | ) 46 | 47 | num_labels = 1 48 | config = AutoConfig.from_pretrained( 49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 50 | num_labels=num_labels, 51 | output_hidden_states=True, 52 | cache_dir=model_args.cache_dir, 53 | ) 54 | tokenizer = AutoTokenizer.from_pretrained( 55 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 56 | cache_dir=model_args.cache_dir, 57 | use_fast=False, 58 | ) 59 | 60 | if (model_args.model).lower() == 'dhr': 61 | from tevatron.DHR.modeling import DHRModelForInference 62 | from tevatron.DHR.modeling import DHROutput as Output 63 | logger.info("Encoding model DHR") 64 | model = DHRModelForInference.build( 65 | model_args=model_args, 66 | config=config, 67 | cache_dir=model_args.cache_dir, 68 | ) 69 | elif (model_args.model).lower() == 'dlr': 70 | from tevatron.DHR.modeling import DHRModelForInference 71 | from tevatron.DHR.modeling import DHROutput as Output 72 | logger.info("Encoding model DLR") 73 | model_args.combine_cls = False 74 | model = DHRModelForInference.build( 75 | model_args=model_args, 76 | config=config, 77 | cache_dir=model_args.cache_dir, 78 | ) 79 | elif (model_args.model).lower() == 'agg': 80 | from tevatron.Aggretriever.modeling import DenseModelForInference 81 | from tevatron.Aggretriever.modeling import DenseOutput as Output 82 | logger.info("Encoding model Dense (AGG)") 83 | model = DenseModelForInference.build( 84 | model_args=model_args, 85 | config=config, 86 | cache_dir=model_args.cache_dir, 87 | ) 88 | elif (model_args.model).lower() == 'dense': 89 | from tevatron.Dense.modeling import DenseModelForInference 90 | from tevatron.Dense.modeling import DenseOutput as Output 91 | logger.info("Encding model Dense (CLS)") 92 | model = DenseModelForInference.build( 93 | model_args=model_args, 94 | config=config, 95 | cache_dir=model_args.cache_dir, 96 | ) 97 | else: 98 | raise ValueError('input model is not supported') 99 | 100 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 101 | if data_args.encode_is_qry: 102 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args, 103 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 104 | else: 105 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args, 106 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 107 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 108 | tokenizer, max_len=text_max_length) 109 | 110 | encode_loader = DataLoader( 111 | encode_dataset, 112 | batch_size=training_args.per_device_eval_batch_size, 113 | collate_fn=EncodeCollator( 114 | tokenizer, 115 | max_length=text_max_length, 116 | padding='max_length' 117 | ), 118 | shuffle=False, 119 | drop_last=False, 120 | num_workers=training_args.dataloader_num_workers, 121 | ) 122 | 123 | 124 | 125 | def initialize_reps(data_num, dim, dtype): 126 | return np.zeros((data_num, dim), dtype=dtype) 127 | 128 | 129 | offset = 0 130 | lookup_indices = [] 131 | model = model.to(training_args.device) 132 | model.eval() 133 | 134 | data_num = len(encode_dataset) 135 | value_encoded, index_encoded = None, None 136 | 137 | for (batch_ids, batch) in tqdm(encode_loader): 138 | batch_size = len(batch_ids) 139 | lookup_indices.extend(batch_ids) 140 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 141 | with torch.no_grad(): 142 | for k, v in batch.items(): 143 | batch[k] = v.to(training_args.device) 144 | 145 | if data_args.encode_is_qry: 146 | 147 | model_output: Output = model(query=batch) 148 | 149 | if (model_args.model).lower() == 'dense' or (model_args.model).lower() == 'agg': 150 | reps = model_output.q_reps.cpu().detach().numpy() 151 | if value_encoded is None: 152 | value_encoded = initialize_reps(data_num, reps.shape[1], np.float16) 153 | value_encoded[offset: (offset + batch_size), :] = reps 154 | else: 155 | dlr_value_reps, dlr_index_reps = densify(model_output.q_lexical_reps, model_args.dlr_out_dim) 156 | dlr_value_reps = dlr_value_reps.cpu().detach().numpy() 157 | dlr_index_reps = dlr_index_reps.cpu().detach().numpy().astype(np.uint8) 158 | cls_reps = model_output.q_semantic_reps.cpu().detach().numpy() 159 | 160 | if value_encoded is None: 161 | if cls_reps is None: 162 | cls_dim = 0 163 | else: 164 | cls_dim = cls_reps.shape[1] 165 | value_encoded = initialize_reps(data_num, dlr_value_reps.shape[1] + cls_dim, np.float16) 166 | index_encoded = initialize_reps(data_num, dlr_index_reps.shape[1], np.uint8) 167 | value_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_value_reps 168 | index_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_index_reps 169 | if cls_reps is not None: 170 | value_encoded[offset: (offset + batch_size), model_args.dlr_out_dim:] = cls_reps 171 | 172 | else: 173 | model_output: Output = model(passage=batch) 174 | if (model_args.model).lower() == 'dense' or (model_args.model).lower() == 'agg': 175 | reps = model_output.p_reps.cpu().detach().numpy() 176 | if value_encoded is None: 177 | value_encoded = initialize_reps(data_num, reps.shape[1], np.float16) 178 | value_encoded[offset: (offset + batch_size), :] = reps 179 | else: 180 | dlr_value_reps, dlr_index_reps = densify(model_output.p_lexical_reps, model_args.dlr_out_dim) 181 | dlr_value_reps = dlr_value_reps.cpu().detach().numpy() 182 | dlr_index_reps = dlr_index_reps.cpu().detach().numpy().astype(np.uint8) 183 | cls_reps = model_output.p_semantic_reps.cpu().detach().numpy() 184 | 185 | if value_encoded is None: 186 | if cls_reps is None: 187 | cls_dim = 0 188 | else: 189 | cls_dim = cls_reps.shape[1] 190 | value_encoded = initialize_reps(data_num, dlr_value_reps.shape[1] + cls_dim, np.float16) 191 | index_encoded = initialize_reps(data_num, dlr_index_reps.shape[1], np.uint8) 192 | value_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_value_reps 193 | index_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_index_reps 194 | if cls_reps is not None: 195 | value_encoded[offset: (offset + batch_size), model_args.dlr_out_dim:] = cls_reps 196 | 197 | offset += batch_size 198 | 199 | output_dir = '/'.join( (data_args.encoded_save_path).split('/')[:-1] ) 200 | if not os.path.exists(output_dir): 201 | logger.info(f'{output_dir} not exists, create') 202 | os.mkdir(output_dir) 203 | with open(data_args.encoded_save_path, 'wb') as f: 204 | pickle.dump([value_encoded, index_encoded, lookup_indices], f, protocol=4) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /tevatron/driver/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoConfig, AutoTokenizer 14 | from transformers import ( 15 | HfArgumentParser, 16 | ) 17 | 18 | from tevatron.arguments import ModelArguments, DataArguments, \ 19 | DenseTrainingArguments as TrainingArguments 20 | from tevatron.data import EvalDataset, EvalCollator 21 | from tevatron.datasets import HFEvalDataset 22 | from tevatron.utils import metrics 23 | METRICS_MAP = ['MAP', 'RPrec', 'NDCG', 'MRR', 'MRR@10'] 24 | # from tevatron.densification.utils import densify 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def main(): 30 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: TrainingArguments 38 | 39 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 40 | raise NotImplementedError('Multi-GPU encoding is not supported.') 41 | 42 | # Setup logging 43 | logging.basicConfig( 44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 45 | datefmt="%m/%d/%Y %H:%M:%S", 46 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 47 | ) 48 | 49 | num_labels = 1 50 | config = AutoConfig.from_pretrained( 51 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 52 | num_labels=num_labels, 53 | output_hidden_states=True, 54 | cache_dir=model_args.cache_dir, 55 | ) 56 | tokenizer = AutoTokenizer.from_pretrained( 57 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 58 | cache_dir=model_args.cache_dir, 59 | use_fast=False, 60 | ) 61 | 62 | if (model_args.model).lower() == 'colbert': 63 | from tevatron.ColBERT.modeling import ColBERTForInference 64 | from tevatron.ColBERT.modeling import ColBERTOutput as Output 65 | logger.info("Evaluating model ColBERT") 66 | model = ColBERTForInference.build( 67 | model_args=model_args, 68 | config=config, 69 | cache_dir=model_args.cache_dir, 70 | ) 71 | elif (model_args.model).lower() == 'dhr': 72 | from tevatron.DHR.modeling import DHRModelForInference 73 | from tevatron.DHR.modeling import DHROutput as Output 74 | logger.info("Evaluating model DHR") 75 | model = DHRModelForInference.build( 76 | model_args=model_args, 77 | config=config, 78 | cache_dir=model_args.cache_dir, 79 | ) 80 | elif (model_args.model).lower() == 'dlr': 81 | from tevatron.DHR.modeling import DHRModelForInference 82 | from tevatron.DHR.modeling import DHROutput as Output 83 | logger.info("Evaluating model DHR") 84 | model_args.combine_cls = False 85 | model = DHRModelForInference.build( 86 | model_args=model_args, 87 | config=config, 88 | cache_dir=model_args.cache_dir, 89 | ) 90 | elif (model_args.model).lower() == 'agg': 91 | from tevatron.Aggretriever.modeling import DenseModelForInference 92 | from tevatron.Aggretriever.modeling import DenseOutput as Output 93 | logger.info("Evaluating model Dense (AGG)") 94 | model = DHRModelForInference.build( 95 | model_args=model_args, 96 | config=config, 97 | cache_dir=model_args.cache_dir, 98 | ) 99 | elif (model_args.model).lower() == 'dense': 100 | from tevatron.Dense.modeling import DenseModelForInference 101 | from tevatron.Dense.modeling import DenseOutput as Output 102 | logger.info("Evaluating model Dense (CLS)") 103 | model = DenseModelForInference.build( 104 | model_args=model_args, 105 | config=config, 106 | cache_dir=model_args.cache_dir, 107 | ) 108 | else: 109 | raise ValueError('input model is not supported') 110 | 111 | eval_dataset = HFEvalDataset(tokenizer=tokenizer, data_args=data_args, 112 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 113 | eval_dataset = EvalDataset(data_args, eval_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 114 | tokenizer) 115 | 116 | eval_loader = DataLoader( 117 | eval_dataset, 118 | batch_size=training_args.per_device_eval_batch_size, 119 | collate_fn=EvalCollator( 120 | tokenizer, 121 | max_p_len=data_args.p_max_len, 122 | max_q_len=data_args.q_max_len, 123 | padding='max_length' 124 | ), 125 | shuffle=False, 126 | drop_last=False, 127 | num_workers=training_args.dataloader_num_workers, 128 | ) 129 | 130 | model = model.to(training_args.device) 131 | model.eval() 132 | 133 | num_candidates_per_qry = 1000 134 | if num_candidates_per_qry%training_args.per_device_eval_batch_size!=0: 135 | raise ValueError('Batch size should be a factor of {}'.format(num_candidates_per_qry)) 136 | all_metrics = np.zeros(len(METRICS_MAP)) 137 | num_examples = 0 138 | qids = [] 139 | candidiate_psg_ids = [] 140 | scores = [] 141 | labels = [] 142 | for (batch_qry_ids, batch_qry_featutres, batch_psg_ids, batch_psg_features, rels) in tqdm(eval_loader): 143 | if len(set(batch_qry_ids)) != 1: 144 | raise ValueError('Tere is other query in the Eval batch!') 145 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 146 | with torch.no_grad(): 147 | for k, v in batch_qry_featutres.items(): 148 | batch_qry_featutres[k] = v.to(training_args.device) 149 | for k, v in batch_psg_features.items(): 150 | batch_psg_features[k] = v.to(training_args.device) 151 | model_output: Output = model(query=batch_qry_featutres, passage=batch_psg_features) 152 | 153 | qids += batch_qry_ids 154 | candidiate_psg_ids += batch_psg_ids 155 | scores += model_output.scores.cpu().numpy().tolist() 156 | labels += rels 157 | if len(candidiate_psg_ids) == num_candidates_per_qry: 158 | if len(set(qids)) != 1: 159 | raise ValueError('Tere is other query in the set!') 160 | gt = set(list(np.where(np.array(labels) > 0)[0])) 161 | 162 | predict_doc = np.array(scores).argsort()[::-1] 163 | all_metrics += metrics.metrics(gt=gt, pred=predict_doc, metrics_map=METRICS_MAP) 164 | num_examples+=1 165 | qids = [] 166 | candidiate_psg_ids = [] 167 | scores = [] 168 | labels = [] 169 | if (num_examples%10==0): 170 | logging.warn("Read {} examples, Metrics so far:".format(num_examples)) 171 | logging.warn(" ".join(METRICS_MAP)) 172 | logging.warn(all_metrics / num_examples) 173 | if num_examples==200: 174 | break 175 | # Write results 176 | 177 | 178 | output_dir = '/'.join( (data_args.encoded_save_path).split('/')[:-1] ) 179 | 180 | 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /tevatron/driver/jax_encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import datasets 7 | import jax 8 | import numpy as np 9 | from flax.training.common_utils import shard 10 | from jax import pmap 11 | from tevatron.arguments import DataArguments 12 | from tevatron.arguments import DenseTrainingArguments as TrainingArguments 13 | from tevatron.arguments import ModelArguments 14 | from tevatron.data import EncodeCollator, EncodeDataset 15 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | from flax.training.train_state import TrainState 19 | from flax import jax_utils 20 | import optax 21 | from transformers import (AutoConfig, AutoTokenizer, FlaxAutoModel, 22 | HfArgumentParser, TensorType) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | # Setup logging 38 | logging.basicConfig( 39 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 40 | datefmt="%m/%d/%Y %H:%M:%S", 41 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 42 | ) 43 | 44 | num_labels = 1 45 | config = AutoConfig.from_pretrained( 46 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 47 | num_labels=num_labels, 48 | cache_dir=model_args.cache_dir, 49 | ) 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 52 | cache_dir=model_args.cache_dir, 53 | use_fast=False, 54 | ) 55 | 56 | model = FlaxAutoModel.from_pretrained(model_args.model_name_or_path, config=config, from_pt=False) 57 | 58 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 59 | if data_args.encode_is_qry: 60 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args, 61 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 62 | else: 63 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args, 64 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 65 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 66 | tokenizer, max_len=text_max_length) 67 | 68 | # prepare padding batch (for last nonfull batch) 69 | dataset_size = len(encode_dataset) 70 | padding_prefix = "padding_" 71 | total_batch_size = len(jax.devices()) * training_args.per_device_eval_batch_size 72 | features = list(encode_dataset.encode_data.features.keys()) 73 | padding_batch = {features[0]: [], features[1]: []} 74 | for i in range(total_batch_size - (dataset_size % total_batch_size)): 75 | padding_batch["text_id"].append(f"{padding_prefix}{i}") 76 | padding_batch["text"].append([0]) 77 | padding_batch = datasets.Dataset.from_dict(padding_batch) 78 | encode_dataset.encode_data = datasets.concatenate_datasets([encode_dataset.encode_data, padding_batch]) 79 | 80 | encode_loader = DataLoader( 81 | encode_dataset, 82 | batch_size=training_args.per_device_eval_batch_size * len(jax.devices()), 83 | collate_fn=EncodeCollator( 84 | tokenizer, 85 | max_length=text_max_length, 86 | padding='max_length', 87 | pad_to_multiple_of=16, 88 | return_tensors=TensorType.NUMPY, 89 | ), 90 | shuffle=False, 91 | drop_last=False, 92 | num_workers=training_args.dataloader_num_workers, 93 | ) 94 | 95 | # craft a fake state for now to replicate on devices 96 | adamw = optax.adamw(0.0001) 97 | state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) 98 | 99 | def encode_step(batch, state): 100 | embedding = state.apply_fn(**batch, params=state.params, train=False)[0] 101 | return embedding[:, 0] 102 | 103 | p_encode_step = pmap(encode_step) 104 | state = jax_utils.replicate(state) 105 | 106 | encoded = [] 107 | lookup_indices = [] 108 | 109 | for (batch_ids, batch) in tqdm(encode_loader): 110 | lookup_indices.extend(batch_ids) 111 | batch_embeddings = p_encode_step(shard(batch.data), state) 112 | encoded.extend(np.concatenate(batch_embeddings, axis=0)) 113 | with open(data_args.encoded_save_path, 'wb') as f: 114 | pickle.dump((encoded[:dataset_size], lookup_indices[:dataset_size]), f) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /tevatron/driver/jax_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from functools import partial 5 | 6 | import datasets 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | from flax import jax_utils, traverse_util 11 | from flax.jax_utils import prefetch_to_device 12 | from flax.training.common_utils import get_metrics, shard 13 | from torch.utils.data import DataLoader, IterableDataset 14 | from tqdm import tqdm 15 | from transformers import AutoConfig, AutoTokenizer, FlaxAutoModel 16 | from transformers import ( 17 | HfArgumentParser, 18 | set_seed, 19 | ) 20 | 21 | from tevatron.arguments import ModelArguments, DataArguments, DenseTrainingArguments 22 | from tevatron.tevax.training import TiedParams, RetrieverTrainState, retriever_train_step, grad_cache_train_step, \ 23 | DualParams 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DataArguments, DenseTrainingArguments)) 30 | 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: DenseTrainingArguments 38 | 39 | if ( 40 | os.path.exists(training_args.output_dir) 41 | and os.listdir(training_args.output_dir) 42 | and training_args.do_train 43 | and not training_args.overwrite_output_dir 44 | ): 45 | raise ValueError( 46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 47 | ) 48 | 49 | # Setup logging 50 | logging.basicConfig( 51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 52 | datefmt="%m/%d/%Y %H:%M:%S", 53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 54 | ) 55 | logger.warning( 56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 57 | training_args.local_rank, 58 | training_args.device, 59 | training_args.n_gpu, 60 | bool(training_args.local_rank != -1), 61 | training_args.fp16, 62 | ) 63 | logger.info("Training/evaluation parameters %s", training_args) 64 | logger.info("MODEL parameters %s", model_args) 65 | 66 | set_seed(training_args.seed) 67 | 68 | config = AutoConfig.from_pretrained( 69 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 70 | cache_dir=model_args.cache_dir, 71 | ) 72 | tokenizer = AutoTokenizer.from_pretrained( 73 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 74 | cache_dir=model_args.cache_dir, 75 | ) 76 | try: 77 | model = FlaxAutoModel.from_pretrained( 78 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) 79 | ) 80 | except: 81 | model = FlaxAutoModel.from_pretrained( 82 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), 83 | from_pt=True 84 | ) 85 | 86 | if data_args.train_dir: 87 | data_files = { 88 | 'train': data_args.train_path 89 | } 90 | else: 91 | data_files = None 92 | 93 | train_dataset = \ 94 | datasets.load_dataset(data_args.dataset_name, data_args.dataset_language, cache_dir=model_args.cache_dir, 95 | data_files=data_files)[data_args.dataset_split] 96 | 97 | def tokenize_train(example): 98 | tokenize = partial(tokenizer, return_attention_mask=False, return_token_type_ids=False, padding=False, 99 | truncation=True) 100 | query = example['query'] 101 | pos_psgs = [p['title'] + " " + p['text'] for p in example['positive_passages']] 102 | neg_psgs = [p['title'] + " " + p['text'] for p in example['negative_passages']] 103 | 104 | example['query_input_ids'] = dict(tokenize(query, max_length=32)) 105 | example['pos_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in pos_psgs] 106 | example['neg_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in neg_psgs] 107 | 108 | return example 109 | 110 | train_data = train_dataset.map( 111 | tokenize_train, 112 | batched=False, 113 | num_proc=data_args.dataset_proc_num, 114 | desc="Running tokenizer on train dataset", 115 | ) 116 | train_data = train_data.filter( 117 | function=lambda data: len(data["pos_psgs_input_ids"]) >= 1 and \ 118 | len(data["neg_psgs_input_ids"]) >= data_args.train_n_passages-1, num_proc=64 119 | ) 120 | 121 | class TrainDataset: 122 | def __init__(self, train_data, group_size, tokenizer): 123 | self.group_size = group_size 124 | self.data = train_data 125 | self.tokenizer = tokenizer 126 | 127 | def __len__(self): 128 | return len(self.data) 129 | 130 | def get_example(self, i, epoch): 131 | example = self.data[i] 132 | q = example['query_input_ids'] 133 | 134 | pp = example['pos_psgs_input_ids'] 135 | p = pp[0] 136 | 137 | nn = example['neg_psgs_input_ids'] 138 | off = epoch * (self.group_size - 1) % len(nn) 139 | nn = nn * 2 140 | nn = nn[off: off + self.group_size - 1] 141 | 142 | return q, [p] + nn 143 | 144 | def get_batch(self, indices, epoch): 145 | qq, dd = zip(*[self.get_example(i, epoch) for i in map(int, indices)]) 146 | dd = sum(dd, []) 147 | return dict(tokenizer.pad(qq, max_length=32, padding='max_length', return_tensors='np')), dict( 148 | tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np')) 149 | 150 | train_dataset = TrainDataset(train_data, data_args.train_n_passages, tokenizer) 151 | 152 | def create_learning_rate_fn( 153 | train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, 154 | learning_rate: float 155 | ): 156 | """Returns a linear warmup, linear_decay learning rate function.""" 157 | steps_per_epoch = train_ds_size // train_batch_size 158 | num_train_steps = steps_per_epoch * num_train_epochs 159 | warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) 160 | decay_fn = optax.linear_schedule( 161 | init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps 162 | ) 163 | schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) 164 | return schedule_fn 165 | 166 | def _decay_mask_fn(params): 167 | flat_params = traverse_util.flatten_dict(params) 168 | layer_norm_params = [ 169 | (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] 170 | ] 171 | flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} 172 | return traverse_util.unflatten_dict(flat_mask) 173 | 174 | def decay_mask_fn(params): 175 | param_nodes, treedef = jax.tree_flatten(params, lambda v: isinstance(v, dict)) 176 | masks = [_decay_mask_fn(param_node) for param_node in param_nodes] 177 | return jax.tree_unflatten(treedef, masks) 178 | 179 | num_epochs = int(training_args.num_train_epochs) 180 | train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() 181 | steps_per_epoch = len(train_dataset) // train_batch_size 182 | total_train_steps = steps_per_epoch * num_epochs 183 | 184 | linear_decay_lr_schedule_fn = create_learning_rate_fn( 185 | len(train_dataset), 186 | train_batch_size, 187 | int(training_args.num_train_epochs), 188 | int(total_train_steps * 0.1), 189 | training_args.learning_rate, 190 | ) 191 | 192 | adamw = optax.adamw( 193 | learning_rate=linear_decay_lr_schedule_fn, 194 | b1=training_args.adam_beta1, 195 | b2=training_args.adam_beta2, 196 | eps=training_args.adam_epsilon, 197 | weight_decay=training_args.weight_decay, 198 | mask=decay_mask_fn, 199 | ) 200 | 201 | if model_args.untie_encoder: 202 | params = DualParams.create(model.params) 203 | else: 204 | params = TiedParams.create(model.params) 205 | state = RetrieverTrainState.create(apply_fn=model.__call__, params=params, tx=adamw) 206 | 207 | if training_args.grad_cache: 208 | q_n_subbatch = train_batch_size // training_args.gc_q_chunk_size 209 | p_n_subbatch = train_batch_size * data_args.train_n_passages // training_args.gc_p_chunk_size 210 | p_train_step = jax.pmap( 211 | partial(grad_cache_train_step, q_n_subbatch=q_n_subbatch, p_n_subbatch=p_n_subbatch), 212 | "device" 213 | ) 214 | else: 215 | p_train_step = jax.pmap( 216 | retriever_train_step, 217 | "device" 218 | ) 219 | 220 | state = jax_utils.replicate(state) 221 | rng = jax.random.PRNGKey(training_args.seed) 222 | dropout_rngs = jax.random.split(rng, jax.local_device_count()) 223 | 224 | class IterableTrain(IterableDataset): 225 | def __init__(self, dataset, batch_idx, epoch): 226 | super(IterableTrain).__init__() 227 | self.dataset = dataset 228 | self.batch_idx = batch_idx 229 | self.epoch = epoch 230 | 231 | def __iter__(self): 232 | for idx in self.batch_idx: 233 | batch = self.dataset.get_batch(idx, self.epoch) 234 | batch = shard(batch) 235 | yield batch 236 | 237 | logger.info("***** Running training *****") 238 | logger.info(f" Num examples = {len(train_dataset)}") 239 | logger.info(f" Num Epochs = {num_epochs}") 240 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") 241 | logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") 242 | logger.info(f" Total optimization steps = {total_train_steps}") 243 | 244 | train_metrics = [] 245 | for epoch in tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0): 246 | # ======================== Training ================================ 247 | # Create sampling rng 248 | rng, input_rng = jax.random.split(rng) 249 | 250 | steps_per_epoch = len(train_dataset) // train_batch_size 251 | 252 | batch_idx = jax.random.permutation(input_rng, len(train_dataset)) 253 | batch_idx = batch_idx[: steps_per_epoch * train_batch_size] 254 | batch_idx = batch_idx.reshape((steps_per_epoch, train_batch_size)).tolist() 255 | 256 | train_loader = prefetch_to_device( 257 | iter(DataLoader( 258 | IterableTrain(train_dataset, batch_idx, epoch), 259 | num_workers=16, prefetch_factor=256, batch_size=None, collate_fn=lambda v: v) 260 | ), 2) 261 | 262 | # train 263 | epochs = tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False) 264 | for step in epochs: 265 | cur_step = epoch * (len(train_dataset) // train_batch_size) + step 266 | batch = next(train_loader) 267 | 268 | loss, state, dropout_rngs = p_train_step(state, *batch, dropout_rngs) 269 | train_metrics.append({'loss': loss}) 270 | 271 | if cur_step % training_args.logging_steps == 0 and cur_step > 0: 272 | train_metrics = get_metrics(train_metrics) 273 | print( 274 | f"Step... ({cur_step} | Loss: {train_metrics['loss'].mean()}," 275 | f" Learning Rate: {linear_decay_lr_schedule_fn(cur_step)})", 276 | flush=True, 277 | ) 278 | train_metrics = [] 279 | 280 | epochs.write( 281 | f"Epoch... ({epoch + 1}/{num_epochs})" 282 | ) 283 | 284 | params = jax_utils.unreplicate(state.params) 285 | 286 | if model_args.untie_encoder: 287 | os.makedirs(training_args.output_dir, exist_ok=True) 288 | model.save_pretrained(os.path.join(training_args.output_dir, 'query_encoder'), params=params.q_params) 289 | model.save_pretrained(os.path.join(training_args.output_dir, 'passage_encoder'), params=params.p_params) 290 | else: 291 | model.save_pretrained(training_args.output_dir, params=params.p_params) 292 | tokenizer.save_pretrained(training_args.output_dir) 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /tevatron/driver/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | from transformers import ( 7 | HfArgumentParser, 8 | set_seed, 9 | ) 10 | 11 | from tevatron.arguments import ModelArguments, DataArguments, ColBERTModelArguments, \ 12 | DenseTrainingArguments as TrainingArguments 13 | from tevatron.data import TrainDataset, TrainTASBDataset, QPCollator 14 | from tevatron.trainer import DenseTrainer as Trainer, GCTrainer 15 | from tevatron.datasets import HFTrainDataset, HFCorpusDataset 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def main(): 21 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 22 | 23 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 24 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 25 | else: 26 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 27 | 28 | model_args: ModelArguments 29 | data_args: DataArguments 30 | training_args: TrainingArguments 31 | 32 | 33 | 34 | if ( 35 | os.path.exists(training_args.output_dir) 36 | and os.listdir(training_args.output_dir) 37 | and training_args.do_train 38 | and not training_args.overwrite_output_dir 39 | ): 40 | raise ValueError( 41 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 42 | ) 43 | 44 | # Setup logging 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 47 | datefmt="%m/%d/%Y %H:%M:%S", 48 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 49 | ) 50 | logger.warning( 51 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 52 | training_args.local_rank, 53 | training_args.device, 54 | training_args.n_gpu, 55 | bool(training_args.local_rank != -1), 56 | training_args.fp16, 57 | ) 58 | logger.info("Training/evaluation parameters %s", training_args) 59 | logger.info("MODEL parameters %s", model_args) 60 | 61 | set_seed(training_args.seed) 62 | 63 | num_labels = 1 64 | config = AutoConfig.from_pretrained( 65 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 66 | num_labels=num_labels, 67 | output_hidden_states=True, 68 | cache_dir=model_args.cache_dir, 69 | ) 70 | tokenizer = AutoTokenizer.from_pretrained( 71 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 72 | cache_dir=model_args.cache_dir, 73 | use_fast=False, 74 | ) 75 | 76 | teacher_model = None 77 | if model_args.tct: 78 | if model_args.teacher_model_name_or_path is None: 79 | raise ValueError( 80 | f"when use --tct option, you should input --teacher_model_name_or_path" 81 | ) 82 | # use default setting 83 | teacher_model_args = ColBERTModelArguments() 84 | teacher_model_args.model_name_or_path = model_args.teacher_model_name_or_path 85 | colbert_config = AutoConfig.from_pretrained( 86 | teacher_model_args.config_name if teacher_model_args.config_name else teacher_model_args.model_name_or_path, 87 | num_labels=num_labels, 88 | output_hidden_states=True, 89 | cache_dir=teacher_model_args.cache_dir, 90 | ) 91 | 92 | from tevatron.ColBERT.modeling import ColBERTForInference, ColBERTOutput 93 | from tevatron.ColBERT.modeling import ColBERTOutput as Output 94 | logger.info("Call model ColBERT as listwise teacher") 95 | teacher_model = ColBERTForInference.build( 96 | model_args=teacher_model_args, 97 | data_args=data_args, 98 | train_args=training_args, 99 | config=colbert_config, 100 | cache_dir=teacher_model_args.cache_dir, 101 | ) 102 | 103 | if (model_args.model).lower() == 'colbert': 104 | from tevatron.ColBERT.modeling import ColBERT 105 | logger.info("Training model ColBERT") 106 | model = ColBERT.build( 107 | model_args, 108 | data_args, 109 | training_args, 110 | config=config, 111 | cache_dir=model_args.cache_dir, 112 | ) 113 | elif (model_args.model).lower() == 'dhr': 114 | from tevatron.DHR.modeling import DHRModel 115 | logger.info("Training model DHR") 116 | model = DHRModel.build( 117 | model_args, 118 | data_args, 119 | training_args, 120 | teacher_model, 121 | config=config, 122 | cache_dir=model_args.cache_dir, 123 | ) 124 | elif (model_args.model).lower() == 'dlr': 125 | from tevatron.DHR.modeling import DHRModel 126 | logger.info("Training model DLR") 127 | model_args.combine_cls = False 128 | model = DHRModel.build( 129 | model_args, 130 | data_args, 131 | training_args, 132 | teacher_model, 133 | config=config, 134 | cache_dir=model_args.cache_dir, 135 | ) 136 | elif (model_args.model).lower() == 'agg': 137 | from tevatron.Aggretriever.modeling import DenseModel 138 | logger.info("Training model Dense (AGG)") 139 | model = DenseModel.build( 140 | model_args, 141 | data_args, 142 | training_args, 143 | config=config, 144 | cache_dir=model_args.cache_dir, 145 | ) 146 | elif (model_args.model).lower() == 'dense': 147 | from tevatron.Dense.modeling import DenseModel 148 | logger.info("Training model Dense (CLS)") 149 | model = DenseModel.build( 150 | model_args, 151 | data_args, 152 | training_args, 153 | config=config, 154 | cache_dir=model_args.cache_dir, 155 | ) 156 | else: 157 | raise ValueError('input model is not supported') 158 | 159 | 160 | train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args, 161 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 162 | 163 | corpus_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args, 164 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 165 | ### Todo: set augument, using TASB training dataset 166 | # train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer) 167 | train_dataset = TrainTASBDataset(data_args, model_args.kd, train_dataset.process(), corpus_dataset.process(), tokenizer) 168 | 169 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer 170 | trainer = trainer_cls( 171 | model=model, 172 | args=training_args, 173 | train_dataset=train_dataset, 174 | data_collator=QPCollator( 175 | tokenizer, 176 | max_p_len=data_args.p_max_len, 177 | max_q_len=data_args.q_max_len 178 | ), 179 | ) 180 | train_dataset.trainer = trainer 181 | 182 | trainer.train() # TODO: resume training 183 | trainer.save_model() 184 | if trainer.is_world_process_zero(): 185 | tokenizer.save_pretrained(training_args.output_dir) 186 | 187 | 188 | if __name__ == "__main__": 189 | main() 190 | -------------------------------------------------------------------------------- /tevatron/faiss_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import BaseFaissIPRetriever 2 | -------------------------------------------------------------------------------- /tevatron/faiss_retriever/__main__.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import glob 5 | from argparse import ArgumentParser 6 | from itertools import chain 7 | from tqdm import tqdm 8 | 9 | from .retriever import BaseFaissIPRetriever 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig( 14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S", 16 | level=logging.INFO, 17 | ) 18 | 19 | 20 | def search_queries(retriever, q_reps, p_lookup, args): 21 | if args.batch_size > 0: 22 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size) 23 | else: 24 | all_scores, all_indices = retriever.search(q_reps, args.depth) 25 | 26 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices] 27 | psg_indices = np.array(psg_indices) 28 | return all_scores, psg_indices 29 | 30 | 31 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): 32 | with open(ranking_save_file, 'w') as f: 33 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): 34 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 35 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 36 | for s, idx in score_list: 37 | f.write(f'{qid}\t{idx}\t{s}\n') 38 | 39 | 40 | def pickle_load(path): 41 | with open(path, 'rb') as f: 42 | obj = pickle.load(f) 43 | return obj 44 | 45 | 46 | def pickle_save(obj, path): 47 | with open(path, 'wb') as f: 48 | pickle.dump(obj, f) 49 | 50 | 51 | def main(): 52 | parser = ArgumentParser() 53 | parser.add_argument('--query_reps', required=True) 54 | parser.add_argument('--passage_reps', required=True) 55 | parser.add_argument('--batch_size', type=int, default=128) 56 | parser.add_argument('--depth', type=int, default=1000) 57 | parser.add_argument('--save_ranking_to', required=True) 58 | parser.add_argument('--save_text', action='store_true') 59 | 60 | args = parser.parse_args() 61 | 62 | index_files = glob.glob(args.passage_reps) 63 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') 64 | 65 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 66 | retriever = BaseFaissIPRetriever(p_reps_0) 67 | 68 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 69 | if len(index_files) > 1: 70 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 71 | look_up = [] 72 | for p_reps, p_lookup in shards: 73 | retriever.add(p_reps) 74 | look_up += p_lookup 75 | 76 | q_reps, q_lookup = pickle_load(args.query_reps) 77 | q_reps = q_reps 78 | 79 | logger.info('Index Search Start') 80 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) 81 | logger.info('Index Search Finished') 82 | 83 | if args.save_text: 84 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) 85 | else: 86 | pickle_save((all_scores, psg_indices), args.save_ranking_to) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /tevatron/faiss_retriever/reducer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import faiss 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from typing import Iterable, Tuple 6 | from numpy import ndarray 7 | from .__main__ import pickle_load, write_ranking 8 | 9 | 10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]): 11 | rh = None 12 | for scores, indices in results: 13 | if rh is None: 14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.') 15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1]) 16 | rh.add_result(-scores, indices) 17 | rh.finalize() 18 | corpus_scores, corpus_indices = -rh.D, rh.I 19 | 20 | return corpus_scores, corpus_indices 21 | 22 | 23 | def main(): 24 | parser = ArgumentParser() 25 | parser.add_argument('--score_dir', required=True) 26 | parser.add_argument('--query', required=True) 27 | parser.add_argument('--save_ranking_to', required=True) 28 | args = parser.parse_args() 29 | 30 | partitions = glob.glob(f'{args.score_dir}/*') 31 | 32 | corpus_scores, corpus_indices = combine_faiss_results(map(pickle_load, tqdm(partitions))) 33 | 34 | _, q_lookup = pickle_load(args.query) 35 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /tevatron/faiss_retriever/retriever.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class BaseFaissIPRetriever: 10 | def __init__(self, init_reps: np.ndarray): 11 | index = faiss.IndexFlatIP(init_reps.shape[1]) 12 | self.index = index 13 | 14 | def add(self, p_reps: np.ndarray): 15 | self.index.add(p_reps) 16 | 17 | def search(self, q_reps: np.ndarray, k: int): 18 | return self.index.search(q_reps, k) 19 | 20 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): 21 | num_query = q_reps.shape[0] 22 | all_scores = [] 23 | all_indices = [] 24 | for start_idx in range(0, num_query, batch_size): 25 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 26 | all_scores.append(nn_scores) 27 | all_indices.append(nn_indices) 28 | all_scores = np.concatenate(all_scores, axis=0) 29 | all_indices = np.concatenate(all_indices, axis=0) 30 | 31 | return all_scores, all_indices 32 | 33 | 34 | class FaissRetriever(BaseFaissIPRetriever): 35 | 36 | def __init__(self, init_reps: np.ndarray, factory_str: str): 37 | index = faiss.index_factory(init_reps.shape[1], factory_str) 38 | self.index = index 39 | self.index.verbose = True 40 | if not self.index.is_trained: 41 | self.index.train(init_reps) 42 | -------------------------------------------------------------------------------- /tevatron/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torch import distributed as dist 5 | 6 | 7 | class SimpleContrastiveLoss: 8 | def __init__(self, n_target: int = 1): 9 | self.target_per_qry = n_target 10 | 11 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 12 | if target is None: 13 | assert x.size(0) * self.target_per_qry == y.size(0) 14 | target = torch.arange( 15 | 0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device, dtype=torch.long) 16 | logits = torch.matmul(x, y.transpose(0, 1)) 17 | return F.cross_entropy(logits, target, reduction=reduction) 18 | 19 | 20 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 21 | def __init__(self, n_target: int = 0, scale_loss: bool = True): 22 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 23 | super().__init__(n_target=n_target) 24 | self.word_size = dist.get_world_size() 25 | self.rank = dist.get_rank() 26 | self.scale_loss = scale_loss 27 | 28 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 29 | dist_x = self.gather_tensor(x) 30 | dist_y = self.gather_tensor(y) 31 | loss = super().__call__(dist_x, dist_y, **kwargs) 32 | if self.scale_loss: 33 | loss = loss * self.word_size 34 | return loss 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) -------------------------------------------------------------------------------- /tevatron/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessor_tsv import SimpleTrainPreProcessor as MarcoPassageTrainPreProcessor, \ 2 | SimpleCollectionPreProcessor as MarcoPassageCollectionPreProcessor 3 | -------------------------------------------------------------------------------- /tevatron/preprocessor/preprocessor_tsv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import datasets 4 | from transformers import PreTrainedTokenizer 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class SimpleTrainPreProcessor: 10 | query_file: str 11 | collection_file: str 12 | tokenizer: PreTrainedTokenizer 13 | 14 | max_length: int = 128 15 | columns = ['text_id', 'title', 'text'] 16 | title_field = 'title' 17 | text_field = 'text' 18 | 19 | def __post_init__(self): 20 | self.queries = self.read_queries(self.query_file) 21 | self.collection = datasets.load_dataset( 22 | 'csv', 23 | data_files=self.collection_file, 24 | column_names=self.columns, 25 | delimiter='\t', 26 | )['train'] 27 | 28 | @staticmethod 29 | def read_queries(queries): 30 | qmap = {} 31 | with open(queries) as f: 32 | for l in f: 33 | qid, qry = l.strip().split('\t') 34 | qmap[qid] = qry 35 | return qmap 36 | 37 | @staticmethod 38 | def read_qrel(relevance_file): 39 | qrel = {} 40 | with open(relevance_file, encoding='utf8') as f: 41 | tsvreader = csv.reader(f, delimiter="\t") 42 | for [topicid, _, docid, rel] in tsvreader: 43 | assert rel == "1" 44 | if topicid in qrel: 45 | qrel[topicid].append(docid) 46 | else: 47 | qrel[topicid] = [docid] 48 | return qrel 49 | 50 | def get_query(self, q): 51 | query_encoded = self.tokenizer.encode( 52 | self.queries[q], 53 | add_special_tokens=False, 54 | max_length=self.max_length, 55 | truncation=True 56 | ) 57 | return query_encoded 58 | 59 | def get_passage(self, p): 60 | entry = self.collection[int(p)] 61 | title = entry[self.title_field] 62 | title = "" if title is None else title 63 | body = entry[self.text_field] 64 | content = title + self.tokenizer.sep_token + body 65 | 66 | passage_encoded = self.tokenizer.encode( 67 | content, 68 | add_special_tokens=False, 69 | max_length=self.max_length, 70 | truncation=True 71 | ) 72 | 73 | return passage_encoded 74 | 75 | def process_one(self, train): 76 | q, pp, nn = train 77 | train_example = { 78 | 'query': self.get_query(q), 79 | 'positives': [self.get_passage(p) for p in pp], 80 | 'negatives': [self.get_passage(n) for n in nn], 81 | } 82 | 83 | return json.dumps(train_example) 84 | 85 | 86 | @dataclass 87 | class SimpleCollectionPreProcessor: 88 | tokenizer: PreTrainedTokenizer 89 | separator: str = '\t' 90 | max_length: int = 128 91 | 92 | def process_line(self, line: str): 93 | xx = line.strip().split(self.separator) 94 | text_id, text = xx[0], xx[1:] 95 | text_encoded = self.tokenizer.encode( 96 | self.tokenizer.sep_token.join(text), 97 | add_special_tokens=False, 98 | max_length=self.max_length, 99 | truncation=True 100 | ) 101 | encoded = { 102 | 'text_id': text_id, 103 | 'text': text_encoded 104 | } 105 | return json.dumps(encoded) 106 | -------------------------------------------------------------------------------- /tevatron/tevax/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import TiedParams, DualParams, RetrieverTrainState, retriever_train_step 2 | -------------------------------------------------------------------------------- /tevatron/tevax/loss.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | import optax 4 | import chex 5 | 6 | 7 | def _onehot(labels: chex.Array, num_classes: int) -> chex.Array: 8 | x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,)) 9 | x = lax.select(x, jnp.ones(x.shape), jnp.zeros(x.shape)) 10 | return x.astype(jnp.float32) 11 | 12 | 13 | def p_contrastive_loss(ss: chex.Array, tt: chex.Array, axis: str = 'device') -> chex.Array: 14 | per_shard_targets = tt.shape[0] 15 | per_sample_targets = int(tt.shape[0] / ss.shape[0]) 16 | labels = jnp.arange(0, per_shard_targets, per_sample_targets) + per_shard_targets * lax.axis_index(axis) 17 | 18 | tt = lax.all_gather(tt, axis).reshape((-1, ss.shape[-1])) 19 | scores = jnp.dot(ss, jnp.transpose(tt)) 20 | 21 | return optax.softmax_cross_entropy(scores, _onehot(labels, scores.shape[-1])) 22 | -------------------------------------------------------------------------------- /tevatron/tevax/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple, Any, Union 3 | 4 | import jax 5 | from jax import numpy as jnp 6 | 7 | from flax.training.train_state import TrainState 8 | from flax.core import FrozenDict 9 | from flax.struct import PyTreeNode 10 | 11 | from .loss import p_contrastive_loss 12 | 13 | 14 | class TiedParams(PyTreeNode): 15 | params: FrozenDict[str, Any] 16 | 17 | @property 18 | def q_params(self): 19 | return self.params 20 | 21 | @property 22 | def p_params(self): 23 | return self.params 24 | 25 | @classmethod 26 | def create(cls, params): 27 | return cls(params=params) 28 | 29 | 30 | class DualParams(PyTreeNode): 31 | params: Tuple[FrozenDict[str, Any], FrozenDict[str, Any]] 32 | 33 | @property 34 | def q_params(self): 35 | return self.params[0] 36 | 37 | @property 38 | def p_params(self): 39 | return self.params[1] 40 | 41 | @classmethod 42 | def create(cls, *ps): 43 | if len(ps) == 1: 44 | return cls(params=ps*2) 45 | else: 46 | p_params, q_params = ps 47 | return cls(params=[p_params, q_params]) 48 | 49 | 50 | class RetrieverTrainState(TrainState): 51 | params: Union[TiedParams, DualParams] 52 | 53 | 54 | def retriever_train_step(state, queries, passages, dropout_rng, axis='device'): 55 | q_dropout_rng, p_dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 3) 56 | 57 | def compute_loss(params): 58 | q_reps = state.apply_fn(**queries, params=params.q_params, dropout_rng=q_dropout_rng, train=True)[0][:, 0, :] 59 | p_reps = state.apply_fn(**passages, params=params.p_params, dropout_rng=p_dropout_rng, train=True)[0][:, 0, :] 60 | return jnp.mean(p_contrastive_loss(q_reps, p_reps, axis=axis)) 61 | 62 | loss, grad = jax.value_and_grad(compute_loss)(state.params) 63 | loss, grad = jax.lax.pmean([loss, grad], axis) 64 | 65 | new_state = state.apply_gradients(grads=grad) 66 | 67 | return loss, new_state, new_dropout_rng 68 | 69 | 70 | def grad_cache_train_step(state, queries, passages, dropout_rng, axis='device', q_n_subbatch=1, p_n_subbatch=1): 71 | try: 72 | from grad_cache import cachex 73 | except ImportError: 74 | raise ModuleNotFoundError('GradCache packaged needs to be installed for running grad_cache_train_step') 75 | 76 | def encode_query(params, **kwargs): 77 | return state.apply_fn(**kwargs, params=params.q_params, train=True)[0][:, 0, :] 78 | 79 | def encode_passage(params, **kwargs): 80 | return state.apply_fn(**kwargs, params=params.p_params, train=True)[0][:, 0, :] 81 | 82 | queries, passages = cachex.tree_chunk(queries, q_n_subbatch), cachex.tree_chunk(passages, p_n_subbatch) 83 | q_rngs, p_rngs, new_rng = jax.random.split(dropout_rng, 3) 84 | q_rngs = jax.random.split(q_rngs, q_n_subbatch) 85 | p_rngs = jax.random.split(p_rngs, p_n_subbatch) 86 | 87 | q_reps = cachex.chunk_encode(partial(encode_query, state.params))(**queries, dropout_rng=q_rngs) 88 | p_reps = cachex.chunk_encode(partial(encode_passage, state.params))(**passages, dropout_rng=p_rngs) 89 | 90 | @cachex.unchunk_args(axis=0, argnums=(0, 1)) 91 | def compute_loss(xx, yy): 92 | return jnp.mean(p_contrastive_loss(xx, yy, axis=axis)) 93 | 94 | loss, (q_grads, p_grads) = jax.value_and_grad(compute_loss, argnums=(0, 1))(q_reps, p_reps) 95 | 96 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params) 97 | grads = cachex.cache_grad(encode_query)(state.params, grads, q_grads, **queries, dropout_rng=q_rngs) 98 | grads = cachex.cache_grad(encode_passage)(state.params, grads, p_grads, **passages, dropout_rng=p_rngs) 99 | 100 | loss, grads = jax.lax.pmean([loss, grads], axis) 101 | new_state = state.apply_gradients(grads=grads) 102 | return loss, new_state, new_rng 103 | -------------------------------------------------------------------------------- /tevatron/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | from typing import Dict, List, Tuple, Optional, Any, Union 4 | 5 | from transformers.trainer import Trainer 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | 11 | from .loss import SimpleContrastiveLoss, DistributedContrastiveLoss 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | try: 17 | from grad_cache import GradCache 18 | _grad_cache_available = True 19 | except ModuleNotFoundError: 20 | _grad_cache_available = False 21 | 22 | 23 | class DenseTrainer(Trainer): 24 | def __init__(self, *args, **kwargs): 25 | super(DenseTrainer, self).__init__(*args, **kwargs) 26 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1 27 | 28 | def _save(self, output_dir: Optional[str] = None): 29 | output_dir = output_dir if output_dir is not None else self.args.output_dir 30 | os.makedirs(output_dir, exist_ok=True) 31 | logger.info("Saving model checkpoint to %s", output_dir) 32 | self.model.save(output_dir) 33 | 34 | def _prepare_inputs( 35 | self, 36 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...] 37 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]: 38 | prepared = [] 39 | for x in inputs: 40 | if isinstance(x, torch.Tensor): 41 | prepared.append(x.to(self.args.device)) 42 | else: 43 | prepared.append(super()._prepare_inputs(x)) 44 | return prepared 45 | 46 | def get_train_dataloader(self) -> DataLoader: 47 | if self.train_dataset is None: 48 | raise ValueError("Trainer: training requires a train_dataset.") 49 | train_sampler = self._get_train_sampler() 50 | 51 | return DataLoader( 52 | self.train_dataset, 53 | batch_size=self.args.train_batch_size, 54 | sampler=train_sampler, 55 | collate_fn=self.data_collator, 56 | drop_last=True, 57 | num_workers=self.args.dataloader_num_workers, 58 | ) 59 | 60 | def compute_loss(self, model, inputs): 61 | query, passage, teacher_scores = inputs 62 | 63 | return model(query=query, passage=passage, teacher_scores=teacher_scores).loss 64 | 65 | def training_step(self, *args): 66 | return super(DenseTrainer, self).training_step(*args) / self._dist_loss_scale_factor 67 | 68 | 69 | def split_dense_inputs(model_input: dict, chunk_size: int): 70 | assert len(model_input) == 1 71 | arg_key = list(model_input.keys())[0] 72 | arg_val = model_input[arg_key] 73 | 74 | keys = list(arg_val.keys()) 75 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] 76 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 77 | 78 | return [{arg_key: c} for c in chunked_arg_val] 79 | 80 | 81 | def get_dense_rep(x): 82 | if x.q_reps is None: 83 | return x.p_reps 84 | else: 85 | return x.q_reps 86 | 87 | 88 | class GCTrainer(DenseTrainer): 89 | def __init__(self, *args, **kwargs): 90 | logger.info('Initializing Gradient Cache Trainer') 91 | if not _grad_cache_available: 92 | raise ValueError( 93 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') 94 | super(GCTrainer, self).__init__(*args, **kwargs) 95 | 96 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss 97 | loss_fn = loss_fn_cls(self.model.data_args.train_n_passages) 98 | 99 | self.gc = GradCache( 100 | models=[self.model, self.model], 101 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], 102 | loss_fn=loss_fn, 103 | split_input_fn=split_dense_inputs, 104 | get_rep_fn=get_dense_rep, 105 | fp16=self.args.fp16, 106 | scaler=self.scaler 107 | ) 108 | 109 | def training_step(self, model, inputs) -> torch.Tensor: 110 | model.train() 111 | queries, passages = self._prepare_inputs(inputs) 112 | queries, passages = {'query': queries}, {'passage': passages} 113 | 114 | _distributed = self.args.local_rank > -1 115 | self.gc.models = [model, model] 116 | loss = self.gc(queries, passages, no_sync_except_last=_distributed) 117 | 118 | return loss / self._dist_loss_scale_factor 119 | -------------------------------------------------------------------------------- /tevatron/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/utils/__init__.py -------------------------------------------------------------------------------- /tevatron/utils/convert_from_dpr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dpr_model', required=True) 10 | parser.add_argument('--save_to', required=True) 11 | args = parser.parse_args() 12 | 13 | dpr_model_ckpt = torch.load(args.dpr_model, map_location='cpu') 14 | config_name = dpr_model_ckpt['encoder_params']['pretrained_model_cfg'] 15 | dpr_model_dict = dpr_model_ckpt['model_dict'] 16 | 17 | AutoConfig.from_pretrained(config_name).save_pretrained(args.save_to) 18 | AutoTokenizer.from_pretrained(config_name).save_pretrained(args.save_to) 19 | 20 | question_keys = [k for k in dpr_model_dict.keys() if k.startswith('question_model')] 21 | ctx_keys = [k for k in dpr_model_dict.keys() if k.startswith('ctx_model')] 22 | 23 | question_dict = dict([(k[len('question_model')+1:], dpr_model_dict[k]) for k in question_keys]) 24 | ctx_dict = dict([(k[len('ctx_model')+1:], dpr_model_dict[k]) for k in ctx_keys]) 25 | 26 | os.makedirs(os.path.join(args.save_to, 'query_model'), exist_ok=True) 27 | os.makedirs(os.path.join(args.save_to, 'passage_model'), exist_ok=True) 28 | torch.save(question_dict, os.path.join(args.save_to, 'query_model', 'pytorch_model.bin')) 29 | torch.save(ctx_dict, os.path.join(args.save_to, 'passage_model', 'pytorch_model.bin')) 30 | 31 | 32 | if __name__ == '__main__': 33 | main() -------------------------------------------------------------------------------- /tevatron/utils/data_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import os 4 | import json 5 | from typing import List, Tuple 6 | from collections import defaultdict 7 | logger = logging.getLogger(__name__) 8 | 9 | def create_dir(dir_: str): 10 | output_parent = '/'.join((dir_).split('/')[:-1]) 11 | if not os.path.exists(output_parent): 12 | logger.info(f'Create {output_parent}') 13 | os.mkdir(output_parent) 14 | if not os.path.exists(dir_): 15 | logger.info(f'Create {dir_}') 16 | os.mkdir(dir_) 17 | 18 | def read_tsv(path: str): 19 | id2info = {} 20 | with open(path, 'r') as f: 21 | for line in tqdm(f, desc=f"read {path}"): 22 | idx, info = line.strip().split('\t') 23 | id2info[idx] = info 24 | return id2info 25 | 26 | def read_json(path: str, 27 | id_key: str = 'id', 28 | content_key: str = 'content', 29 | meta_keys: List[str] = None, 30 | sep: str = ' '): 31 | id2info = {} 32 | with open(path, 'r') as f: 33 | for line in tqdm(f, desc=f"read {path}"): 34 | data = json.loads(line.strip().split('\t')) 35 | idx = data[id_key] 36 | info = data[content_key] 37 | if meta_key: 38 | info = [info] 39 | for meta_key in meta_keys: 40 | info.append(data[meta_key]) 41 | info = sep.join(info) 42 | id2info[idx] = info 43 | return id2info 44 | 45 | def read_trec(path: str): 46 | qid2psg = defaultdict(list) 47 | with open(path, 'r') as f: 48 | for line in tqdm(f, desc=f"read {path}"): 49 | try: 50 | data = line.strip().split('\t') 51 | qid = data[0] 52 | psg = data[2] 53 | except: 54 | data = line.strip().split(' ') 55 | qid = data[0] 56 | psg = data[2] 57 | qid2psg[qid].append(psg) 58 | 59 | 60 | return qid2psg 61 | 62 | def read_qrel(path: str): 63 | qid_pid2qrel = defaultdict(int) 64 | with open(path, 'r') as f: 65 | for line in tqdm(f, desc=f"read {path}"): 66 | qid, _, pid, rel,= line.strip().split('\t') 67 | qid_pid2qrel[f'{qid}_{pid}'] = int(rel) 68 | return qid_pid2qrel -------------------------------------------------------------------------------- /tevatron/utils/format/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/utils/format/__init__.py -------------------------------------------------------------------------------- /tevatron/utils/format/convert_result_to_trec.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | parser = ArgumentParser() 4 | parser.add_argument('--input', type=str, required=True) 5 | parser.add_argument('--output', type=str, required=True) 6 | args = parser.parse_args() 7 | 8 | with open(args.input) as f_in, open(args.output, 'w') as f_out: 9 | cur_qid = None 10 | rank = 0 11 | for line in f_in: 12 | qid, docid, score = line.split() 13 | if cur_qid != qid: 14 | cur_qid = qid 15 | rank = 0 16 | rank += 1 17 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n') 18 | -------------------------------------------------------------------------------- /tevatron/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def average_precision(gt, pred): 5 | """ 6 | Computes the average precision. 7 | 8 | This function computes the average prescision at k between two lists of 9 | items. 10 | 11 | Parameters 12 | ---------- 13 | gt: set 14 | A set of ground-truth elements (order doesn't matter) 15 | pred: list 16 | A list of predicted elements (order does matter) 17 | 18 | Returns 19 | ------- 20 | score: double 21 | The average precision over the input lists 22 | """ 23 | 24 | if not gt: 25 | return 0.0 26 | 27 | score = 0.0 28 | num_hits = 0.0 29 | for i,p in enumerate(pred): 30 | if p in gt and p not in pred[:i]: 31 | num_hits += 1.0 32 | score += num_hits / (i + 1.0) 33 | 34 | return score / max(1.0, len(gt)) 35 | 36 | 37 | def NDCG(gt, pred, use_graded_scores=False): 38 | score = 0.0 39 | for rank, item in enumerate(pred): 40 | if item in gt: 41 | if use_graded_scores: 42 | grade = 1.0 / (gt.index(item) + 1) 43 | else: 44 | grade = 1.0 45 | score += grade / np.log2(rank + 2) 46 | 47 | norm = 0.0 48 | for rank in range(len(gt)): 49 | if use_graded_scores: 50 | grade = 1.0 / (rank + 1) 51 | else: 52 | grade = 1.0 53 | norm += grade / np.log2(rank + 2) 54 | return score / max(0.3, norm) 55 | 56 | 57 | def metrics(gt, pred, metrics_map): 58 | ''' 59 | Returns a numpy array containing metrics specified by metrics_map. 60 | gt: ground-truth items 61 | pred: predicted items 62 | ''' 63 | out = np.zeros((len(metrics_map),), np.float32) 64 | 65 | if ('MAP' in metrics_map): 66 | avg_precision = average_precision(gt=gt, pred=pred) 67 | out[metrics_map.index('MAP')] = avg_precision 68 | 69 | if ('RPrec' in metrics_map): 70 | intersec = len(gt & set(pred[:len(gt)])) 71 | out[metrics_map.index('RPrec')] = intersec / max(1., float(len(gt))) 72 | 73 | if 'MRR' in metrics_map: 74 | score = 0.0 75 | for rank, item in enumerate(pred): 76 | if item in gt: 77 | score = 1.0 / (rank + 1.0) 78 | break 79 | out[metrics_map.index('MRR')] = score 80 | 81 | if 'MRR@10' in metrics_map: 82 | score = 0.0 83 | for rank, item in enumerate(pred[:10]): 84 | if item in gt: 85 | score = 1.0 / (rank + 1.0) 86 | break 87 | out[metrics_map.index('MRR@10')] = score 88 | 89 | if ('NDCG' in metrics_map): 90 | out[metrics_map.index('NDCG')] = NDCG(gt, pred) 91 | 92 | return out 93 | 94 | -------------------------------------------------------------------------------- /tevatron/utils/tokenize_corpus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) 3 | import argparse 4 | from tqdm import tqdm 5 | import os 6 | import json 7 | from multiprocessing import Pool 8 | from transformers import AutoTokenizer 9 | from .data_reader import create_dir 10 | 11 | DATA_ITEM = {'msmarco-passage': {'id':'id', 'contents': ['contents']}, 12 | 'beir': {'id':'_id', 'contents': ['title', 'text']}} 13 | 14 | def tokenize_and_json_save(data_item, data_type, tokenizer, lines, jsonl_path, tokenize, encode): 15 | output = open(jsonl_path, 'w') 16 | for i, line in enumerate( tqdm(lines, total=len(lines), desc=f"write {output}") ): 17 | if data_type == 'tsv': 18 | docid, contents = line.strip().split('\t') 19 | elif (data_type =='json') or (data_type =='jsonl'): 20 | line = json.loads(line.strip()) 21 | docid = line[data_item['id']] 22 | 23 | contents = [] 24 | for content in data_item['contents']: 25 | contents.append(line[content]) 26 | contents = ' '.join(contents) 27 | if tokenize: 28 | if encode: 29 | contents = tokenizer.encode(contents, add_special_tokens=False) 30 | # Fit the format of tevatron 31 | output_dict = {'text_id': docid, 'text': contents} 32 | else: 33 | contents = ' '.join(tokenizer.tokenize(contents)) 34 | output_dict = {'id': docid, 'contents': contents} 35 | else: 36 | output_dict = {'id': docid, 'contents': contents} 37 | output.write(json.dumps(output_dict) + '\n') 38 | output.close() 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser( 42 | description='Transform corpus into wordpiece corpus') 43 | parser.add_argument('--corpus_path', required=True, help='TSV or json corpus file with format {docid}\t{document}.') 44 | parser.add_argument('--output_dir', required=True) 45 | parser.add_argument('--corpus_domain', required=False, default='msmarco-passage') 46 | parser.add_argument('--tokenizer', required=False, default='bert-base-uncased', help='tokenizer model name') 47 | parser.add_argument('--tokenize', action='store_true') 48 | parser.add_argument('--encode', action='store_true') 49 | parser.add_argument('--num_workers', type=int, required=False, default=None) 50 | parser.add_argument('--max_line_per_file', type=int, required=False, default=300000, help='max length 150 use default; max length 512 use 300000') 51 | args = parser.parse_args() 52 | 53 | if args.encode: 54 | if not args.tokenize: 55 | raise ValueError('if you want to encode, you must set tokenize option!') 56 | 57 | create_dir(args.output_dir) 58 | 59 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 60 | 61 | data_type = (args.corpus_path).split('.')[-1] 62 | if (data_type != 'tsv') and (data_type != 'json') and (data_type != 'jsonl'): 63 | raise ValueError('--corpus_path should be tsv, json or jsonl format') 64 | 65 | with open(args.corpus_path, 'r') as f: 66 | print("read {}".format(args.corpus_path)) 67 | lines = f.readlines() 68 | total_num_docs = len(lines) 69 | print("total {} lines".format(total_num_docs)) 70 | 71 | ## for debug 72 | # tokenize_and_json_save(DATA_ITEM[args.corpus_domain], data_type, tokenizer, lines, os.path.join(jsonl_dir, 'split.json'), args.tokenize ) 73 | if args.num_workers is None: 74 | num_docs_per_worker = args.max_line_per_file 75 | args.num_workers = total_num_docs // num_docs_per_worker 76 | if (total_num_docs%num_docs_per_worker ) != 0: 77 | args.num_workers+=1 78 | else: 79 | num_docs_per_worker = total_num_docs//args.num_workers 80 | if (total_num_docs%args.num_workers) != 0: 81 | args.num_workers+=1 82 | 83 | logging.info(f'Run with {args.num_workers} workers on {total_num_docs} documents') 84 | pool = Pool(args.num_workers) 85 | for i in range(args.num_workers): 86 | f_out = os.path.join(args.output_dir, 'split%02d.json'%i) 87 | start = i*num_docs_per_worker 88 | if i==(args.num_workers-1): 89 | pool.apply_async(tokenize_and_json_save ,(DATA_ITEM[args.corpus_domain], data_type, tokenizer,\ 90 | lines[start:], f_out, args.tokenize, args.encode)) 91 | else: 92 | pool.apply_async(tokenize_and_json_save ,(DATA_ITEM[args.corpus_domain], data_type, tokenizer,\ 93 | lines[start:(start+num_docs_per_worker)], f_out, args.tokenize, args.encode)) 94 | 95 | pool.close() 96 | pool.join() 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /tevatron/utils/tokenize_query.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) 3 | import argparse 4 | from tqdm import tqdm 5 | import os 6 | import json 7 | from collections import defaultdict 8 | from transformers import AutoTokenizer 9 | import sys 10 | from .data_reader import read_tsv, create_dir 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser( 14 | description='Tokenize query') 15 | parser.add_argument('--qry_file', required=True, help='format {qid}\t{qry}') 16 | parser.add_argument('--output_dir', required=True) 17 | parser.add_argument('--tokenizer', required=False, default='bert-base-uncased', help='tokenizer model name') 18 | args = parser.parse_args() 19 | 20 | create_dir(args.output_dir) 21 | 22 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 23 | qid2qry = read_tsv(args.qry_file) 24 | 25 | query_name = args.qry_file.split('/')[-1].replace('.tsv','.json') 26 | output_path = os.path.join(args.output_dir, query_name) 27 | output = open(output_path, 'w') 28 | with open(args.qry_file, 'r') as f: 29 | for line in tqdm(f, desc=f"tokenize query: {output_path}"): 30 | qid, qry = line.strip().split('\t') 31 | qry = tokenizer.encode(qry, add_special_tokens=False) 32 | output_dict = {"text_id": qid, "text": qry} 33 | output.write(json.dumps(output_dict) + '\n') 34 | output.close() 35 | if __name__ == "__main__": 36 | main() --------------------------------------------------------------------------------