├── README.md
├── __init__.py
├── densify
├── __init__.py
├── densify_corpus.py
├── densify_query.py
└── output_vector.py
├── docs
├── aggretriever
│ ├── beir-eval.md
│ └── msmarco-passage-train-eval.md
└── dhr
│ ├── beir-eval.md
│ ├── densify_exp.md
│ └── msmarco-passage-train-eval.md
├── fig
├── aggretriever.png
├── aggretriever_teaser.png
├── densification.png
└── single_model_fusion.png
├── retrieval
├── __init__.py
├── evaluation
│ ├── __init__.py
│ └── custom_metrics.py
├── gip_retrieval.py
├── index.py
├── merge.result.py
├── quantize_index.py
├── rcap_eval.py
└── util.py
└── tevatron
├── Aggretriever
├── __init__.py
├── modeling.py
└── utils.py
├── ColBERT
└── modeling.py
├── DHR
├── __init__.py
├── modeling.py
└── utils.py
├── Dense
├── __init__.py
└── modeling.py
├── __init__.py
├── arguments.py
├── data.py
├── datasets
├── __init__.py
├── beir
│ ├── __init__.py
│ ├── encode_and_retrieval.py
│ ├── preprocess.py
│ └── sentence_bert.py
├── dataset.py
└── preprocessor.py
├── driver
├── __init__.py
├── encode.py
├── eval.py
├── jax_encode.py
├── jax_train.py
└── train.py
├── faiss_retriever
├── __init__.py
├── __main__.py
├── reducer.py
└── retriever.py
├── loss.py
├── preprocessor
├── __init__.py
└── preprocessor_tsv.py
├── tevax
├── __init__.py
├── loss.py
└── training.py
├── trainer.py
└── utils
├── __init__.py
├── convert_from_dpr.py
├── data_reader.py
├── format
├── __init__.py
└── convert_result_to_trec.py
├── metrics.py
├── tokenize_corpus.py
└── tokenize_query.py
/README.md:
--------------------------------------------------------------------------------
1 | # Dense Hybrid Retrieval
2 | In this repo, we introduce two approaches to training transformers to capture semantic and lexical text representations for robust dense passage retrieval.
3 | 1. *[Aggretriever: A Simple Approach to Aggregate Textual Representation for Robust Dense Passage Retrieval](https://arxiv.org/abs/2208.00511)* Sheng-Chieh Lin, Minghan Li and Jimmy Lin. (TACL just accepted)
4 | 2. *[A Dense Representation Framework for Lexical and Semantic Matching](https://dl.acm.org/doi/10.1145/3582426)* Sheng-Chieh Lin and Jimmy Lin. (TOIS 2021 in press)
5 |
6 | This repo contains three parts: (1) densify (2) training (tevatron) (3) retrieval.
7 | Our training code is mainly from [Tevatron](https://github.com/texttron/tevatron) with a minor revision.
8 |
9 | ## Requirements
10 | ```
11 | pip install torch>=1.7.0
12 | pip install transformers==4.15.0
13 | pip install pyserini
14 | pip install beir
15 | ```
16 |
17 | ## Huggingface Checkpoints
18 | Model | Initialization | MARCO Dev | BEIR (13 public datasets) | Huggingface Path | Document
19 | |---|---|---|---|---|---
20 | DeLADE+[CLS] plus | [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) | 37.1 | 49.8 | [jacklin/DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/dhr)
21 | DeLADE+[CLS] | [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) | 35.7 | 48.5 | [jacklin/DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/dhr)
22 | Aggretriever | [distilbert-base-uncased](https://huggingface.co/bert-base-uncased) | 34.1 | 46.0 | [jacklin/DistilBERT-AGG](https://huggingface.co/jacklin/DistilBERT-AGG) | [Read Me](https://github.com/castorini/dhr/tree/main/docs/aggretriever)
23 |
24 | # Aggretriever
25 |
26 |
27 |
28 | In this paper, we introduce a simple approach to aggregating token-level information into a single-vector dense representation. We provide instruction for model training and evaluation on MS MARCO passage ranking dataset in the [document](https://github.com/castorini/dhr/blob/main/docs/aggretriever/msmarco-passage-train-eval.md). We also provide instruction for the evaluation on BEIR datasets in the [document](https://github.com/castorini/dhr/blob/main/docs/aggretriever/beir-eval.md).
29 |
30 | # A Dense Representation Framework for Lexical and Semantic Matching
31 | In this paper, we introduce a unified representation framework for Lexical and Semantic Matching. We first introduce how to use our framework to conduct retrieval for high-dimensional (lexcial) representations and combine with single-vector dense (semantic) representations for hybrid search.
32 | ## Dense Lexical Retrieval
33 |
34 |
35 |
36 | We can densify any existing lexical matching models and conduct lexical matching on GPU. In the [document](https://github.com/jacklin64/DHR/blob/main/docs/densify_exp.md), we demonstrate how to conduct BM25 and uniCOIL end-to-end retrieval under our framework. Detailed description can be found in our [paper](https://arxiv.org/pdf/2112.04666.pdf).
37 |
38 | ## Dense Hybrid Retrieval
39 | With the densified lexical representations, we can easily conduct lexical and semantic hybrid retrieval using independent neural models. A document for hybrid retrieval will be coming soon.
40 |
41 | ## Dense Hybrid Representation Model
42 |
43 |
44 |
45 | In our paper, we propose a single model fusion approach by training the lexical and semantic components of a transformer while inference, we combine the densified lexical representations and dense representations as dense hybrid representations. Instead of training by yourself, you can also download our trained [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P), [DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) and [DeLADE](https://huggingface.co/jacklin/DeLADE) and directly peform inference on MSMARCO Passage dataset (see [document](https://github.com/jacklin64/DHR/blob/main/docs/dhr/msmarco-passage-train-eval.md)) or BEIR datasets (see [document](https://github.com/jacklin64/DHR/blob/main/docs/dhr/beir-eval.md)).
46 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/__init__.py
--------------------------------------------------------------------------------
/densify/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/densify/__init__.py
--------------------------------------------------------------------------------
/densify/densify_corpus.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | import glob
4 | import numpy as np
5 | import gzip
6 | import json
7 | import argparse
8 | from pyserini.index import IndexReader
9 | from multiprocessing import Pool, Manager, Queue
10 | from transformers import AutoModelForMaskedLM, AutoTokenizer
11 | import multiprocessing
12 | import os
13 | from tqdm import tqdm
14 | logger = logging.getLogger(__name__)
15 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
16 |
17 | omission_num = \
18 | {'bm25': 472,
19 | 'deepimpact': 502,
20 | 'unicoil': 570,
21 | 'splade': 570}
22 |
23 | whole_word_matching = \
24 | {'bm25': True,
25 | 'deepimpact': True,
26 | 'unicoil': False,
27 | 'splade': False}
28 |
29 | def densify(data, dim, whole_word_matching, token2id, args):
30 | value = np.zeros((dim), dtype=np.float16)
31 | if whole_word_matching:
32 | index = np.zeros((dim), dtype=np.int16)
33 | else:
34 | index = np.zeros((dim), dtype=np.int8)
35 | collision_num = 0
36 | for i, (token, weight) in enumerate(data['vector'].items()):
37 | token_id = token2id[token]
38 | if token_id < omission_num[args.model]:
39 | continue
40 | else:
41 | slice_num = (token_id - omission_num[args.model])%dim
42 | index_num = (token_id - omission_num[args.model])//dim
43 | if value[slice_num]==0:
44 | value[slice_num] = weight
45 | index[slice_num] = index_num
46 | else:
47 | # collision
48 | collision_num += 1
49 | if value[slice_num] < weight:
50 | value[slice_num] = weight
51 | index[slice_num] = index_num
52 | return value, index, collision_num
53 |
54 |
55 | def vectorize_and_densify(files, file_type, dim, whole_word_matching, token2id, output_path, args):
56 | data_num = 0
57 | logger.info('count line number')
58 | for file in files:
59 | if file_type == 'jsonl.gz':
60 | f = gzip.open(file, "rb")
61 | else:
62 | f = open(file, 'r')
63 | for line in f:
64 | data_num+=1
65 | f.close()
66 |
67 | logger.info('initialize numpy array with {}X{}'.format(data_num, dim))
68 | value_encoded = np.zeros((data_num, dim), dtype=np.float16)
69 | if whole_word_matching:
70 | index_encoded = np.zeros((data_num, dim), dtype=np.int16)
71 | else:
72 | index_encoded = np.zeros((data_num, dim), dtype=np.int8)
73 | docids =[]
74 | total_collision_num = 0
75 | counter = 0
76 | for file in files:
77 | if file_type == 'jsonl.gz':
78 | f = gzip.open(file, "rb")
79 | else:
80 | f = open(file, 'r')
81 | for i, line in tqdm(enumerate(f), desc=f"densify {file}"):
82 | data = json.loads(line)
83 | docids.append(data['id'])
84 | value, index, collision_num = densify(data, dim, whole_word_matching, token2id, args)
85 | total_collision_num += collision_num
86 | value_encoded[counter] = value
87 | index_encoded[counter] = index
88 | counter += 1
89 | f.close()
90 |
91 | print('Total {} collisions with {} passages'.format(total_collision_num, data_num))
92 | with open(output_path, 'wb') as f_out:
93 | pickle.dump([value_encoded, index_encoded, docids], f_out, protocol=4)
94 |
95 |
96 | def get_files(directory):
97 | files = glob.glob(os.path.join(directory, '*.json'))
98 | if len(files) == 0:
99 | files = glob.glob(os.path.join(directory, '*.jsonl.gz'))
100 | file_type = 'jsonl.gz'
101 | else:
102 | file_type = 'json'
103 | if len(files) == 0:
104 | raise ValueError('There is no json or jsonl.gz files in {}'.format(directory))
105 | return files, file_type
106 |
107 | def main():
108 | parser = argparse.ArgumentParser(description='Densify corpus')
109 | parser.add_argument('--model', required=True, help='bm25, deepimpact, unicoil or splade')
110 | parser.add_argument('--tokenizer', required=False, default="bert-base-uncased", help='anserini index path or transformer tokenizer')
111 | parser.add_argument('--vector_dir', required=True, help='directory with json files')
112 | parser.add_argument('--output_dir', required=True, help='output pickle directory')
113 | parser.add_argument('--output_dims', type=int, required=False, default=768)
114 | parser.add_argument('--num_workers', type=int, required=False, default=None)
115 | parser.add_argument('--prefix', required=True, help='index name prefix')
116 | args = parser.parse_args()
117 |
118 | token2id = {}
119 | if (args.model == 'bm25') or (args.model == 'deepimpact'):
120 | tokenizer = IndexReader(args.tokenizer)
121 | for idx, token in tqdm(enumerate(tokenizer.terms()), desc=f"read index terms"):
122 | token2id[token.term] = idx
123 | elif (args.model == 'unicoil') or (args.model == 'splade'):
124 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
125 | token2id = tokenizer.vocab
126 | else:
127 | raise ValueError('We cannot handle you input model')
128 |
129 |
130 | if not os.path.exists(args.output_dir):
131 | os.mkdir(args.output_dir)
132 |
133 | densified_vector_dir = os.path.join(args.output_dir, f'encoding')
134 | if not os.path.exists(densified_vector_dir):
135 | os.mkdir(densified_vector_dir)
136 |
137 | files, file_type = get_files(args.vector_dir)
138 |
139 | total_num_files = len(files)
140 | if args.num_workers is None:
141 | args.num_workers = total_num_files
142 | num_files_per_worker = 1
143 | else:
144 | num_files_per_worker = total_num_files//args.num_workers
145 | if (total_num_files%args.num_workers) != 0:
146 | args.num_workers+=1
147 |
148 | pool = Pool(args.num_workers)
149 | for i in range(args.num_workers):
150 | start = i*num_files_per_worker
151 | output_path = os.path.join(densified_vector_dir, f"{args.prefix}.split{i}.pt")
152 |
153 | if i==(args.num_workers-1):
154 | pool.apply_async(vectorize_and_densify ,(files[start:], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args))
155 | else:
156 | pool.apply_async(vectorize_and_densify ,(files[start:(start+num_files_per_worker)], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args))
157 |
158 | # for debug
159 | # vectorize_and_densify(files[start:(start+num_files_per_worker)], file_type, args.output_dims, whole_word_matching[args.model], token2id, output_path, args)
160 |
161 | pool.close()
162 | pool.join()
163 |
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/densify/densify_query.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | import glob
4 | import numpy as np
5 | import json
6 | import argparse
7 | from collections import defaultdict
8 | from pyserini.index import IndexReader
9 | from pyserini.analysis import Analyzer, get_lucene_analyzer
10 | from multiprocessing import Pool, Manager, Queue
11 | import multiprocessing
12 | import os
13 | from tqdm import tqdm
14 | from transformers import AutoModelForMaskedLM, AutoTokenizer
15 | from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder
16 | from .densify_corpus import densify, whole_word_matching
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser(
21 | description='Transform corpus into wordpiece corpus')
22 | parser.add_argument('--model', required=True, help='bm25, deepimpact, unicoil or splade')
23 | parser.add_argument('--tokenizer', required=False, default="bert-base-uncased", help='anserini index path or transformer tokenizer')
24 | parser.add_argument('--query_path', required=True, help='query tsv file')
25 | parser.add_argument('--output_dims', type=int, required=False, default=768)
26 | parser.add_argument('--output_dir', required=True, help='output pickle directory')
27 | parser.add_argument('--prefix', required=True, help='index name prefix')
28 | args = parser.parse_args()
29 |
30 |
31 | if not os.path.exists(args.output_dir):
32 | os.mkdir(args.output_dir)
33 |
34 | args.output_dir = os.path.join(args.output_dir, f'encoding')
35 | if not os.path.exists(args.output_dir):
36 | os.mkdir(args.output_dir)
37 |
38 | densified_vector_dir = os.path.join(args.output_dir, f"queries")
39 | if not os.path.exists(densified_vector_dir):
40 | os.mkdir(densified_vector_dir)
41 |
42 |
43 | args.model = args.model.lower()
44 | token2id = {}
45 | if (args.model == 'bm25') or (args.model == 'deepimpact'):
46 | analyzer = Analyzer(get_lucene_analyzer())
47 | tokenizer = IndexReader(args.tokenizer)
48 | for idx, token in tqdm(enumerate(tokenizer.terms()), desc=f"read index terms"):
49 | token2id[token.term] = idx
50 | if args.model == 'bm25':
51 | analyze = True
52 | else:
53 | analyze = False
54 | query_encoder = None
55 | elif (args.model == 'unicoil') or (args.model == 'splade'):
56 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
57 | token2id = tokenizer.vocab
58 | if args.model == 'unicoil':
59 | query_encoder = UniCoilQueryEncoder('castorini/unicoil-msmarco-passage')
60 | else:
61 | raise ValueError('We cannot handle you input --model')
62 |
63 |
64 | f = open(args.query_path, 'r')
65 | print('count line number')
66 | data_num = 0
67 | for line in f:
68 | data_num+=1
69 | f.close()
70 |
71 | print('initialize numpy array with {}X{}'.format(data_num, args.output_dims))
72 | value_encoded = np.zeros((data_num, args.output_dims), dtype=np.float16)
73 | index_encoded = np.zeros((data_num, args.output_dims), dtype=np.int16)
74 |
75 | qids = []
76 | data = {}
77 | total_collision_num = 0
78 | with open(args.query_path, 'r') as f:
79 | for i, line in enumerate(f):
80 | qid, query = line.strip().split('\t')
81 | if query_encoder is None:
82 | if analyze:
83 | analyzed_query_terms = analyzer.analyze(query)
84 | else:
85 | analyzed_query_terms = query.split(' ')
86 | # use tf as term weight
87 | vector = defaultdict(int)
88 | for analyzed_query_term in analyzed_query_terms:
89 | vector[analyzed_query_term] += 1
90 | else:
91 | vector = query_encoder.encode(query)
92 |
93 | data['vector'] = vector
94 |
95 | qids.append(qid)
96 | value, index, collision_num = densify(data, args.output_dims, whole_word_matching[args.model] , token2id, args)
97 | total_collision_num += collision_num
98 | value_encoded[i] = value
99 | index_encoded[i] = index
100 |
101 |
102 |
103 | print('Total {} collisions with {} queries'.format(total_collision_num, i+1))
104 | file_name = args.prefix + '.' + (args.query_path).split('/')[-1].replace('tsv','pt')
105 | output_path = os.path.join(densified_vector_dir, file_name)
106 | with open(output_path, 'wb') as f_out:
107 | pickle.dump([value_encoded, index_encoded, qids], f_out, protocol=4)
108 |
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/densify/output_vector.py:
--------------------------------------------------------------------------------
1 | from pyserini.search import SimpleSearcher
2 | from pyserini.index import IndexReader
3 | import json
4 | from tqdm import tqdm
5 | import argparse
6 | import itertools
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser(
9 | description='Extract text contents from anserini index')
10 | parser.add_argument('--index_path', required=True, help='anserini index path')
11 | parser.add_argument('--output_path', required=True, help='Output file in the anserini jsonl format.')
12 | parser.add_argument('--tf_only' , action='store_true')
13 | args = parser.parse_args()
14 |
15 | index_reader = IndexReader(args.index_path)
16 | searcher = SimpleSearcher(args.index_path)
17 | total_num_docs = searcher.num_docs
18 |
19 | # term_dict = {}
20 | # for idx, term in tqdm(enumerate(index_reader.terms()), desc=f"read index terms"):
21 | # term_dict[term.term] = idx
22 |
23 | fout = open(args.output_path, 'w')
24 | for i in tqdm(range(total_num_docs), total=total_num_docs, desc=f"compute bm25 vector"):
25 | docid = searcher.doc(i).docid()
26 | tf = index_reader.get_document_vector(docid)
27 | vector = {}
28 | for term in tf:
29 | vector[term] = index_reader.compute_bm25_term_weight(docid, term, analyzer=None)
30 | output_dict = {'id': docid, 'vector': vector}
31 | fout.write(json.dumps(output_dict) + '\n')
32 | fout.close()
--------------------------------------------------------------------------------
/docs/aggretriever/beir-eval.md:
--------------------------------------------------------------------------------
1 | # BEIR Evaluation
2 | ## Evaluation with Sentence Transformer
3 | We use [BEIR](https://github.com/beir-cellar/beir) API to conduct brute-force search.
4 | ```
5 | git clone https://huggingface.co/jacklin/DistilBERT-AGG
6 | export MODEL_DIR=DistilBERT-AGG
7 | export CUDA_VISIBLE_DEVICES=0
8 | export MODEL=AGG
9 | export AGGDIM=640
10 | export CORPUS=scifact
11 | python -m tevatron.datasets.beir.encode_and_retrieval --dataset ${CORPUS} --model_name_or_path ${MODEL_DIR} --model ${MODEL} --agg_dim ${AGGDIM}
12 | ```
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/docs/aggretriever/msmarco-passage-train-eval.md:
--------------------------------------------------------------------------------
1 | # Training and Inference on MSMARCO Passage ranking
2 | In the following, we describe how to train, encode and retrieve with Aggretriever on MS MARCO passage-v1.
3 | 1. [MS MARCO Passage-v1 Data Preparation](#msmarco_data_prep)
4 | 1. [Training](#training)
5 | 1. [Generate Passage and Query Embeddings](#generate_embeddings)
6 | 1. [End-To-End Retrieval](#retrieval)
7 | 1. [Evaluation](#evaluation)
8 |
9 | ## Data Preparation
10 | We first preprocess the corpus, development queries and official training data in the json format. Each passage in the corpus is a line with the format: `{"text_id": passage_id, "text": [vocab_ids]}`. Similarly, each query in the development set is a line with the format: `{"text_id": query_id, "text": [vocab_ids]}`. As for training data, we rearrange the official training data in the format: `{"query": [vocab_ids], "positive_pids": [positive_passage_id0, positive_passage_id1, ...], "negative_pids": [negative_passage_id0, negative_passage_id1, ...]}`. Note that we use string type for passage and query. You can also download our preprocessed data on huggingface hub: [official_train](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus), [queries](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_queries) and [corpus](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus).
11 |
12 | ## Training
13 | This below script is the Aggretriever training in our paper. Here we use distilbert-base-uncased as an example. You can switch to any backbone using `--model_name_or_path`.
14 | ```shell=bash
15 | export CUDA_VISIBLE_DEVICES=0
16 | export MODEL=AGG
17 | export CLSDIM=128
18 | export AGGDIM=640
19 | export MODEL_DIR=${MODEL}_CLS${CLSDIM}XAGG${AGGDIM}
20 | export DATA_DIR=need_your_assignment
21 |
22 | python -m tevatron.driver.train \
23 | --output_dir ${MODEL_DIR} \
24 | --train_dir ${DATA_DIR}/official_train \
25 | --corpus_dir ${DATA_DIR}/corpus \
26 | --model_name_or_path distilbert-base-uncased \
27 | --do_train \
28 | --save_steps 20000 \
29 | --fp16 \
30 | --per_device_train_batch_size 8 \
31 | --learning_rate 5e-6 \
32 | --q_max_len 32 \
33 | --p_max_len 128 \
34 | --num_train_epochs 3 \
35 | --add_pooler \
36 | --model ${MODEL} \
37 | --projection_out_dim ${CLSDIM} \
38 | --agg_dim ${AGGDIM}
39 | --train_n_passages 8 \
40 | --dataloader_num_workers 8 \
41 | ```
42 |
43 | ## Inference MSMARCO Passage for Retrieval
44 | ```
45 | export CUDA_VISIBLE_DEVICES=0
46 | export CORPUS=msmarco-passage
47 | export SPLIT=dev.small
48 | export INDEX_DIR=${MODEL_DIR}/encoding
49 | export DATA_DIR=need_your_assignment
50 |
51 | # Corpus
52 | for i in $(seq -f "%02g" 0 10)
53 | do
54 | echo '============= Inference doc.split '${i} ' ============='
55 | srun --gres=gpu:p100:1 --mem=16G --cpus-per-task=2 --time=1:40:00 \
56 | python -m tevatron.driver.encode \
57 | --output_dir ${MODEL_DIR} \
58 | --model_name_or_path ${MODEL_DIR} \
59 | --add_pooler \
60 | --projection_out_dim ${CLSDIM} \
61 | --agg_dim ${AGGDIM} \
62 | --model ${MODEL} \
63 | --fp16 \
64 | --p_max_len 128 \
65 | --per_device_eval_batch_size 128 \
66 | --encode_in_path ${DATA_DIR}/corpus/split${i}.json \
67 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt &
68 | done
69 |
70 | # Merge index
71 | python -m retrieval.index \
72 | --index_path ${INDEX_DIR} \
73 | --index_prefix ${CORPUS}
74 | mkdir ${INDEX_DIR}/index
75 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/
76 |
77 | # Queries
78 | for SPLIT in dev.small
79 | do
80 | mkdir ${INDEX_DIR}/queries
81 | python -m tevatron.driver.encode \
82 | --output_dir ${MODEL_DIR} \
83 | --model_name_or_path ${MODEL_DIR} \
84 | --fp16 \
85 | --q_max_len 32 \
86 | --model ${MODEL} \
87 | --encode_is_qry \
88 | --add_pooler \
89 | --projection_out_dim ${CLSDIM} \
90 | --agg_dim ${AGGDIM} \
91 | --per_device_eval_batch_size 128 \
92 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \
93 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt
94 | done
95 | ```
96 |
97 | ## End-to-End Retrieval
98 | ```
99 | # IP retrieval
100 | for shrad in 0
101 | do
102 | echo 'run shrad'$shrad
103 | python -m retrieval.gip_retrieval \
104 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \
105 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \
106 | --topk 1000 \
107 | --total_shrad 1 \
108 | --shrad $shrad \
109 | --IP \
110 | --use_gpu \
111 | done
112 | ```
113 |
114 | ## Evaluation
115 | The run file, result.trec, is in the trec format so that you can directly evaluate the result using pyserini.
116 | ```
117 | python -m pyserini.eval.trec_eval -c -M 10 -m recip_rank ${QREL_PATH} result.trec
118 | python -m pyserini.eval.trec_eval -c -m recall.1000 ${QREL_PATH} result.trec
119 | ```
120 |
121 |
122 |
--------------------------------------------------------------------------------
/docs/dhr/beir-eval.md:
--------------------------------------------------------------------------------
1 | # BEIR Evaluation
2 | We provide two scripts for BEIR evaluation and use the model, [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P), and the dataset, trec-covid, as an example.
3 | 1. [Evaluation with GIP Retrieval](#evaluation_with_gip)
4 | 1. [Evaluation with Sentence Transformer](#evaluation_with_sentence_transformer)
5 |
6 | ## Evaluation with GIP Retrieval
7 | We first downlaod our model and beir dataset.
8 | ```
9 | git clone https://huggingface.co/jacklin/DeLADE-CLS-P
10 | export MODEL_DIR=DeLADE-CLS-P
11 | export CORPUS=trec-covid
12 | export SPLIT=test
13 | python -m tevatron.datasets.beir.preprocess --dataset ${CORPUS}
14 | ```
15 | Then we tokenize the query and corpus.
16 | ```
17 | python -m tevatron.utils.tokenize_corpus \
18 | --corpus_path ./dataset/${CORPUS}/corpus/collection.json \
19 | --output_dir ./dataset/${CORPUS}/tokenized_data/corpus \
20 | --corpus_domain beir \
21 | --tokenize --encode --num_workers 10
22 |
23 | python -m tevatron.utils.tokenize_query \
24 | --qry_file ./dataset/${CORPUS}/queries/queries.${SPLIT}.tsv \
25 | --output_dir ./dataset/${CORPUS}/tokenized_data/queries
26 |
27 | ```
28 | Following the [inference scripts](https://github.com/castorini/DHR/blob/main/docs/msmarco-passage-train-eval.md#inference-msmarco-passage-for-retrieval) for msmarco-passage data, we run inference, GIP retrieval and evaluation on the BEIR dataset.
29 | ```
30 | export CUDA_VISIBLE_DEVICES=0
31 | export MODEL=DHR #change to DLR if you use DLR model
32 | export CLSDIM=128
33 | export DLRDIM=768
34 | export CORPUS=trec-covid
35 | export SPLIT=test
36 | export INDEX_DIR=${MODEL_DIR}/encoding${DLRDIM}
37 | export DATA_DIR=./dataset/${CORPUS}/tokenized_data
38 |
39 | # Corpus
40 | for file in ${DATA_DIR}/corpus/split*.json
41 | do
42 | i=$(echo $file |rev | cut -c -7 |rev | cut -c -2 )
43 | echo "===========inference ${file}==========="
44 | python -m tevatron.driver.encode \
45 | --output_dir ${MODEL_DIR} \
46 | --model_name_or_path ${MODEL_DIR} \
47 | --projection_out_dim ${CLSDIM} \
48 | --dlr_out_dim ${DLRDIM} \
49 | --model ${MODEL} \
50 | --add_pooler \
51 | --combine_cls \
52 | --fp16 \
53 | --p_max_len 512 \
54 | --per_device_eval_batch_size 32 \
55 | --encode_in_path ${file} \
56 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt
57 | done
58 |
59 | # Merge index
60 | python -m retrieval.index \
61 | --index_path ${INDEX_DIR} \
62 | --index_prefix ${CORPUS}
63 | mkdir ${INDEX_DIR}/index
64 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/
65 |
66 | # QUERY
67 | mkdir ${INDEX_DIR}/queries
68 | python -m tevatron.driver.encode \
69 | --output_dir ${MODEL_DIR} \
70 | --model_name_or_path ${MODEL_DIR} \
71 | --fp16 \
72 | --q_max_len 512 \
73 | --model ${MODEL} \
74 | --encode_is_qry \
75 | --combine_cls \
76 | --add_pooler \
77 | --projection_out_dim ${CLSDIM} \
78 | --dlr_out_dim ${DLRDIM} \
79 | --per_device_eval_batch_size 128 \
80 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \
81 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt
82 |
83 | ```
84 | ```
85 | # GIP retrieval
86 | for shrad in 0
87 | do
88 | echo 'run shrad'$shrad
89 | python -m retrieval.gip_retrieval \
90 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \
91 | --emb_dim ${DLRDIM} \
92 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \
93 | --topk 1000 \
94 | --total_shrad 1 \
95 | --shrad $shrad \
96 | --theta 0.3 \
97 | --rerank \
98 | --use_gpu \
99 | --combine_cls
100 | done
101 | ```
102 | ```
103 | # Evaluation
104 | python -m pyserini.eval.trec_eval -c -mndcg_cut.10 -mrecall.100 ./dataset/${CORPUS}/qrels/qrels.${SPLIT}.tsv result.trec
105 | python -m retrieval.rcap_eval --qrel_file_path ./dataset/${CORPUS}/qrels/qrels.${SPLIT}.tsv --run_file_path result.trec
106 |
107 | ```
108 | ## Evaluation with Sentence Transformer
109 | The second one is to directly use [BEIR](https://github.com/beir-cellar/beir) API to conduct brute-force search. No densification before retrieval; thus, the result is slightly different from the numbers reported in our paper. Note that, for this script, we currently only support our DHR models, [DeLADE-CLS](https://huggingface.co/jacklin/DeLADE-CLS) and [DeLADE-CLS-P](https://huggingface.co/jacklin/DeLADE-CLS-P).
110 | ```
111 | git clone https://huggingface.co/jacklin/DeLADE-CLS-P
112 | export MODEL_DIR=DeLADE-CLS-P
113 | python -m tevatron.datasets.beir.encode_and_retrieval --dataset trec-covid --model_name_or_path ${MODEL_DIR}
114 | ```
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/docs/dhr/densify_exp.md:
--------------------------------------------------------------------------------
1 | # Densify Sparse Vector
2 | The repo is to demonstrate how to densify existing sparse lexical retrievers for dense search. We use [pyserini](https://github.com/castorini/pyserini) to get the sparse vectors from models. We show how to densify BM25 on msmarco-passage ranking dataset in this repo.
3 | 1. [Densifying BM25](#densifying_bm25)
4 | 1. [Densifying uniCOIL](#densifying_uniCOIL)
5 |
6 | # Densifying BM25
7 | ## Data Prepare
8 | Folloing the [instruction](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-passage.md), we first download MSMARCO passage collection and query files. Then, convert the collection.tsv into json file in $COLLECTION_PATH for pyserini index, and put queries.dev.small.tsv file into $Q_DIR.
9 | ```shell=bash
10 | export COLLECTION_PATH=need_your_assignment
11 | export INDEX_PATH=need_your_assignment
12 | export VECTOR_DIR=need_your_assignment
13 | export Q_DIR=need_your_assignment
14 | export MODEL=BM25
15 | export DLRDIM=768
16 | export CORPUS=msmarco-passage
17 | export DLR_PATH=${MODEL}_DIM${DLRDIM}
18 | export SPLIT=dev.small
19 | ```
20 | ## Output BM25 Vector from index
21 | We first index the json corpus using BM25.
22 | ```shell=bash
23 | python -m pyserini.index.lucene \
24 | --collection JsonVectorCollection \
25 | --input ${COLLECTION_PATH} \
26 | --index ${INDEX_PATH} \
27 | --generator DefaultLuceneDocumentGenerator \
28 | --threads 12 \
29 | --storeDocvectors --storeRaw --optimize
30 | ```
31 | Then, we output the sparse vector in a json file. We split the json file into multiple splits for multi-process in the next step.
32 | ```shell=bash
33 | python -m densify.output_vector \
34 | --index_path ${INDEX_PATH} \
35 | --output_path ${VECTOR_DIR}/split.json
36 |
37 | split -a 2 -dl 1000000 --additional-suffix=.json ${VECTOR_DIR}/split.json ${VECTOR_DIR}/split
38 | rm ${VECTOR_DIR}/split.json
39 | ```
40 | ## Sparse vector densification
41 | We now start to densify corpus and queries.
42 | ```shell=bash
43 | python -m densify.densify_corpus \
44 | --model ${MODEL} \
45 | --prefix ${CORPUS} \
46 | --tokenizer ${INDEX_PATH} \
47 | --vector_dir ${VECTOR_DIR} \
48 | --output_dir ${DLR_PATH} \
49 | --output_dims ${DLRDIM}
50 |
51 | python -m densify.densify_query \
52 | --model bm25 \
53 | --prefix ${CORPUS} \
54 | --tokenizer ${INDEX_PATH} \
55 | --query_path ${Q_DIR}/queries.${SPLIT}.tsv \ \
56 | --output_dir ${DLR_PATH} \
57 | --output_dims ${DLRDIM} \
58 | ```
59 | ## BM25 search on GPU
60 | We then merge index and start DLR search.
61 | ```shell=bash
62 | # Merge index
63 | python -m retrieval.index \
64 | --index_path ${DLR_PATH}/encoding \
65 | --index_prefix ${CORPUS} \
66 |
67 | mkdir ${DLR_PATH}/encoding/index
68 | mv ${DLR_PATH}/encoding/${CORPUS}.index.pt ${DLR_PATH}/encoding/index/
69 |
70 | # Search
71 | python -m retrieval.gip_retrieval \
72 | --query_emb_path ${DLR_PATH}/encoding/queries/queries.${CORPUS}.${SPLIT}.pt \
73 | --emb_dim ${DLRDIM} \
74 | --index_path ${DLR_PATH}/encoding/index/${CORPUS}.index.pt \
75 | --theta 1 \
76 | --rerank \
77 | --use_gpu \
78 | ```
79 |
80 | # Densifying uniCOIL
81 | ## Data Prepare
82 | Folloing the [instruction](https://github.com/castorini/pyserini/blob/master/docs/experiments-unicoil.md), we download pre-encoded uniCOIL passage collection.
83 | ```shell=bash
84 | wget https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/data/msmarco-passage-unicoil.tar -P collections/
85 |
86 | tar xvf collections/msmarco-passage-unicoil.tar -C collections/
87 | ```
88 | ```shell=bash
89 | export MODEL=uniCOIL
90 | export DLRDIM=768
91 | export CORPUS=msmarco-passage
92 | export VECTOR_DIR=./collections/msmarco-passage-unicoil-b8
93 | export DLR_PATH=${MODEL}_DIM${DLRDIM}
94 | export SPLIT=dev.small
95 | ```
96 | ## Sparse vector densification
97 | We now start to densify corpus and queries.
98 | ```shell=bash
99 | python -m densify.densify_corpus \
100 | --model ${MODEL} \
101 | --prefix ${CORPUS} \
102 | --vector_dir ${VECTOR_DIR} \
103 | --output_dir ${DLR_PATH} \
104 | --output_dims ${DLRDIM}
105 |
106 | python -m densify.densify_query \
107 | --model ${MODEL} \
108 | --prefix ${CORPUS} \
109 | --query_path ${Q_DIR}/queries.${SPLIT}.tsv \ \
110 | --output_dir ${DLR_PATH} \
111 | --output_dims ${DLRDIM} \
112 | ```
113 |
114 | We then merge index and start DLR search.
115 | ```shell=bash
116 | # Merge index
117 | python -m retrieval.index \
118 | --index_path ${DLR_PATH}/encoding \
119 | --index_prefix ${CORPUS} \
120 |
121 | mkdir ${DLR_PATH}/encoding/index
122 | mv ${DLR_PATH}/encoding/${CORPUS}.index.pt ${DLR_PATH}/encoding/index/
123 |
124 | # Search
125 | python -m retrieval.gip_retrieval \
126 | --query_emb_path ${DLR_PATH}/encoding/queries/queries.${CORPUS}.${SPLIT}.pt \
127 | --emb_dim ${DLRDIM} \
128 | --index_path ${DLR_PATH}/encoding/index/${CORPUS}.index.pt \
129 | --theta 1 \
130 | --rerank \
131 | --use_gpu \
132 | ```
--------------------------------------------------------------------------------
/docs/dhr/msmarco-passage-train-eval.md:
--------------------------------------------------------------------------------
1 | # Training and Inference on MSMARCO Passage ranking
2 | In the following, we describe how to train, encode and retrieve with DHR on MS MARCO passage-v1.
3 | 1. [MS MARCO Passage-v1 Data Preparation](#msmarco_data_prep)
4 | 1. [Training](#training)
5 | 1. [Generate Passage and Query Embeddings](#generate_embeddings)
6 | 1. [End-To-End Retrieval](#retrieval)
7 | 1. [Retrieval on GPU](#retrieval_on_gpu)
8 | 1. [Retrieval on CPU](#retrieval_on_cpu)
9 | 1. [Evaluation](#evaluation)
10 |
11 |
12 | ## MS MARCO Passage-v1 Data Preparation
13 | We first preprocess the corpus, development queries and official training data in the json format. Each passage in the corpus is a line with the format: `{"text_id": passage_id, "text": [vocab_ids]}`. Similarly, each query in the development set is a line with the format: `{"text_id": query_id, "text": [vocab_ids]}`. As for training data, we rearrange the official training data in the format: `{"query": [vocab_ids], "positive_pids": [positive_passage_id0, positive_passage_id1, ...], "negative_pids": [negative_passage_id0, negative_passage_id1, ...]}`. Note that we use string type for passage and query. You can also download our preprocessed data on huggingface hub: [official_train](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus), [queries](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_queries) and [corpus](https://huggingface.co/datasets/jacklin/msmarco_passage_ranking_corpus).
14 |
15 | ## Training
16 | This below script is the DHR (DLR) training in our paper. You can simply switch ${MODEL} from DHR to DLR, and the option `--combine_cls` would be turned off automatically.
17 | ```shell=bash
18 | export CUDA_VISIBLE_DEVICES=0
19 | export MODEL=DHR
20 | export CLSDIM=128
21 | export DLRDIM=768
22 | export MODEL_DIR=${MODEL}_CLS${CLSDIM}
23 | export DATA_DIR=need_your_assignment
24 |
25 | python -m tevatron.driver.train \
26 | --output_dir ${MODEL_DIR} \
27 | --train_dir ${DATA_DIR}/official_train \
28 | --corpus_dir ${DATA_DIR}/corpus \
29 | --model_name_or_path distilbert-base-uncased \
30 | --do_train \
31 | --save_steps 20000 \
32 | --fp16 \
33 | --per_device_train_batch_size 24 \
34 | --learning_rate 7e-6 \
35 | --q_max_len 32 \
36 | --p_max_len 150 \
37 | --num_train_epochs 6 \
38 | --add_pooler \
39 | --model ${MODEL} \
40 | --projection_out_dim ${CLSDIM} \
41 | --train_n_passages 8 \
42 | --dataloader_num_workers 8 \
43 | --combine_cls \
44 | ```
45 |
46 | ## Generate Passage and Query Embeddings
47 | ```
48 | export CUDA_VISIBLE_DEVICES=0
49 | export MODEL=DHR #place DHR for DeLADE+[CLS] and DLR for DeLADE
50 | export CLSDIM=128
51 | export DLRDIM=768
52 | export MODEL_DIR=${MODEL}_CLS${CLSDIM}
53 | export CORPUS=msmarco-passage
54 | export SPLIT=dev.small
55 | export INDEX_DIR=${MODEL_DIR}/encoding${DLRDIM}
56 | export DATA_DIR=need_your_assignment
57 |
58 | # Corpus
59 | for i in $(seq -f "%02g" 0 10)
60 | do
61 | echo '============= Inference doc.split '${i} ' ============='
62 | srun --gres=gpu:p100:1 --mem=16G --cpus-per-task=2 --time=1:40:00 \
63 | python -m tevatron.driver.encode \
64 | --output_dir ${MODEL_DIR} \
65 | --model_name_or_path ${MODEL_DIR} \
66 | --add_pooler \
67 | --projection_out_dim ${CLSDIM} \
68 | --dlr_out_dim ${DLRDIM} \
69 | --combine_cls \
70 | --model ${MODEL} \
71 | --fp16 \
72 | --p_max_len 150 \
73 | --per_device_eval_batch_size 128 \
74 | --encode_in_path ${DATA_DIR}/corpus/split${i}.json \
75 | --encoded_save_path ${INDEX_DIR}/${CORPUS}.split${i}.pt &
76 | done
77 |
78 | # Merge index
79 | python -m retrieval.index \
80 | --index_path ${INDEX_DIR} \
81 | --index_prefix ${CORPUS}
82 | mkdir ${INDEX_DIR}/index
83 | mv ${INDEX_DIR}/${CORPUS}.index.pt ${INDEX_DIR}/index/
84 |
85 | # Queries
86 | for SPLIT in dev.small
87 | do
88 | mkdir ${INDEX_DIR}/queries
89 | python -m tevatron.driver.encode \
90 | --output_dir ${MODEL_DIR} \
91 | --model_name_or_path ${MODEL_DIR} \
92 | --fp16 \
93 | --q_max_len 32 \
94 | --model ${MODEL} \
95 | --encode_is_qry \
96 | --combine_cls \
97 | --add_pooler \
98 | --projection_out_dim ${CLSDIM} \
99 | --dlr_out_dim ${DLRDIM} \
100 | --per_device_eval_batch_size 128 \
101 | --encode_in_path ${DATA_DIR}/queries/queries.${SPLIT}.json \
102 | --encoded_save_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt
103 | done
104 | ```
105 |
106 | ## End-to-end Retrieval
107 | ### Retrieval on GPU
108 | If you want to use GPU for retrieval, we suggest to use our implemented two-stage retrieval.
109 | ```
110 | # GIP retrieval
111 | for shrad in 0
112 | do
113 | echo 'run shrad'$shrad
114 | python -m retrieval.gip_retrieval \
115 | --query_emb_path ${INDEX_DIR}/queries/queries.${CORPUS}.${SPLIT}.pt \
116 | --emb_dim ${DLRDIM} \
117 | --index_path ${INDEX_DIR}/index/${CORPUS}.index.pt \
118 | --topk 1000 \
119 | --total_shrad 1 \
120 | --shrad $shrad \
121 | --theta 0.3 \
122 | --rerank \
123 | --use_gpu \
124 | --combine_cls \
125 | done
126 | ```
127 |
128 | ### Retrieval on CPU
129 | If you only have CPU, we suggest to first quanize the index; then, use our implemented two-stage retrieval.
130 | ```
131 | # index quanization
132 | python -m retrieval.quantize_index \
133 | --index_path ${INDEX_PATH}/index/${CORPUS}.index.pt \
134 | --output_index_path ${INDEX_PATH}/index/${CORPUS}.pq64.faiss.index \
135 | --qauntized_dim 64
136 |
137 | # GIP retrieval
138 | python -m retrieval.gip_retrieval \
139 | --query_emb_path ${INDEX_PATH}/queries/queries.${CORPUS}.${SPLIT}.pt \
140 | --index_path ${INDEX_PATH}/index/${CORPUS}.index.pt \
141 | --faiss_pq_index_path ${INDEX_PATH}/index/${CORPUS}.pq64.faiss.index \
142 | --emb_dim ${DLRDIM} \
143 | --topk 1000 \
144 | --lamda 1 \
145 | --batch 1 \
146 | --PQIP \
147 | --rerank
148 | ```
149 |
150 | ## Evaluation
151 | The run file, result.trec, is in the trec format so that you can directly evaluate the result using pyserini.
152 | ```
153 | python -m pyserini.eval.trec_eval -c -M 10 -m recip_rank ${QREL_PATH} result.trec
154 | python -m pyserini.eval.trec_eval -c -m recall.1000 ${QREL_PATH} result.trec
155 | ```
156 |
157 |
158 |
--------------------------------------------------------------------------------
/fig/aggretriever.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/aggretriever.png
--------------------------------------------------------------------------------
/fig/aggretriever_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/aggretriever_teaser.png
--------------------------------------------------------------------------------
/fig/densification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/densification.png
--------------------------------------------------------------------------------
/fig/single_model_fusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/fig/single_model_fusion.png
--------------------------------------------------------------------------------
/retrieval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/retrieval/__init__.py
--------------------------------------------------------------------------------
/retrieval/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/retrieval/evaluation/__init__.py
--------------------------------------------------------------------------------
/retrieval/evaluation/custom_metrics.py:
--------------------------------------------------------------------------------
1 | ## copy from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/custom_metrics.py
2 | import logging
3 | from typing import List, Dict, Union, Tuple
4 |
5 | def mrr(qrels: Dict[str, Dict[str, int]],
6 | results: Dict[str, Dict[str, float]],
7 | k_values: List[int]) -> Tuple[Dict[str, float]]:
8 |
9 | MRR = {}
10 |
11 | for k in k_values:
12 | MRR[f"MRR@{k}"] = 0.0
13 |
14 | k_max, top_hits = max(k_values), {}
15 | logging.info("\n")
16 |
17 | for query_id, doc_scores in results.items():
18 | top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
19 |
20 | for query_id in top_hits:
21 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
22 | for k in k_values:
23 | for rank, hit in enumerate(top_hits[query_id][0:k]):
24 | if hit[0] in query_relevant_docs:
25 | MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
26 | break
27 |
28 | for k in k_values:
29 | MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5)
30 | logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"]))
31 |
32 | return MRR
33 |
34 | def recall_cap(qrels: Dict[str, Dict[str, int]],
35 | results: Dict[str, Dict[str, float]],
36 | k_values: List[int]) -> Tuple[Dict[str, float]]:
37 |
38 | capped_recall = {}
39 |
40 | for k in k_values:
41 | capped_recall[f"R_cap@{k}"] = 0.0
42 |
43 | k_max = max(k_values)
44 | logging.info("\n")
45 |
46 | for query_id, doc_scores in results.items():
47 | top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
48 | query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
49 | for k in k_values:
50 | retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0]
51 | denominator = min(len(query_relevant_docs), k)
52 | capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator)
53 |
54 | for k in k_values:
55 | capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5)
56 | logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"]))
57 |
58 | return capped_recall
59 |
60 |
61 | def hole(qrels: Dict[str, Dict[str, int]],
62 | results: Dict[str, Dict[str, float]],
63 | k_values: List[int]) -> Tuple[Dict[str, float]]:
64 |
65 | Hole = {}
66 |
67 | for k in k_values:
68 | Hole[f"Hole@{k}"] = 0.0
69 |
70 | annotated_corpus = set()
71 | for _, docs in qrels.items():
72 | for doc_id, score in docs.items():
73 | annotated_corpus.add(doc_id)
74 |
75 | k_max = max(k_values)
76 | logging.info("\n")
77 |
78 | for _, scores in results.items():
79 | top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
80 | for k in k_values:
81 | hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus]
82 | Hole[f"Hole@{k}"] += len(hole_docs) / k
83 |
84 | for k in k_values:
85 | Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5)
86 | logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"]))
87 |
88 | return Hole
89 |
90 | def top_k_accuracy(
91 | qrels: Dict[str, Dict[str, int]],
92 | results: Dict[str, Dict[str, float]],
93 | k_values: List[int]) -> Tuple[Dict[str, float]]:
94 |
95 | top_k_acc = {}
96 |
97 | for k in k_values:
98 | top_k_acc[f"Accuracy@{k}"] = 0.0
99 |
100 | k_max, top_hits = max(k_values), {}
101 | logging.info("\n")
102 |
103 | for query_id, doc_scores in results.items():
104 | top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]]
105 |
106 | for query_id in top_hits:
107 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
108 | for k in k_values:
109 | for relevant_doc_id in query_relevant_docs:
110 | if relevant_doc_id in top_hits[query_id][0:k]:
111 | top_k_acc[f"Accuracy@{k}"] += 1.0
112 | break
113 |
114 | for k in k_values:
115 | top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
116 | logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))
117 |
118 | return top_k_acc
--------------------------------------------------------------------------------
/retrieval/gip_retrieval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | import numpy as np
5 | import math
6 | from tqdm import tqdm
7 | from multiprocessing import Pool, Manager
8 | import pickle5 as pickle
9 | import torch
10 | import torch.nn as nn
11 | import time
12 | import faiss
13 |
14 | def faiss_search(query_embs, corpus_embs, batch=1, topk=1000):
15 | print('start faiss index')
16 | query_embs = np.concatenate([query_embs,query_embs], axis=1)
17 | corpus_embs = np.concatenate([corpus_embs,corpus_embs], axis=1)
18 |
19 | dimension = query_embs.shape[1]
20 | res = faiss.StandardGpuResources()
21 | res.noTempMemory()
22 | # res.setTempMemory(1000 * 1024 * 1024) # 1G GPU memory for serving query
23 | flat_config = faiss.GpuIndexFlatConfig()
24 | flat_config.device = 0
25 | flat_config.useFloat16=True
26 | index = faiss.GpuIndexFlatIP(res, dimension, flat_config)
27 |
28 | print("Load index to GPU...")
29 | index.add(corpus_embs)
30 |
31 | Distance = []
32 | Index = []
33 | print("Search with batch size %d"%(batch))
34 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(),
35 | ' ', ETA(), ' ', FileTransferSpeed()]
36 | pbar = ProgressBar(widgets=widgets, maxval=query_embs.shape[0]//batch).start()
37 | start_time = time.time()
38 |
39 | for i in range(query_embs.shape[0]//batch):
40 | D,I=index.search(query_embs[i*batch:(i+1)*batch], topk)
41 |
42 |
43 | Distance.append(D)
44 | Index.append(I)
45 | pbar.update(i + 1)
46 |
47 |
48 | D,I=index.search(query_embs[(i+1)*batch:], topk)
49 |
50 |
51 | Distance.append(D)
52 | Index.append(I)
53 |
54 | time_per_query = (time.time() - start_time)/query_embs.shape[0]
55 | print('Retrieving {} queries ({:0.3f} s/query)'.format(query_embs.shape[0], time_per_query))
56 | Distance = np.concatenate(Distance, axis=0)
57 | Index = np.concatenate(Index, axis=0)
58 | return Distance, Index
59 |
60 | def IP_retrieval(qids, query_embs, corpus_embs, args):
61 |
62 | description = 'Brute force IP search'
63 |
64 |
65 | all_results = {}
66 | all_scores = {}
67 |
68 | start_time = time.time()
69 | total_num_idx = 0
70 | for i, (query_emb) in tqdm(enumerate(query_embs), total=len(query_embs), desc=description):
71 |
72 |
73 |
74 | scores = torch.einsum('ij,j->i',(corpus_embs, query_emb))
75 | sort_candidates = torch.argsort(scores, descending=True)[:args.topk]
76 | sort_scores = scores[sort_candidates]
77 |
78 | all_scores[qids[i]]=sort_scores.cpu().tolist()
79 | all_results[qids[i]]=sort_candidates.cpu().tolist()
80 |
81 | average_num_idx = total_num_idx/query_embs.shape[0]
82 | time_per_query = (time.time() - start_time)/query_embs.shape[0]
83 | print('Retrieving {} queries ({:0.3f} s/query), average number of index use {}'.format(query_embs.shape[0], time_per_query, average_num_idx))
84 |
85 | return all_results, all_scores
86 |
87 |
88 | def GIP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs, args):
89 | if args.brute_force:
90 | args.theta = 0
91 | description = 'Brute force GIP search'
92 | else:
93 | if not args.IP:
94 | if args.rerank:
95 | description = 'GIP (\u03F4={}) retrieval w/ GIP rerank'.format(args.theta)
96 | else:
97 | description = 'GIP (\u03F4={}) retrieval w/o GIP rerank'.format(args.theta)
98 | else:
99 | if args.rerank:
100 | description = 'IP retrieval w/ GIP rerank'
101 | else:
102 | description = 'IP retrieval w/o GIP rerank'
103 |
104 | all_results = {}
105 | all_scores = {}
106 |
107 | start_time = time.time()
108 | total_num_idx = 0
109 |
110 | cls_dim = query_embs.shape[1] - args.emb_dim
111 | if cls_dim > 0:
112 | query_arg_idxs = torch.nn.functional.pad(query_arg_idxs, (0, cls_dim), mode='constant', value=1)
113 | corpus_arg_idxs = torch.nn.functional.pad(corpus_arg_idxs, (0, cls_dim), mode='constant', value=1)
114 |
115 | for i, (query_emb, query_arg_idx) in tqdm(enumerate(zip(query_embs, query_arg_idxs)), total=len(query_embs), desc=description):
116 |
117 | if args.theta==0:
118 | total_num_idx += args.emb_dim
119 | candidate_sparse_embs = ((corpus_arg_idxs==query_arg_idx)*corpus_embs)
120 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb))
121 | del candidate_sparse_embs
122 |
123 | sort_idx = torch.topk(scores, args.topk, dim=0).indices
124 | sort_candidates = sort_idx
125 | sort_scores = scores[sort_idx]
126 | torch.cuda.empty_cache()
127 |
128 | else:
129 |
130 | num_idx = int((query_emb > args.theta).sum())
131 | important_idx = torch.topk(query_emb, num_idx, dim=0).indices.tolist()
132 |
133 | if not args.IP:
134 | # Approximate GIP
135 | candidate_sparse_embs = ( (corpus_arg_idxs[:,important_idx]==query_arg_idx[important_idx]) * corpus_embs[:,important_idx] )
136 | partial_scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb[important_idx]))
137 | else:
138 | # IN as an approximation
139 | partial_scores = torch.einsum('ij,j->i',(corpus_embs, query_emb))
140 |
141 | if args.rerank:
142 | candidates = torch.topk(partial_scores, args.agip_topk, dim=0).indices
143 |
144 | candidate_sparse_embs = ((corpus_arg_idxs[candidates,:]==query_arg_idx)*corpus_embs[candidates])
145 |
146 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb))
147 |
148 | sort_idx = torch.topk(scores, args.topk, dim=0).indices
149 | sort_candidates = candidates[sort_idx]
150 | sort_scores = scores[sort_idx]
151 |
152 | del important_idx, candidates, candidate_sparse_embs, scores, sort_idx
153 | torch.cuda.empty_cache()
154 | else:
155 | sort_candidates = torch.topk(partial_scores, args.topk, dim=0).indices
156 | sort_scores = partial_scores[sort_candidates]
157 |
158 | all_scores[qids[i]]=sort_scores.cpu().tolist()
159 | all_results[qids[i]]=sort_candidates.cpu().tolist()
160 |
161 | average_num_idx = total_num_idx/query_embs.shape[0]
162 | time_per_query = (time.time() - start_time)/query_embs.shape[0]
163 | print('Retrieving {} queries ({:0.3f} s/query), average number of index use {}'.format(query_embs.shape[0], time_per_query, average_num_idx))
164 |
165 | return all_results, all_scores
166 |
167 | def PQ_IP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs, args):
168 | assert args.faiss_pq_index_path is not None, 'you do not spesify your PQ index through --faiss_pq_index_path'
169 | print('Load PQ index ...')
170 | faiss_index = faiss.read_index(args.faiss_pq_index_path)
171 |
172 | if args.rerank:
173 | description = 'IP (Product Quantization) search w/ GIP rerank'
174 | else:
175 | description = 'IP (Product Quantization) search w/o GIP rerank'
176 |
177 | all_results = {}
178 | all_scores = {}
179 |
180 | cls_dim = query_embs.shape[1] - query_arg_idxs.shape[1]
181 | if cls_dim > 0:
182 | query_arg_idxs = torch.nn.functional.pad(query_arg_idxs, (0, cls_dim), mode='constant', value=1)
183 | corpus_arg_idxs = torch.nn.functional.pad(corpus_arg_idxs, (0, cls_dim), mode='constant', value=1)
184 |
185 | if len(query_embs)%args.batch == 0:
186 | total_batch = len(query_embs)//args.batch
187 | else:
188 | total_batch = len(query_embs)//args.batch + 1
189 |
190 | start_time = time.time()
191 | for i in tqdm(range(total_batch), total=total_batch, desc=description):
192 |
193 | if i == (total_batch -1):
194 | batch_query_embs = query_embs[i*args.batch:]
195 | batch_query_arg_idxs = query_arg_idxs[i*args.batch:]
196 | batch_qids = qids[i*args.batch:]
197 | else:
198 | batch_query_embs = query_embs[i*args.batch:(i+1)*args.batch]
199 | batch_query_arg_idxs = query_arg_idxs[i*args.batch:(i+1)*args.batch]
200 | batch_qids = qids[i*args.batch:(i+1)*args.batch]
201 |
202 | scores, candidates = faiss_index.search(batch_query_embs.numpy(), args.agip_topk)
203 |
204 |
205 | for i, (qid, query_emb, query_arg_idx, candidate) in enumerate(zip(batch_qids, batch_query_embs, batch_query_arg_idxs, candidates)):
206 | if args.rerank:
207 | candidate_sparse_embs = ((corpus_arg_idxs[candidate,:]==query_arg_idx)*corpus_embs[candidate])
208 | scores = torch.einsum('ij,j->i',(candidate_sparse_embs, query_emb))
209 |
210 | sort_idx = torch.topk(scores, args.topk, dim=0).indices
211 | sort_candidates = candidate[sort_idx]
212 | sort_scores = scores[sort_idx]
213 |
214 | all_scores[qid] = sort_scores.tolist()
215 | all_results[qid] = sort_candidates.tolist()
216 |
217 | # del candidates, candidate_sparse_embs, scores, sort_idx
218 | else:
219 | # sort_candidates = torch.argsort(partial_scores, descending=True)[:args.topk]
220 | all_scores[qid] = scores[i, :args.topk].tolist()
221 | all_results[qid] = candidates[i, :args.topk].tolist()
222 |
223 |
224 | torch.cuda.empty_cache()
225 |
226 |
227 |
228 | time_per_query = (time.time() - start_time)/query_embs.shape[0]
229 | print('Retrieving {} queries ({:0.3f} s/query)'.format(query_embs.shape[0], time_per_query))
230 |
231 | return all_results, all_scores
232 |
233 | def main():
234 | parser = argparse.ArgumentParser()
235 | parser.add_argument("--query_emb_path", type=str, required=True)
236 | parser.add_argument("--index_path", type=str, required=True)
237 | parser.add_argument("--faiss_pq_index_path", type=str, default=None)
238 | parser.add_argument("--emb_dim", type=int, default=768, help='DLR dimension')
239 | parser.add_argument("--theta", type=float, default=0.1)
240 | parser.add_argument("--topk", type=int, default=1000)
241 | parser.add_argument("--agip_topk", type=int, default=10000)
242 | parser.add_argument("--combine_cls", action='store_true')
243 | parser.add_argument("--IP", action='store_true')
244 | parser.add_argument("--PQIP", action='store_true')
245 | parser.add_argument("--batch", type=int, default=1)
246 | parser.add_argument("--brute_force", action='store_true')
247 | parser.add_argument("--use_gpu", action='store_true')
248 | parser.add_argument("--rerank", action='store_true')
249 | parser.add_argument("--lamda", type=float, default=1, help='weight for [CSL] for concatenation')
250 | parser.add_argument("--total_shrad", type=int, default=1)
251 | parser.add_argument("--shrad", type=int, default=0)
252 | parser.add_argument("--run_name", type=str, default='h2oloo')
253 | args = parser.parse_args()
254 |
255 | if not args.use_gpu:
256 | if args.batch > 1:
257 | torch.set_num_threads(72)
258 | else:
259 | torch.set_num_threads(1)
260 | else:
261 | torch.cuda.set_device(0)
262 |
263 | # load query embeddings
264 | print('Load query embeddings ...')
265 | with open(args.query_emb_path, 'rb') as f:
266 | query_embs, query_arg_idxs, qids=pickle.load(f)
267 |
268 | if args.use_gpu:
269 | query_embs = torch.from_numpy(query_embs).cuda(0)
270 | try:
271 | query_arg_idxs = torch.from_numpy(query_arg_idxs).cuda(0)
272 | except:
273 | query_arg_idxs = None
274 | else:
275 | query_embs = torch.from_numpy(query_embs.astype(np.float32))
276 | try:
277 | query_arg_idxs = torch.from_numpy(query_arg_idxs)
278 | except:
279 | query_arg_idxs = None
280 |
281 | cls_dim = query_embs.shape[1] - args.emb_dim
282 | if cls_dim > 0:
283 | query_embs[:,-cls_dim:] = args.lamda * query_embs[:,-cls_dim:]
284 |
285 |
286 |
287 | # load index
288 | print('Load index ...')
289 | with open(args.index_path, 'rb') as f:
290 | corpus_embs, corpus_arg_idxs, docids=pickle.load(f)
291 |
292 | doc_num_per_shrad = len(docids)//args.total_shrad
293 | if args.shrad==(args.total_shrad-1):
294 | corpus_embs = corpus_embs[doc_num_per_shrad*args.shrad:]
295 | try:
296 | corpus_arg_idxs = corpus_arg_idxs[doc_num_per_shrad*args.shrad:]
297 | except:
298 | corpus_arg_idxs = None
299 | docids = docids[doc_num_per_shrad*args.shrad:]
300 | else:
301 | corpus_embs = corpus_embs[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)]
302 | try:
303 | corpus_arg_idxs = corpus_arg_idxs[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)]
304 | except:
305 | corpus_arg_idxs = None
306 | docids = docids[doc_num_per_shrad*args.shrad:doc_num_per_shrad*(args.shrad+1)]
307 |
308 | if args.use_gpu:
309 | corpus_embs = torch.from_numpy(corpus_embs).cuda(0)
310 | if corpus_arg_idxs is not None:
311 | corpus_arg_idxs = torch.from_numpy(corpus_arg_idxs).cuda(0)
312 | else:
313 | corpus_embs = torch.from_numpy(corpus_embs.astype(np.float32))
314 | if corpus_arg_idxs is not None:
315 | corpus_arg_idxs = torch.from_numpy(corpus_arg_idxs)
316 | # density = corpus_embs!=0
317 | # density = density.sum(axis=1)
318 | # print(torch.sum(density)/8841823/args.emb_dim)
319 |
320 |
321 | if query_arg_idxs is not None:
322 | if not args.PQIP:
323 | results, scores = GIP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs ,args)
324 | else:
325 | results, scores = PQ_IP_retrieval(qids, query_embs, query_arg_idxs, corpus_embs, corpus_arg_idxs ,args)
326 | else:
327 | results, scores = IP_retrieval(qids, query_embs, corpus_embs, args)
328 |
329 | if args.total_shrad==1:
330 | fout = open('result.trec', 'w')
331 | else:
332 | fout = open('result{}.trec'.format(args.shrad), 'w')
333 | for i, query_id in tqdm(enumerate(results), total=len(results), desc=f"write results"):
334 | result = results[query_id]
335 | score = scores[query_id]
336 |
337 | for rank, docidx in enumerate(result):
338 |
339 | docid = docids[docidx]
340 | if (docid!=query_id):
341 | fout.write('{} Q0 {} {} {} {}\n'.format(query_id, docid, rank+1, score[rank], args.run_name))
342 | fout.close()
343 |
344 | print('finish')
345 |
346 |
347 | if __name__ == "__main__":
348 | main()
--------------------------------------------------------------------------------
/retrieval/index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | # os.environ['OMP_NUM_THREADS'] = str(32)
5 | import numpy as np
6 | import math
7 | from progressbar import *
8 | # from util import load_tfrecords_and_index, read_id_dict, faiss_index
9 | from multiprocessing import Pool, Manager
10 | import pickle
11 | import torch
12 | import torch.nn as nn
13 | import time
14 |
15 |
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--index_prefix", type=str, default='msmarco-passage')
21 | parser.add_argument("--emb_dim", type=int, default=768)
22 | parser.add_argument("--index_path", type=str, required=True)
23 | args = parser.parse_args()
24 |
25 | ## Merge index
26 | corpus_files = glob.glob(os.path.join(args.index_path, args.index_prefix + '.split*.pt'))
27 |
28 | corpus_embs = []
29 | corpus_arg_idxs = []
30 | docids = []
31 | for corpus_file in corpus_files:
32 | with open(corpus_file, 'rb') as f:
33 | print('Load index: {}...'.format(corpus_file))
34 | corpus_emb, corpus_arg_idx, docid=pickle.load(f)
35 | corpus_embs.append(corpus_emb)
36 | corpus_arg_idxs.append(corpus_arg_idx)
37 | docids += docid
38 |
39 | print('Merge index ...')
40 | try:
41 | corpus_arg_idxs = np.concatenate(corpus_arg_idxs, axis=0)
42 | except:
43 | corpus_arg_idxs = 0
44 | corpus_embs = np.concatenate(corpus_embs, axis=0)
45 |
46 | with open(os.path.join(args.index_path, args.index_prefix + '.index.pt'), 'wb') as f:
47 | pickle.dump([corpus_embs, corpus_arg_idxs, docids], f, protocol=4)
48 |
49 |
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/retrieval/merge.result.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import glob
4 | import os
5 | import numpy as np
6 | from collections import defaultdict
7 | from progressbar import *
8 |
9 |
10 |
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--total_shrad", type=int, default=1)
16 | parser.add_argument("--topk", type=int, default=1000)
17 | parser.add_argument("--run_name", default='dhr')
18 |
19 | args = parser.parse_args()
20 | results = defaultdict(list)
21 | scores = defaultdict(list)
22 | for shrad in range(args.total_shrad):
23 | with open('result{:02d}.trec'.format(shrad), 'r') as f:
24 | for line in f:
25 | query_id, _, docid, rank, score, _ = line.strip().split(' ')
26 | score = float(score)
27 | results[query_id].append(docid)
28 | scores[query_id].append(score)
29 |
30 |
31 | print('write results ...')
32 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(),
33 | ' ', ETA(), ' ', FileTransferSpeed()]
34 | pbar = ProgressBar(widgets=widgets, maxval=10*len(results)).start()
35 | fout = open('result.trec', 'w')
36 | for i, query_id in enumerate(results):
37 | score = scores[query_id]
38 | result = results[query_id]
39 | sort_idx = np.array(score).argsort()[::-1][:args.topk]
40 | for rank, idx in enumerate(sort_idx):
41 | fout.write('{} Q0 {} {} {} {}\n'.format(query_id, result[idx], rank+1, score[idx], args.run_name))
42 | pbar.update(10 * i + 1)
43 | fout.close()
44 |
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
--------------------------------------------------------------------------------
/retrieval/quantize_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import pickle5 as pickle
5 | import faiss
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--index_path", type=str, required=True)
11 | parser.add_argument("--output_index_path", type=str, default=None)
12 | parser.add_argument("--qauntized_dim", type=int, default=64)
13 | parser.add_argument("--n_bits", type=int, default=8)
14 | args = parser.parse_args()
15 |
16 | if args.output_index_path is None:
17 | # assign to index dir
18 | index_dir = '/'.join(index_path.split('/')[:-1])
19 | args.output_index_path = os.path.join(index_dir, 'pq{}_index'.format(args.qauntized_dim))
20 |
21 | # load index
22 | print('Load index ...')
23 | with open(args.index_path, 'rb') as f:
24 | corpus_embs, corpus_arg_idxs, docids=pickle.load(f)
25 | corpus_embs = corpus_embs.astype(np.float32)
26 |
27 | faiss.omp_set_num_threads(36)
28 | print('build PQ index...')
29 | index = faiss.IndexPQ(corpus_embs.shape[1], args.qauntized_dim, args.n_bits, faiss.METRIC_INNER_PRODUCT)
30 | index.verbose = True
31 |
32 | print('train PQ...')
33 | index.train(corpus_embs)
34 | print('build index...')
35 | index.add(corpus_embs)
36 | print('write index to {}'.format(args.output_index_path))
37 | faiss.write_index(index, args.output_index_path)
38 | print('finish')
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
43 |
--------------------------------------------------------------------------------
/retrieval/rcap_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .evaluation.custom_metrics import recall_cap
3 |
4 | def main():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--qrel_file_path", type=str, required=True)
7 | parser.add_argument("--run_file_path", type=str, required=True)
8 | parser.add_argument("--cutoff", type=int, default=100, required=False)
9 | args = parser.parse_args()
10 |
11 | qrels = {}
12 | with open(args.qrel_file_path, 'r') as f:
13 | for line in f:
14 | qid, _, docid, rel = line.strip().split('\t')
15 | if qid not in qrels:
16 | qrels[qid] = {docid: int(rel)}
17 | else:
18 | qrels[qid][docid] = int(rel)
19 |
20 | results = {}
21 | with open(args.run_file_path, 'r') as f:
22 | for line in f:
23 | qid, _, docid, rank, score, _ = line.strip().split(' ')
24 | if qid not in results:
25 | results[qid] = {docid: float(score)}
26 | else:
27 | results[qid][docid] = float(score)
28 |
29 | print(recall_cap(qrels, results, [args.cutoff]))
30 |
31 | if __name__ == "__main__":
32 | main()
--------------------------------------------------------------------------------
/retrieval/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | # import mkl
5 | # mkl.set_num_threads(16)
6 | import numpy as np
7 | import tensorflow.compat.v1 as tf
8 | from numpy import linalg as LA
9 | from progressbar import *
10 | from collections import defaultdict
11 | import glob
12 | from scipy.sparse import csc_matrix
13 | import gzip
14 | import json
15 |
16 | def read_pickle(filename):
17 | with open(filename, 'rb') as f:
18 | Distance, Index=pickle.load(f)
19 | return Distance, Index
20 |
21 |
22 | def read_id_dict(path):
23 | if os.path.isdir(path):
24 | files = glob.glob(os.path.join(path, '*.id'))
25 | else:
26 | files = [path]
27 |
28 | idx_to_id = {}
29 | id_to_idx = {}
30 | for file in files:
31 | f = open(file, 'r')
32 | for i, line in enumerate(f):
33 | try:
34 | idx, Id =line.strip().split('\t')
35 | idx_to_id[int(idx)] = Id
36 | id_to_idx[Id] = int(idx)
37 | except:
38 | Id = line.strip()
39 | idx_to_id[i] = Id
40 | # if len(Id.split(' '))==1:
41 |
42 | # else:
43 | # print(line+' has no id')
44 | return idx_to_id, id_to_idx
45 |
46 | def write_result(qidxs, Index, Score, file, idx_to_qid, idx_to_docid, topk=None, run_name='Faiss'):
47 | print('write results...')
48 | with open(file, 'w') as fout:
49 | for i, qidx in enumerate(qidxs):
50 | try:
51 | qid = idx_to_qid[qidx]
52 | except:
53 | qid = qidx
54 | if topk==None:
55 | docidxs=Index[i]
56 | scores=Score[i]
57 | for rank, docidx in enumerate(docidxs):
58 | try:
59 | docid = idx_to_docid[docidx]
60 | except:
61 | docid = docidx
62 | fout.write('{} Q0 {} {} {} {}\n'.format(qid, docid, rank + 1, scores[rank], run_name))
63 | else:
64 | try:
65 | hit=min(topk, len(Index[i]))
66 | except:
67 | print('debug')
68 |
69 | docidxs=Index[i]
70 | scores=Score[i]
71 | for rank, docidx in enumerate(docidxs[:hit]):
72 | try:
73 | docid = idx_to_docid[docidx]
74 | except:
75 | docid = docidx
76 | fout.write('{} Q0 {} {} {} {}\n'.format(qid, docid, rank + 1, scores[rank], run_name))
77 |
78 |
79 | def faiss_index(corpus_embs, docids, save_path, index_method):
80 |
81 | dimension=corpus_embs.shape[1]
82 | print("Indexing ...")
83 | if index_method==None or index_method=='flatip':
84 | cpu_index = faiss.IndexFlatIP(dimension)
85 |
86 | elif index_method=='hsw':
87 | cpu_index = faiss.IndexHNSWFlat(dimension, 256, faiss.METRIC_INNER_PRODUCT)
88 | cpu_index.hnsw.efConstruction = 256
89 | elif index_method=='quantize': # still try better way for balanced efficiency and effectiveness
90 | cpu_index = faiss.IndexHNSWPQ(dimension, 192, 256)
91 | cpu_index.hnsw.efConstruction = 256
92 | cpu_index.metric_type = faiss.METRIC_INNER_PRODUCT
93 | # ncentroids = 1000
94 | # code_size = dimension//4
95 | # cpu_index = faiss.IndexIVFPQ(cpu_index, dimension, ncentroids, code_size, 8)
96 | # cpu_index = faiss.IndexPQ(dimension, code_size, 8)
97 | # cpu_index = faiss.index_factory(768, "OPQ128,IVF4096,PQ128", faiss.METRIC_INNER_PRODUCT)
98 | # cpu_index = faiss.IndexIDMap(cpu_index)
99 | # cpu_index = faiss.GpuIndexScalarQuantizer(dimension, faiss.ScalarQuantizer.QT_16bit_direct, faiss.METRIC_INNER_PRODUCT)
100 |
101 |
102 | cpu_index.verbose = True
103 | cpu_index.add(corpus_embs)
104 | if index_method=='quantize':
105 | print("Train index...")
106 | cpu_index.train(corpus_embs)
107 | print("Save Index {}...".format(save_path))
108 | faiss.write_index(cpu_index, save_path)
109 |
110 | def save_pickle(corpus_embs, arg_idxs, docids, filename):
111 | print('save pickle...')
112 | with open(filename, 'wb') as f:
113 | pickle.dump([corpus_embs, arg_idxs, docids], f, protocol=4)
114 |
115 | def load_tfrecords_and_index(srcfiles, data_num, word_num, dim, data_type, add_cls, index=False, save_path=None, batch=10000):
116 | def _parse_function(example_proto):
117 | features = {'doc_emb': tf.FixedLenFeature([],tf.string) , #tf.FixedLenSequenceFeature([],tf.string, allow_missing=True),
118 | 'argx_id_id': tf.FixedLenFeature([],tf.string) ,
119 | 'docid': tf.FixedLenFeature([],tf.int64)}
120 | parsed_features = tf.parse_single_example(example_proto, features)
121 | arg_idx = tf.decode_raw(parsed_features['argx_id_id'], tf.uint8)
122 | if data_type=='16':
123 | corpus = tf.decode_raw(parsed_features['doc_emb'], tf.float16)
124 | elif data_type=='32':
125 | corpus = tf.decode_raw(parsed_features['doc_emb'], tf.float32)
126 | docid = tf.cast(parsed_features['docid'], tf.int32)
127 | return corpus, arg_idx, docid
128 | print('Read embeddings...')
129 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(),
130 | ' ', ETA(), ' ', FileTransferSpeed()]
131 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start()
132 | with tf.Session() as sess:
133 | docids=[]
134 | if add_cls:
135 | segment=2
136 | else:
137 | segment=1
138 | #assign memory in advance so that we can save memory without concatenate
139 | arg_idxs = np.zeros((word_num*data_num*len(srcfiles) , dim), dtype=np.uint8)
140 | if (data_type=='16'): # Faiss now only support index array with float32
141 | corpus_embs = np.zeros((word_num*data_num*len(srcfiles) , dim*segment), dtype=np.float16)
142 | elif data_type=='32':
143 | corpus_embs = np.zeros((word_num*data_num*len(srcfiles) , dim*segment), dtype=np.float32)
144 | # else:
145 | # raise Exception('Please assign datatype 16 or 32 bits')
146 | counter = 0
147 | i = 0
148 |
149 | for srcfile in srcfiles:
150 | try:
151 | dataset = tf.data.TFRecordDataset(srcfile) # load tfrecord file
152 | except:
153 | print('Cannot find data')
154 | continue
155 | dataset = dataset.map(_parse_function) # parse data into tensor
156 | dataset = dataset.repeat(1)
157 | dataset = dataset.batch(batch)
158 | iterator = dataset.make_one_shot_iterator()
159 | next_data = iterator.get_next()
160 |
161 | while True:
162 | try:
163 | corpus_emb, arg_idx, docid = sess.run(next_data)
164 |
165 | corpus_emb = corpus_emb.reshape(-1, dim*segment)
166 |
167 | sent_num = corpus_emb.shape[0]
168 | corpus_embs[counter:(counter+sent_num)] = corpus_emb
169 | arg_idxs[counter:(counter+sent_num)] = arg_idx
170 |
171 | docids+=docid.tolist()
172 | counter+=sent_num
173 | pbar.update(10 * i + 1)
174 | i+=sent_num
175 | except tf.errors.OutOfRangeError:
176 | break
177 |
178 | docids = np.array(docids).reshape(-1)
179 | corpus_embs = (corpus_embs[:len(docids)])
180 | arg_idxs = (arg_idxs[:len(docids)])
181 | mask = docids!=-1
182 | docids = docids[mask]
183 | corpus_embs = corpus_embs[mask]
184 | arg_idxs = arg_idxs[mask]
185 | if index:
186 | save_pickle(corpus_embs, arg_idxs, docids, save_path)
187 | else:
188 | return corpus_embs, arg_idxs, docids
189 |
190 | def load_jsonl_and_index(srcfiles, data_num, dim, vocab_dict, data_type, add_cls, index=False, save_path=None, batch=10000):
191 | print('Count line...')
192 | data_num = 0
193 | for srcfile in srcfiles:
194 | with gzip.open(srcfile, 'rb') as f:
195 | for l in f:
196 | data_num+=1
197 | print('Total {} lines'.format(data_num))
198 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(),
199 | ' ', ETA(), ' ', FileTransferSpeed()]
200 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start()
201 | docids=[]
202 | if add_cls:
203 | segment=2
204 | else:
205 | segment=1
206 | #assign memory in advance so that we can save memory without concatenate
207 | arg_idxs = np.zeros((data_num , dim), dtype=np.uint8)
208 | if (data_type=='16'): # Faiss now only support index array with float32
209 | corpus_embs = np.zeros((data_num , dim*segment), dtype=np.float16)
210 | elif data_type=='32':
211 | corpus_embs = np.zeros((data_num , dim*segment), dtype=np.float32)
212 | # else:
213 | # raise Exception('Please assign datatype 16 or 32 bits')
214 | counter = 0
215 | i = 0
216 |
217 |
218 | for srcfile in srcfiles:
219 |
220 | with gzip.open(srcfile, "rb") as f:
221 | for line in f:
222 | data = json.loads(line.strip())
223 | embedding =np.zeros((30522), dtype=np.float16)
224 | for vocab, term_weight in data['vector'].items():
225 | embedding[vocab_dict[vocab]] = term_weight/100
226 |
227 | embedding = np.reshape(embedding[570:],(-1, dim))
228 | corpus_emb = embedding.max(0)
229 | arg_idx = embedding.argmax(0)
230 | docid = int(data['id'])
231 |
232 |
233 | corpus_emb = corpus_emb.reshape(-1, dim*segment)
234 |
235 | sent_num = corpus_emb.shape[0]
236 | corpus_embs[counter:(counter+sent_num)] = corpus_emb
237 | arg_idxs[counter:(counter+sent_num)] = arg_idx
238 |
239 | docids+=[docid]
240 | counter+=sent_num
241 | pbar.update(10 * i + 1)
242 | i+=sent_num
243 |
244 |
245 | docids = np.array(docids).reshape(-1)
246 | corpus_embs = (corpus_embs[:len(docids)])
247 | arg_idxs = (arg_idxs[:len(docids)])
248 | mask = docids!=-1
249 | docids = docids[mask]
250 | corpus_embs = corpus_embs[mask]
251 | arg_idxs = arg_idxs[mask]
252 | if index:
253 | save_pickle(corpus_embs, arg_idxs, docids, save_path)
254 | else:
255 | return corpus_embs, arg_idxs, docids
256 |
257 | def load_tfrecords_and_analyze(srcfiles, data_num, word_num, dim, data_type, batch=1):
258 | def _parse_function(example_proto):
259 | features = {#'doc_emb': tf.FixedLenFeature([],tf.string) , #tf.FixedLenSequenceFeature([],tf.string, allow_missing=True),
260 | 'id_p1': tf.FixedLenSequenceFeature([],tf.int64, allow_missing=True) ,
261 | 'docid': tf.FixedLenFeature([],tf.int64)}
262 | parsed_features = tf.parse_single_example(example_proto, features)
263 | vocab_ids = tf.cast(parsed_features['id_p1'], tf.int32)
264 | docid = tf.cast(parsed_features['docid'], tf.int32)
265 | return vocab_ids, docid
266 | print('Read embeddings...')
267 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(),
268 | ' ', ETA(), ' ', FileTransferSpeed()]
269 | pbar = ProgressBar(widgets=widgets, maxval=10*data_num*len(srcfiles)).start()
270 | with tf.Session() as sess:
271 | docids=[]
272 | segment=1
273 | # else:
274 | # raise Exception('Please assign datatype 16 or 32 bits')
275 | counter = 0
276 | i = 0
277 | vocab_adj =np.zeros((30522,30522), dtype=np.uint32)
278 | vocab_freq = np.zeros((30522), dtype=np.uint32)
279 | for srcfile in srcfiles:
280 | try:
281 | dataset = tf.data.TFRecordDataset(srcfile) # load tfrecord file
282 | except:
283 | print('Cannot find data')
284 | continue
285 | dataset = dataset.map(_parse_function) # parse data into tensor
286 | dataset = dataset.repeat(1)
287 | dataset = dataset.batch(batch)
288 | iterator = dataset.make_one_shot_iterator()
289 | next_data = iterator.get_next()
290 |
291 | while True:
292 | try:
293 | vocab_ids, docid = sess.run(next_data)
294 |
295 | vocab_id_list = vocab_ids.squeeze().tolist()
296 | try:
297 | num_vocab_id = len(vocab_id_list)
298 | if num_vocab_id >1:
299 | for m in range(num_vocab_id):
300 | vocab_freq[vocab_id_list[m]]+=1
301 | for n in range(m+1,num_vocab_id,1):
302 | vocab_adj[vocab_id_list[m], vocab_id_list[n]]+=1
303 |
304 | except:
305 | vocab_freq[vocab_id_list]+=1
306 |
307 |
308 |
309 | pbar.update(10 * i + 1)
310 | i+=1
311 | # if i>=20000:
312 | # break
313 | except tf.errors.OutOfRangeError:
314 | break
315 |
316 |
317 | return vocab_freq, vocab_adj
--------------------------------------------------------------------------------
/tevatron/Aggretriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/Aggretriever/__init__.py
--------------------------------------------------------------------------------
/tevatron/Aggretriever/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 | # remove_dim_dict = {768: -198, 640: -198, 512: 826, 256: 314, 128: 58}
6 | # remove_dim_dict1 = {768: 570, 640: 442, 512: 314, 256: 58, 128: 58}
7 |
8 | def cal_remove_dim(dims, vocab_size=30522):
9 |
10 | remove_dims = vocab_size % dims
11 | if remove_dims > 1000: # the first 1000 tokens in BERT are useless
12 | remove_dims -= dims
13 |
14 | return remove_dims
15 |
16 | def aggregate(lexical_reps: Tensor,
17 | dims: int = 640,
18 | remove_dims: int = -198,
19 | full: bool = True
20 | ):
21 |
22 | if full:
23 | remove_dims = cal_remove_dim(dims*2)
24 | batch_size = lexical_reps.shape[0]
25 | if remove_dims >= 0:
26 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims*2)
27 | else:
28 | lexical_reps = torch.nn.functional.pad(lexical_reps, (0, -remove_dims), "constant", 0).view(batch_size, -1, dims*2)
29 |
30 | tok_reps, _ = lexical_reps.max(1)
31 |
32 | positive_tok_reps = tok_reps[:, 0:2*dims:2]
33 | negative_tok_reps = tok_reps[:, 1:2*dims:2]
34 |
35 | positive_mask = positive_tok_reps > negative_tok_reps
36 | negative_mask = positive_tok_reps <= negative_tok_reps
37 | tok_reps = positive_tok_reps * positive_mask - negative_tok_reps * negative_mask
38 | else:
39 | remove_dims = cal_remove_dim(dims)
40 | batch_size = lexical_reps.shape[0]
41 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims)
42 | tok_reps, index_reps = lexical_reps.max(1)
43 |
44 | return tok_reps
45 |
46 |
--------------------------------------------------------------------------------
/tevatron/DHR/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/DHR/__init__.py
--------------------------------------------------------------------------------
/tevatron/DHR/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 | def densify(lexical_reps: Tensor,
6 | dims: int = 768,
7 | strategy: str = 'stride',
8 | remove_dims: int = 570
9 | ):
10 |
11 | if not (len(lexical_reps.shape)==2):
12 | raise ValueError( 'Input lexical representation shape should be 2 (batch, vocab), but the input shape is {}'.format( len(lexical_reps.shape) ) )
13 |
14 | orig_dims = lexical_reps.shape[-1]
15 | if not ( (orig_dims-remove_dims)%dims==0 ):
16 | raise ValueError('Input lexical representation cannot be densified, please fix dims or remove_dims')
17 |
18 | # Todo: add other strategy
19 | batch_size = lexical_reps.shape[0]
20 | lexical_reps = lexical_reps[:, remove_dims:].view(batch_size, -1, dims)
21 | value_reps, index_reps = lexical_reps.max(1)
22 | return value_reps, index_reps
23 |
--------------------------------------------------------------------------------
/tevatron/Dense/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/Dense/__init__.py
--------------------------------------------------------------------------------
/tevatron/Dense/modeling.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import copy
4 | from dataclasses import dataclass
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch import Tensor
9 | import torch.distributed as dist
10 |
11 | from transformers import AutoModel, PreTrainedModel, AutoModelForMaskedLM
12 | from transformers.modeling_outputs import ModelOutput
13 |
14 |
15 | from typing import Optional, Dict
16 |
17 | from ..arguments import ModelArguments, DataArguments, \
18 | DenseTrainingArguments as TrainingArguments
19 | import logging
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | @dataclass
25 | class DenseOutput(ModelOutput):
26 | q_reps: Tensor = None
27 | p_reps: Tensor = None
28 | loss: Tensor = None
29 | scores: Tensor = None
30 |
31 |
32 | class LinearPooler(nn.Module):
33 | def __init__(
34 | self,
35 | input_dim: int = 768,
36 | output_dim: int = 768,
37 | tied=True,
38 | elementwise=False,
39 | name='pooler'
40 | ):
41 | super(LinearPooler, self).__init__()
42 | self.name = name
43 | self.elementwise=elementwise
44 | self.linear_q = nn.Linear(input_dim, output_dim)
45 | if tied:
46 | self.linear_p = self.linear_q
47 | else:
48 | self.linear_p = nn.Linear(input_dim, output_dim)
49 |
50 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied}
51 |
52 | def forward(self, q: Tensor = None, p: Tensor = None):
53 | if q is not None:
54 | return self.linear_q(q)
55 | elif p is not None:
56 | return self.linear_p(p)
57 | else:
58 | raise ValueError
59 |
60 | def load(self, ckpt_dir: str):
61 | if ckpt_dir is not None:
62 | _pooler_path = os.path.join(ckpt_dir, '{}.pt'.format(self.name))
63 | if os.path.exists(_pooler_path):
64 | logger.info(f'Loading Pooler from {ckpt_dir}')
65 | state_dict = torch.load(os.path.join(ckpt_dir, '{}.pt'.format(self.name)), map_location='cpu')
66 | self.load_state_dict(state_dict)
67 | return
68 | logger.info("Training {} from scratch".format(self.name))
69 | return
70 |
71 | def save_pooler(self, save_path):
72 | torch.save(self.state_dict(), os.path.join(save_path, '{}.pt'.format(self.name)))
73 | with open(os.path.join(save_path, '{}_config.json').format(self.name), 'w') as f:
74 | json.dump(self._config, f)
75 |
76 |
77 | class DenseModel(nn.Module):
78 | def __init__(
79 | self,
80 | lm_q: PreTrainedModel,
81 | lm_p: PreTrainedModel,
82 | pooler: nn.Module = None,
83 | model_args: ModelArguments = None,
84 | data_args: DataArguments = None,
85 | train_args: TrainingArguments = None,
86 | ):
87 | super().__init__()
88 |
89 | self.lm_q = lm_q
90 | self.lm_p = lm_p
91 | self.pooler = pooler
92 |
93 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
94 |
95 | self.model_args = model_args
96 | self.train_args = train_args
97 | self.data_args = data_args
98 |
99 | if train_args.negatives_x_device:
100 | if not dist.is_initialized():
101 | raise ValueError('Distributed training has not been initialized for representation all gather.')
102 | self.process_rank = dist.get_rank()
103 | self.world_size = dist.get_world_size()
104 |
105 | def forward(
106 | self,
107 | query: Dict[str, Tensor] = None,
108 | passage: Dict[str, Tensor] = None,
109 | teacher_scores: Tensor = None,
110 | ):
111 |
112 |
113 | q_hidden, q_reps = self.encode_query(query, self.model_args.pooling_method)
114 | p_hidden, p_reps = self.encode_passage(passage, self.model_args.pooling_method)
115 |
116 | if q_reps is None or p_reps is None:
117 | return DenseOutput(
118 | q_reps=q_reps,
119 | p_reps=p_reps
120 | )
121 |
122 | if self.training:
123 | if self.train_args.negatives_x_device:
124 | q_reps = self.dist_gather_tensor(q_reps)
125 | p_reps = self.dist_gather_tensor(p_reps)
126 |
127 | effective_bsz = self.train_args.per_device_train_batch_size * self.world_size \
128 | if self.train_args.negatives_x_device \
129 | else self.train_args.per_device_train_batch_size
130 |
131 | scores = torch.matmul(q_reps, p_reps.transpose(0, 1))
132 | scores = scores.view(effective_bsz, -1)
133 |
134 | target = torch.arange(
135 | scores.size(0),
136 | device=scores.device,
137 | dtype=torch.long
138 | )
139 | target = target * self.data_args.train_n_passages
140 | loss = self.cross_entropy(scores, target)
141 | if self.train_args.negatives_x_device:
142 | loss = loss * self.world_size # counter average weight reduction
143 | return DenseOutput(
144 | loss=loss,
145 | scores=scores,
146 | q_reps=q_reps,
147 | p_reps=p_reps
148 | )
149 |
150 | else:
151 | loss = None
152 | if query and passage:
153 | scores = (q_reps * p_reps).sum(1)
154 | else:
155 | scores = None
156 |
157 | return DenseOutput(
158 | loss=loss,
159 | scores=scores,
160 | q_reps=q_reps,
161 | p_reps=p_reps
162 | )
163 |
164 | def encode_passage(self, psg, pooling_method):
165 | if psg is None:
166 | return None, None
167 |
168 | psg_out = self.lm_p(**psg, return_dict=True)
169 | p_hidden = psg_out.last_hidden_state
170 |
171 | if pooling_method == 'cls':
172 | p_reps = p_hidden[:, 0]
173 | elif pooling_method == 'average':
174 | attention_mask = psg['attention_mask']
175 | p_hidden = p_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
176 | p_reps = p_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
177 |
178 | if self.pooler is not None:
179 | p_reps = self.pooler(p=p_reps) # D * d
180 |
181 | return p_hidden, p_reps
182 |
183 | def encode_query(self, qry, pooling_method):
184 | if qry is None:
185 | return None, None
186 | qry_out = self.lm_q(**qry, return_dict=True)
187 | q_hidden = qry_out.last_hidden_state
188 |
189 | if pooling_method == 'cls':
190 | q_reps = q_hidden[:, 0]
191 | elif pooling_method == 'average':
192 | attention_mask = qry['attention_mask']
193 | q_hidden = q_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
194 | q_reps = q_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
195 |
196 |
197 | if self.pooler is not None:
198 | q_reps = self.pooler(q=q_reps)
199 |
200 | return q_hidden, q_reps
201 |
202 | @staticmethod
203 | def build_pooler(model_args):
204 | pooler = LinearPooler(
205 | model_args.projection_in_dim,
206 | model_args.projection_out_dim,
207 | tied=not model_args.untie_encoder
208 | )
209 | pooler.load(model_args.model_name_or_path)
210 | return pooler
211 |
212 | @classmethod
213 | def build(
214 | cls,
215 | model_args: ModelArguments,
216 | data_args: DataArguments,
217 | train_args: TrainingArguments,
218 | **hf_kwargs,
219 | ):
220 | # load local
221 | if os.path.isdir(model_args.model_name_or_path):
222 | if model_args.untie_encoder:
223 | _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model')
224 | _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model')
225 | if not os.path.exists(_qry_model_path):
226 | _qry_model_path = model_args.model_name_or_path
227 | _psg_model_path = model_args.model_name_or_path
228 | logger.info(f'loading query model weight from {_qry_model_path}')
229 | lm_q = AutoModel.from_pretrained(
230 | _qry_model_path,
231 | **hf_kwargs
232 | )
233 | logger.info(f'loading passage model weight from {_psg_model_path}')
234 | lm_p = AutoModel.from_pretrained(
235 | _psg_model_path,
236 | **hf_kwargs
237 | )
238 | else:
239 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
240 | lm_p = lm_q
241 | # load pre-trained
242 | else:
243 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
244 | lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q
245 |
246 | if model_args.add_pooler:
247 | pooler = cls.build_pooler(model_args)
248 | else:
249 | pooler = None
250 |
251 | model = cls(
252 | lm_q=lm_q,
253 | lm_p=lm_p,
254 | pooler=pooler,
255 | model_args=model_args,
256 | data_args=data_args,
257 | train_args=train_args
258 | )
259 | return model
260 |
261 | def save(self, output_dir: str):
262 | if self.model_args.untie_encoder:
263 | os.makedirs(os.path.join(output_dir, 'query_model'))
264 | os.makedirs(os.path.join(output_dir, 'passage_model'))
265 | self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
266 | self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
267 | else:
268 | self.lm_q.save_pretrained(output_dir)
269 |
270 | if self.model_args.add_pooler:
271 | self.pooler.save_pooler(output_dir)
272 |
273 | def dist_gather_tensor(self, t: Optional[torch.Tensor]):
274 | if t is None:
275 | return None
276 | t = t.contiguous()
277 |
278 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
279 | dist.all_gather(all_tensors, t)
280 |
281 | all_tensors[self.process_rank] = t
282 | all_tensors = torch.cat(all_tensors, dim=0)
283 |
284 | return all_tensors
285 |
286 |
287 | class DenseModelForInference(DenseModel):
288 | POOLER_CLS = LinearPooler
289 |
290 | def __init__(
291 | self,
292 | model_args,
293 | lm_q: PreTrainedModel,
294 | lm_p: PreTrainedModel,
295 | pooler: nn.Module = None,
296 | **kwargs,
297 | ):
298 | nn.Module.__init__(self)
299 | self.lm_q = lm_q
300 | self.lm_p = lm_p
301 | self.pooler = pooler
302 | self.model_args = model_args
303 |
304 | @torch.no_grad()
305 | def encode_passage(self, psg, pooling_method):
306 | return super(DenseModelForInference, self).encode_passage(psg, pooling_method)
307 |
308 | @torch.no_grad()
309 | def encode_query(self, qry, pooling_method):
310 | return super(DenseModelForInference, self).encode_query(qry, pooling_method)
311 |
312 | # def forward(
313 | # self,
314 | # query: Dict[str, Tensor] = None,
315 | # passage: Dict[str, Tensor] = None,
316 | # ):
317 | # q_hidden, q_reps = self.encode_query(query)
318 | # p_hidden, p_reps = self.encode_passage(passage)
319 | # return DenseOutput(q_reps=q_reps, p_reps=p_reps)
320 |
321 | @classmethod
322 | def build(
323 | cls,
324 | model_name_or_path: str = None,
325 | model_args: ModelArguments = None,
326 | data_args: DataArguments = None,
327 | train_args: TrainingArguments = None,
328 | **hf_kwargs,
329 | ):
330 | assert model_name_or_path is not None or model_args is not None
331 | if model_name_or_path is None:
332 | model_name_or_path = model_args.model_name_or_path
333 |
334 | # load local
335 | if os.path.isdir(model_name_or_path):
336 | _qry_model_path = os.path.join(model_name_or_path, 'query_model')
337 | _psg_model_path = os.path.join(model_name_or_path, 'passage_model')
338 | if os.path.exists(_qry_model_path):
339 | logger.info(f'found separate weight for query/passage encoders')
340 | logger.info(f'loading query model weight from {_qry_model_path}')
341 | lm_q = AutoModel.from_pretrained(
342 | _qry_model_path,
343 | **hf_kwargs
344 | )
345 | logger.info(f'loading passage model weight from {_psg_model_path}')
346 | lm_p = AutoModel.from_pretrained(
347 | _psg_model_path,
348 | **hf_kwargs
349 | )
350 | else:
351 | logger.info(f'try loading tied weight')
352 | logger.info(f'loading model weight from {model_name_or_path}')
353 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
354 | lm_p = lm_q
355 | else:
356 | logger.info(f'try loading tied weight')
357 | logger.info(f'loading model weight from {model_name_or_path}')
358 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
359 | lm_p = lm_q
360 |
361 | pooler_weights = os.path.join(model_name_or_path, 'pooler.pt')
362 | pooler_config = os.path.join(model_name_or_path, 'pooler_config.json')
363 | if os.path.exists(pooler_weights) and os.path.exists(pooler_config):
364 | logger.info(f'found pooler weight and configuration')
365 | with open(pooler_config) as f:
366 | pooler_config_dict = json.load(f)
367 | pooler = cls.POOLER_CLS(**pooler_config_dict)
368 | pooler.load(model_name_or_path)
369 | else:
370 | pooler = None
371 |
372 | model = cls(
373 | model_args=model_args,
374 | lm_q=lm_q,
375 | lm_p=lm_p,
376 | pooler=pooler
377 |
378 | )
379 | return model
--------------------------------------------------------------------------------
/tevatron/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/__init__.py
--------------------------------------------------------------------------------
/tevatron/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional, List
4 | from transformers import TrainingArguments
5 |
6 |
7 | @dataclass
8 | class ModelArguments:
9 | model_name_or_path: str = field(
10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
11 | )
12 | target_model_path: str = field(
13 | default=None,
14 | metadata={"help": "Path to pretrained reranker target model"}
15 | )
16 | config_name: Optional[str] = field(
17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
18 | )
19 | tokenizer_name: Optional[str] = field(
20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
21 | )
22 | cache_dir: Optional[str] = field(
23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
24 | )
25 |
26 | # modeling
27 | model: str = field(
28 | default='DHR',
29 | metadata={"help": "ColBERT, DHR, AGG, Dense"}
30 | )
31 | untie_encoder: bool = field(
32 | default=False,
33 | metadata={"help": "no weight sharing between qry passage encoders"}
34 | )
35 |
36 | # knowledge distillation
37 | teacher_model_name_or_path: str = field(
38 | default=None,
39 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
40 | )
41 | tct: bool = field(default=False)
42 | kd: bool = field(default=False)
43 |
44 | # out projection
45 | combine_cls: bool = field(default=False)
46 | add_pooler: bool = field(default=False)
47 | projection_in_dim: int = field(default=768)
48 | projection_out_dim: int = field(default=768)
49 |
50 | # Dense
51 | pooling_method: str = field(
52 | default='cls',
53 | metadata={"help": "cls, average"}
54 | )
55 |
56 | # dlr option
57 | dlr_out_dim: int = field(default=768)
58 |
59 | # agg option
60 | agg_dim: int = field(default=640)
61 | semi_aggregate: bool = field(default=False)
62 | skip_mlm: bool = field(default=False)
63 |
64 |
65 | # for Jax training
66 | dtype: Optional[str] = field(
67 | default="float32",
68 | metadata={
69 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
70 | "of `[float32, float16, bfloat16]`. "
71 | },
72 | )
73 |
74 |
75 | @dataclass
76 | class ColBERTModelArguments:
77 | config_name: Optional[str] = field(
78 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
79 | )
80 | tokenizer_name: Optional[str] = field(
81 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
82 | )
83 | cache_dir: Optional[str] = field(
84 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
85 | )
86 |
87 | # modeling
88 | model: str = field(
89 | default='ColBERT',
90 | metadata={"help": "ColBERT"}
91 | )
92 | untie_encoder: bool = field(
93 | default=False,
94 | metadata={"help": "no weight sharing between qry passage encoders"}
95 | )
96 |
97 | # out projection
98 | combine_cls: bool = field(default=False)
99 | add_pooler: bool = field(default=True)
100 | projection_in_dim: int = field(default=768)
101 | projection_out_dim: int = field(default=768)
102 |
103 | # for Jax training
104 | dtype: Optional[str] = field(
105 | default="float32",
106 | metadata={
107 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
108 | "of `[float32, float16, bfloat16]`. "
109 | },
110 | )
111 |
112 |
113 | @dataclass
114 | class DataArguments:
115 | train_dir: str = field(
116 | default=None, metadata={"help": "Path to train directory"}
117 | )
118 | corpus_dir: str = field(
119 | default=None, metadata={"help": "Path to corpus directory"}
120 | )
121 | query_cluster_dir: str = field(
122 | default=None, metadata={"help": "Path to query cluster direcotry"}
123 | )
124 | dataset_name: str = field(
125 | default=None, metadata={"help": "huggingface dataset name"}
126 | )
127 | passage_field_separator: str = field(default=' ')
128 | dataset_proc_num: int = field(
129 | default=12, metadata={"help": "number of proc used in dataset preprocess"}
130 | )
131 | train_n_passages: int = field(default=8)
132 | positive_passage_no_shuffle: bool = field(
133 | default=False, metadata={"help": "always use the first positive passage"})
134 | negative_passage_no_shuffle: bool = field(
135 | default=False, metadata={"help": "always use the first negative passages"})
136 |
137 | tasb_sampling: bool = field(
138 | default=False, metadata={"help": "use topic-aware balanced sampling"})
139 |
140 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"})
141 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
142 | encode_is_qry: bool = field(default=False)
143 | encode_num_shard: int = field(default=1)
144 | encode_shard_index: int = field(default=0)
145 |
146 | q_max_len: int = field(
147 | default=32,
148 | metadata={
149 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
150 | "than this will be truncated, sequences shorter will be padded."
151 | },
152 | )
153 | p_max_len: int = field(
154 | default=128,
155 | metadata={
156 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
157 | "than this will be truncated, sequences shorter will be padded."
158 | },
159 | )
160 | data_cache_dir: Optional[str] = field(
161 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"}
162 | )
163 |
164 | def __post_init__(self):
165 | if self.dataset_name is not None:
166 | info = self.dataset_name.split('/')
167 | self.dataset_split = info[-1] if len(info) == 3 else 'train'
168 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
169 | self.dataset_language = 'default'
170 | if ':' in self.dataset_name:
171 | self.dataset_name, self.dataset_language = self.dataset_name.split(':')
172 | else:
173 | self.dataset_name = 'json'
174 | self.dataset_split = 'train'
175 | self.dataset_language = 'default'
176 | if self.train_dir is not None:
177 | files = sorted(os.listdir(self.train_dir))
178 | self.train_path = [
179 | os.path.join(self.train_dir, f)
180 | for f in files
181 | if f.endswith('jsonl') or f.endswith('json')
182 | ]
183 | else:
184 | self.train_path = None
185 | if self.corpus_dir is not None:
186 | files = sorted(os.listdir(self.corpus_dir))
187 | self.corpus_path = [
188 | os.path.join(self.corpus_dir, f)
189 | for f in files
190 | if f.endswith('jsonl') or f.endswith('json')
191 | ]
192 | else:
193 | self.corpus_path = None
194 |
195 | if self.query_cluster_dir is not None:
196 | files = sorted(os.listdir(self.query_cluster_dir))
197 | self.query_cluster_path = [
198 | os.path.join(self.query_cluster_dir, f)
199 | for f in files
200 | if f.endswith('jsonl') or f.endswith('json')
201 | ]
202 | else:
203 | self.query_cluster_path = None
204 |
205 |
206 |
207 | @dataclass
208 | class DenseTrainingArguments(TrainingArguments):
209 | warmup_ratio: float = field(default=0.1)
210 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
211 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
212 | ddp_find_unused_parameters: bool = field(default=True, metadata={"help": "set find unused parameters"})
213 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
214 | gc_q_chunk_size: int = field(default=4)
215 | gc_p_chunk_size: int = field(default=32)
216 |
--------------------------------------------------------------------------------
/tevatron/data.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass
3 | from typing import List, Tuple
4 |
5 | from tqdm import tqdm
6 | import glob
7 | import os
8 | import json
9 |
10 | import datasets
11 | from torch.utils.data import Dataset
12 | from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding
13 | import torch
14 |
15 | from .arguments import DataArguments
16 | from .trainer import DenseTrainer
17 |
18 | import logging
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class TrainDataset(Dataset):
23 | def __init__(
24 | self,
25 | data_args: DataArguments,
26 | dataset: datasets.Dataset,
27 | tokenizer: PreTrainedTokenizer,
28 | trainer: DenseTrainer = None,
29 | ):
30 | self.train_data = dataset
31 | self.tok = tokenizer
32 | self.trainer = trainer
33 |
34 | self.data_args = data_args
35 | self.total_len = len(self.train_data)
36 |
37 | def create_one_example(self, text_encoding: List[int], is_query=False):
38 | item = self.tok.encode_plus(
39 | text_encoding,
40 | truncation='only_first',
41 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len,
42 | padding=False,
43 | return_attention_mask=False,
44 | return_token_type_ids=False,
45 | )
46 | return item
47 |
48 | def __len__(self):
49 | return self.total_len
50 |
51 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
52 | group = self.train_data[item]
53 | epoch = int(self.trainer.state.epoch)
54 |
55 | _hashed_seed = hash(item + self.trainer.args.seed)
56 |
57 | qry = group['query']
58 | encoded_query = self.create_one_example(qry, is_query=True)
59 |
60 | encoded_passages = []
61 | group_positives = group['positives']
62 | group_negatives = group['negatives']
63 |
64 | if self.data_args.positive_passage_no_shuffle:
65 | pos_psg = group_positives[0]
66 | else:
67 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]
68 | encoded_passages.append(self.create_one_example(pos_psg))
69 |
70 | negative_size = self.data_args.train_n_passages - 1
71 | if len(group_negatives) < negative_size:
72 | negs = random.choices(group_negatives, k=negative_size)
73 | elif self.data_args.train_n_passages == 1:
74 | negs = []
75 | elif self.data_args.negative_passage_no_shuffle:
76 | negs = group_negatives[:negative_size]
77 | else:
78 | _offset = epoch * negative_size % len(group_negatives)
79 | negs = [x for x in group_negatives]
80 | random.Random(_hashed_seed).shuffle(negs)
81 | negs = negs * 2
82 | negs = negs[_offset: _offset + negative_size]
83 |
84 | for neg_psg in negs:
85 | encoded_passages.append(self.create_one_example(neg_psg))
86 |
87 | return encoded_query, encoded_passages
88 |
89 | class TrainTASBDataset(Dataset):
90 | # This is now only for msmarco-passage; since the id starts from 0. While using other datasets, this should be revised.
91 | def __init__(
92 | self,
93 | data_args: DataArguments,
94 | kd,
95 | dataset: datasets.Dataset,
96 | corpus: datasets.Dataset,
97 | tokenizer: PreTrainedTokenizer,
98 | trainer: DenseTrainer = None,
99 | ):
100 | self.train_data, self.qidx_cluster = dataset
101 | self.corpus = corpus
102 | self.tok = tokenizer
103 | self.trainer = trainer
104 | self.data_args = data_args
105 | self.tasb_sampling = data_args.tasb_sampling
106 | self.kd = kd
107 |
108 | if self.data_args.corpus_dir is None:
109 | raise ValueError('You should input --corpus_dir with files split*.json')
110 |
111 | # if (self.data_args.train_n_passages!=2) and (self.tasb_sampling):
112 | # raise ValueError('--train_n_passages should be 2 if you use tasb sampling')
113 |
114 | if (self.qidx_cluster is None) and (self.tasb_sampling):
115 | raise ValueError('You should input --query_cluster_dir for tasb sampling')
116 |
117 | self.data_args = data_args
118 | self.total_len = len(self.train_data)
119 | if self.qidx_cluster:
120 | self.cluster_num = len(self.qidx_cluster)
121 |
122 |
123 | def create_one_example(self, text_encoding: List[int], is_query=False):
124 | item = self.tok.encode_plus(
125 | text_encoding,
126 | truncation='only_first',
127 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len,
128 | padding=False,
129 | return_attention_mask=False,
130 | return_token_type_ids=False,
131 | )
132 | return item
133 |
134 | def output_qp(self, group, _hashed_seed):
135 | epoch = int(self.trainer.state.epoch)
136 | qry = group['query']
137 | encoded_query = self.create_one_example(qry, is_query=True)
138 |
139 | encoded_passages = []
140 | group_positives = group['positive_pids']
141 | group_negatives = group['negative_pids']
142 |
143 | if self.data_args.positive_passage_no_shuffle:
144 | pos_psg_id = group_positives[0]
145 | else:
146 | pos_psg_id = group_positives[(_hashed_seed + epoch) % len(group_positives)]
147 | pos_psg = self.corpus[int(pos_psg_id)]['text']
148 | encoded_passages.append(self.create_one_example(pos_psg))
149 |
150 | negative_size = self.data_args.train_n_passages - 1
151 | if len(group_negatives) < negative_size:
152 | negs = random.choices(group_negatives, k=negative_size)
153 | elif self.data_args.train_n_passages == 1:
154 | negs = []
155 | elif self.data_args.negative_passage_no_shuffle:
156 | negs = group_negatives[:negative_size]
157 | else:
158 | _offset = epoch * negative_size % len(group_negatives)
159 | negs = [x for x in group_negatives]
160 | random.Random(_hashed_seed).shuffle(negs)
161 | negs = negs * 2
162 | negs = negs[_offset: _offset + negative_size]
163 |
164 | for neg_psg_pid in negs:
165 | neg_psg = self.corpus[int(neg_psg_pid)]['text']
166 | encoded_passages.append(self.create_one_example(neg_psg))
167 |
168 | return encoded_query, encoded_passages, None
169 |
170 | def output_qp_with_score(self, group, _hashed_seed):
171 | qry = group['query']
172 | encoded_query = self.create_one_example(qry, is_query=True)
173 |
174 | encoded_passages = []
175 | scores = []
176 | qids_bin_pairs = group['bin_pairs']
177 | bins_pairs = random.choices(qids_bin_pairs, k=1)[0]
178 |
179 | pairs = []
180 | negative_size = self.data_args.train_n_passages - 1
181 |
182 | for i in range(negative_size):
183 | bin_pairs = random.choices(bins_pairs, k=1)[0]
184 | pairs.append(random.choices(bin_pairs, k=1)[0])
185 |
186 | pos_psg_idx = int(pairs[0][0])
187 | pos_psg_id = group['positive_pids'][pos_psg_idx]
188 | pos_psg = self.corpus[int(pos_psg_id)]['text']
189 | encoded_passages.append(self.create_one_example(pos_psg))
190 |
191 | for pair in pairs:
192 | neg_psg_idx = int(pair[1])
193 | neg_psg_id = group['negative_pids'][neg_psg_idx]
194 | neg_psg = self.corpus[int(neg_psg_id)]['text']
195 | encoded_passages.append(self.create_one_example(neg_psg))
196 | scores.append(-pair[2])
197 |
198 | return encoded_query, encoded_passages, scores
199 |
200 | def __len__(self):
201 | return self.total_len
202 |
203 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
204 | _hashed_seed = hash(item + self.trainer.args.seed)
205 | if self.tasb_sampling:
206 | # make sure the same query cluster gathered in the same batch
207 | random.seed(self.trainer.state.global_step)
208 | cluster_list = random.choices(self.qidx_cluster, k=24)
209 |
210 | #sampling different queries in a batch
211 | random.seed(_hashed_seed)
212 | cluster = random.choices(cluster_list, k=1)[0]
213 | item = random.choices(cluster['qidx'])[0]
214 |
215 | group = self.train_data[item]
216 | else:
217 | group = self.train_data[item]
218 |
219 | if self.kd:
220 | return self.output_qp_with_score(group, _hashed_seed)
221 | else:
222 | return self.output_qp(group, _hashed_seed)
223 |
224 |
225 |
226 |
227 | class EncodeDataset(Dataset):
228 | input_keys = ['text_id', 'text']
229 |
230 | def __init__(self, dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer, max_len=128):
231 | self.encode_data = dataset
232 | self.tok = tokenizer
233 | self.max_len = max_len
234 |
235 | def __len__(self):
236 | return len(self.encode_data)
237 |
238 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]:
239 | text_id, text = (self.encode_data[item][f] for f in self.input_keys)
240 | if len(text)==0:
241 | text = [0]
242 | encoded_text = self.tok.encode_plus(
243 | text,
244 | max_length=self.max_len,
245 | truncation='only_first',
246 | padding=False,
247 | return_token_type_ids=False,
248 | )
249 | return text_id, encoded_text
250 |
251 | class EvalDataset(Dataset):
252 | input_keys = ['qry_text_id', 'qry_text', 'psg_text_id', 'psg_text', 'rel']
253 |
254 | def __init__(self,
255 | data_args: DataArguments,
256 | dataset: datasets.Dataset,
257 | tokenizer: PreTrainedTokenizer):
258 | self.encode_data = dataset
259 | self.tok = tokenizer
260 | self.data_args = data_args
261 |
262 | def __len__(self):
263 | return len(self.encode_data)
264 |
265 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]:
266 | qry_text_id, qry_text, psg_text_id, psg_text, rel = (self.encode_data[item][f] for f in self.input_keys)
267 | encoded_qry_text = self.tok.encode_plus(
268 | qry_text,
269 | max_length=self.data_args.q_max_len,
270 | truncation='only_first',
271 | padding=False,
272 | return_token_type_ids=False,
273 | )
274 | if len(psg_text)==0:
275 | psg_text = [0]
276 | encoded_psg_text = self.tok.encode_plus(
277 | psg_text,
278 | max_length=self.data_args.p_max_len,
279 | truncation='only_first',
280 | padding=False,
281 | return_token_type_ids=False,
282 | )
283 | return qry_text_id, encoded_qry_text, psg_text_id, encoded_psg_text, rel
284 |
285 |
286 | @dataclass
287 | class QPCollator(DataCollatorWithPadding):
288 | """
289 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
290 | and pass batch separately to the actual collator.
291 | Abstract out data detail for the model.
292 | """
293 | max_q_len: int = 32
294 | max_p_len: int = 128
295 |
296 | def __call__(self, features):
297 | qq = [f[0] for f in features]
298 | dd = [f[1] for f in features]
299 |
300 | if isinstance(qq[0], list):
301 | qq = sum(qq, [])
302 | if isinstance(dd[0], list):
303 | dd = sum(dd, [])
304 |
305 | q_collated = self.tokenizer.pad(
306 | qq,
307 | padding='max_length',
308 | max_length=self.max_q_len,
309 | return_tensors="pt",
310 | )
311 | d_collated = self.tokenizer.pad(
312 | dd,
313 | padding='max_length',
314 | max_length=self.max_p_len,
315 | return_tensors="pt",
316 | )
317 |
318 | if features[0][2] is not None:
319 | scores = [[0]+f[2] for f in features]
320 | scores_collated = torch.tensor(scores)
321 | else:
322 | scores_collated = None
323 |
324 | return q_collated, d_collated, scores_collated
325 |
326 |
327 | @dataclass
328 | class EncodeCollator(DataCollatorWithPadding):
329 | def __call__(self, features):
330 | text_ids = [x[0] for x in features]
331 | text_features = [x[1] for x in features]
332 | collated_features = super().__call__(text_features)
333 | return text_ids, collated_features
334 |
335 | @dataclass
336 | class EvalCollator(DataCollatorWithPadding):
337 | max_q_len: int = 32
338 | max_p_len: int = 128
339 | def __call__(self, features):
340 | qry_text_ids = [x[0] for x in features]
341 | qry_text_features = [x[1] for x in features]
342 | psg_text_ids = [x[2] for x in features]
343 | psg_text_features = [x[3] for x in features]
344 | rels = [x[4] for x in features]
345 | if isinstance(qry_text_features[0], list):
346 | qry_text_features = sum(qry_text_features, [])
347 | if isinstance(psg_text_features[0], list):
348 | psg_text_features = sum(psg_text_features, [])
349 |
350 | qry_collated_features = self.tokenizer.pad(
351 | qry_text_features,
352 | padding='max_length',
353 | max_length=self.max_q_len,
354 | return_tensors="pt",
355 | )
356 | psg_collated_features = self.tokenizer.pad(
357 | psg_text_features,
358 | padding='max_length',
359 | max_length=self.max_p_len,
360 | return_tensors="pt",
361 | )
362 | return qry_text_ids, qry_collated_features, psg_text_ids, psg_collated_features, rels
--------------------------------------------------------------------------------
/tevatron/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import HFTrainDataset, HFQueryDataset, HFCorpusDataset, HFEvalDataset
2 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor
3 |
--------------------------------------------------------------------------------
/tevatron/datasets/beir/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/datasets/beir/__init__.py
--------------------------------------------------------------------------------
/tevatron/datasets/beir/encode_and_retrieval.py:
--------------------------------------------------------------------------------
1 | ################################################################################################################
2 | # The evaluation code is revised from SPLADE repo: https://github.com/naver/splade/blob/main/src/beir_eval.py
3 |
4 |
5 | import argparse
6 | from .sentence_bert import Retriever, SentenceTransformerModel
7 | from transformers import AutoModelForMaskedLM, AutoTokenizer
8 |
9 |
10 | from ...arguments import ModelArguments
11 |
12 |
13 | from beir.datasets.data_loader import GenericDataLoader
14 | from beir.retrieval.evaluation import EvaluateRetrieval
15 | from beir import util, LoggingHandler
16 |
17 | def main():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--dataset", type=str, required=True)
20 | parser.add_argument("--model_name_or_path", type=str, required=True)
21 | parser.add_argument("--max_length", type=int, default=512)
22 | parser.add_argument("--model", type=str, default='dhr', help='dhr, agg, dense')
23 | parser.add_argument("--agg_dim", type=int, default=640, help='for agg model')
24 | parser.add_argument("--semi_aggregate", action='store_true', help='for agg model')
25 | parser.add_argument("--skip_mlm", action='store_true', help='for agg model')
26 | parser.add_argument("--pooling_method", type=str, default='cls', help='for dense model')
27 | args = parser.parse_args()
28 |
29 |
30 | model_type_or_dir = args.model_name_or_path
31 | model_args = ModelArguments
32 | model_args.model = args.model.lower()
33 | # agg method
34 | model_args.agg_dim = args.agg_dim
35 | model_args.semi_aggregate = args.semi_aggregate
36 | model_args.skip_mlm = args.skip_mlm
37 | model_args.pooling_method = args.pooling_method
38 | # loading model and tokenizer
39 | model = Retriever(model_type_or_dir, model_args)
40 |
41 | model.eval()
42 | tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir, use_fast=False)
43 | sentence_transformer = SentenceTransformerModel(model, tokenizer, args.max_length)
44 |
45 |
46 | dataset = args.dataset
47 |
48 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
49 | out_dir = "dataset/{}".format(dataset)
50 | data_path = util.download_and_unzip(url, out_dir)
51 |
52 | #### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
53 | # data folder would contain these files:
54 | # (1) nfcorpus/corpus.jsonl (format: jsonlines)
55 | # (2) nfcorpus/queries.jsonl (format: jsonlines)
56 | # (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
57 |
58 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
59 |
60 | from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
61 | from beir.retrieval.evaluation import EvaluateRetrieval
62 |
63 | dres = DRES(sentence_transformer)
64 | retriever = EvaluateRetrieval(dres, score_function="dot")
65 | results = retriever.retrieve(corpus, queries)
66 | ndcg, map_, recall, p = EvaluateRetrieval.evaluate(qrels, results, [1, 10, 100, 1000])
67 | results2 = EvaluateRetrieval.evaluate_custom(qrels, results, [1, 10, 100, 1000], metric="r_cap")
68 | res = {"NDCG@10": ndcg["NDCG@10"],
69 | "Recall@100": recall["Recall@100"],
70 | "R_cap@100": results2["R_cap@100"]}
71 | print("res for {}:".format(dataset), res, flush=True)
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
--------------------------------------------------------------------------------
/tevatron/datasets/beir/preprocess.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
3 | import argparse
4 | import pathlib, os
5 | from beir import util, LoggingHandler
6 | from beir.datasets.data_loader import GenericDataLoader
7 | logger = logging.getLogger(__name__)
8 | from ...utils.data_reader import create_dir
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--output_dir", required=False, default='./dataset', type=str)
14 | parser.add_argument("--dataset", required=True, type=str, help="beir dataset name")
15 | parser.add_argument("--split", default='test', type=str, help="beir dataset name")
16 | args = parser.parse_args()
17 |
18 | #### Download scifact.zip dataset and unzip the dataset
19 | create_dir(os.path.join('./download'))
20 | create_dir(os.path.join(args.output_dir))
21 | dataset = args.dataset
22 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
23 | data_path = util.download_and_unzip(url, './download')
24 |
25 | #### Provide the data_path where scifact has been downloaded and unzipped
26 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=args.split)
27 |
28 | create_dir(os.path.join(args.output_dir, args.dataset, 'corpus'))
29 | os.rename(os.path.join('./download', args.dataset, 'corpus.jsonl'), os.path.join(args.output_dir, args.dataset, 'corpus', 'collection.json'))
30 |
31 | create_dir(os.path.join(args.output_dir, args.dataset,'qrels'))
32 | qrel_fout = open(os.path.join(args.output_dir, args.dataset,'qrels', 'qrels.' + args.split + '.tsv'), 'w')
33 |
34 | create_dir(os.path.join(args.output_dir, args.dataset,'queries'))
35 | query_fout = open(os.path.join(args.output_dir, args.dataset, 'queries', 'queries.' + args.split + '.tsv'), 'w')
36 |
37 | for qid, answer in qrels.items():
38 | for docid, rel in answer.items():
39 | qrel_fout.write('{}\tQ0\t{}\t{}\n'.format(qid, docid, rel))
40 | query_fout.write('{}\t{}\n'.format(qid, queries[qid]))
41 |
42 | qrel_fout.close()
43 | query_fout.close()
44 |
45 | if __name__ == "__main__":
46 | main()
--------------------------------------------------------------------------------
/tevatron/datasets/beir/sentence_bert.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Dict, Union
3 |
4 | import numpy as np
5 | import torch
6 | from numpy import ndarray
7 | from torch import Tensor
8 | from tqdm.autonotebook import trange
9 | from transformers import AutoModelForMaskedLM
10 |
11 |
12 | try:
13 | import sentence_transformers
14 | from sentence_transformers.util import batch_to_device
15 | except ImportError:
16 | print("Import Error: could not load sentence_transformers... proceeding")
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class SentenceTransformerModel:
21 | def __init__(self, model, tokenizer, max_length=512):
22 | self.max_length = max_length
23 | self.tokenizer = tokenizer
24 | self.model = model
25 | self.sep = ' '
26 |
27 | # Write your own encoding query function (Returns: Query embeddings as numpy array)
28 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
29 | X = self.model.encode_sentence_bert(self.tokenizer, queries, is_q=True, maxlen=self.max_length)
30 | return X
31 |
32 | # Write your own encoding corpus function (Returns: Document embeddings as numpy array)
33 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
34 | sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
35 | return self.model.encode_sentence_bert(self.tokenizer, sentences, maxlen=self.max_length)
36 |
37 |
38 |
39 | class Retriever(torch.nn.Module):
40 |
41 | def __init__(self, model_type_or_dir, model_args):
42 | super().__init__()
43 | self.model_args = model_args
44 | if self.model_args.model.lower() == 'dhr':
45 | from ...DHR.modeling import DHRModelForInference
46 | from ...DHR.modeling import DHROutput as output
47 | self.transformer = DHRModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args)
48 | elif self.model_args.model.lower() == 'agg':
49 | from ...Aggretriever.modeling import DenseModelForInference
50 | from ...Aggretriever.modeling import DenseOutput as output
51 | self.transformer = DenseModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args)
52 | elif self.model_args.model.lower() == 'dense':
53 | from ...Dense.modeling import DenseModelForInference
54 | from ...Dense.modeling import DenseOutput as Output
55 | self.transformer = DenseModelForInference.build(model_name_or_path=model_type_or_dir, model_args=model_args)
56 | else:
57 | raise ValueError('--rep_type can only be dhr or dense (CLS) or agg.')
58 | def forward(self, features, is_q):
59 | if is_q:
60 | if self.model_args.model== 'dhr':
61 | out = self.transformer(query=features)
62 | return [out.q_lexical_reps, out.q_semantic_reps]
63 | if self.model_args.model == 'agg':
64 | out = self.transformer(query=features)
65 | return out.q_reps
66 | elif self.model_args.model == 'dense':
67 | out = self.transformer(query=features)
68 | return out.q_reps
69 | else:
70 | if self.model_args.model == 'dhr':
71 | out = self.transformer(passage=features)
72 | return [out.p_lexical_reps, out.p_semantic_reps]
73 | if self.model_args.model == 'agg':
74 | out = self.transformer(passage=features)
75 | return out.p_reps
76 | elif self.model_args.model == 'dense':
77 | out = self.transformer(passage=features)
78 | return out.p_reps
79 |
80 | def _text_length(self, text: Union[List[int], List[List[int]]]):
81 | """helper function to get the length for the input text. Text can be either
82 | a list of ints (which means a single text as input), or a tuple of list of ints
83 | (representing several text inputs to the model).
84 | """
85 |
86 | if isinstance(text, dict): # {key: value} case
87 | return len(next(iter(text.values())))
88 | elif not hasattr(text, '__len__'): # Object has no len() method
89 | return 1
90 | elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
91 | return len(text)
92 | else:
93 | return sum([len(t) for t in text]) # Sum of length of individual strings
94 |
95 | def encode_sentence_bert(self, tokenizer, sentences: Union[str, List[str], List[int]],
96 | batch_size: int = 32,
97 | show_progress_bar: bool = None,
98 | output_value: str = 'dhr_embeddings',
99 | convert_to_numpy: bool = True,
100 | convert_to_tensor: bool = False,
101 | device: str = None,
102 | normalize_embeddings: bool = False,
103 | maxlen: int = 512,
104 | is_q: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
105 | """
106 | Computes sentence embeddings
107 | :param sentences: the sentences to embed
108 | :param batch_size: the batch size used for the computation
109 | :param show_progress_bar: Output a progress bar when encode sentences
110 | :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings.
111 | :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
112 | :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
113 | :param device: Which torch.device to use for the computation
114 | :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
115 | :return:
116 | By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
117 | """
118 | if self.model_args.model == 'dense':
119 | output_value = 'sentence_embeddings'
120 | elif self.model_args.model == 'agg':
121 | output_value = 'sentence_embeddings'
122 | else:
123 | output_value = 'dhr_embeddings'
124 |
125 |
126 |
127 | self.eval()
128 | if show_progress_bar is None:
129 | show_progress_bar = True
130 |
131 | if convert_to_tensor:
132 | convert_to_numpy = False
133 |
134 | if output_value == 'token_embeddings':
135 | convert_to_tensor = False
136 | convert_to_numpy = False
137 |
138 | input_was_string = False
139 | if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
140 | # Cast an individual sentence to a list with length 1
141 | sentences = [sentences]
142 | input_was_string = True
143 |
144 | if device is None:
145 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
146 |
147 | self.to(device)
148 |
149 | all_embeddings = []
150 | all_semantic_embeddings = []
151 | all_lexical_embeddings = []
152 | length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
153 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
154 |
155 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
156 | sentences_batch = sentences_sorted[start_index:start_index + batch_size]
157 | # features = tokenizer(sentences_batch)
158 | # print(sentences_batch)
159 | features = tokenizer(sentences_batch,
160 | add_special_tokens=True,
161 | padding="longest", # pad to max sequence length in batch
162 | truncation="only_first", # truncates to self.max_length
163 | max_length=maxlen,
164 | return_attention_mask=True,
165 | return_tensors="pt")
166 | # print(features)
167 | features = batch_to_device(features, device)
168 |
169 | with torch.no_grad():
170 | out_features = self.forward(features, is_q)
171 | if output_value == 'dhr_embeddings':
172 | lexical_embeddings = out_features[0].detach()
173 | try:
174 | semantic_embeddings = out_features[1].detach()
175 | semantic_dim = semantic_embeddings.shape[1]
176 | except:
177 | semantic_dim = 0
178 | if convert_to_numpy:
179 | lexical_embeddings = lexical_embeddings.cpu()
180 | try:
181 | semantic_embeddings = semantic_embeddings.cpu()
182 | except:
183 | semantic_dim = 0
184 |
185 | embeddings = torch.zeros((lexical_embeddings.shape[0], lexical_embeddings.shape[1] + semantic_dim))
186 | embeddings[:,:lexical_embeddings.shape[1]] = lexical_embeddings
187 | if semantic_dim != 0:
188 | embeddings[:,lexical_embeddings.shape[1]:] = semantic_embeddings
189 |
190 | else:
191 | if output_value == 'token_embeddings':
192 | embeddings = []
193 | for token_emb, attention in zip(out_features[output_value], out_features['attention_mask']):
194 | last_mask_id = len(attention) - 1
195 | while last_mask_id > 0 and attention[last_mask_id].item() == 0:
196 | last_mask_id -= 1
197 | embeddings.append(token_emb[0:last_mask_id + 1])
198 | elif output_value == 'sentence_embeddings':
199 | # embeddings = out_features[output_value]
200 | embeddings = out_features
201 | embeddings = embeddings.detach()
202 | if normalize_embeddings:
203 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
204 | # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
205 | if convert_to_numpy:
206 | embeddings = embeddings.cpu()
207 |
208 | all_embeddings.extend(embeddings)
209 |
210 |
211 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
212 | if convert_to_tensor:
213 | all_embeddings = torch.stack(all_embeddings)
214 | elif convert_to_numpy:
215 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
216 | if input_was_string:
217 | all_embeddings = all_embeddings[0]
218 | return all_embeddings
219 |
220 |
--------------------------------------------------------------------------------
/tevatron/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from transformers import PreTrainedTokenizer
3 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor, EvalPreProcessor
4 | from ..arguments import DataArguments
5 |
6 | DEFAULT_PROCESSORS = [TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor, EvalPreProcessor]
7 | PROCESSOR_INFO = {
8 | 'Tevatron/wikipedia-nq': DEFAULT_PROCESSORS,
9 | 'Tevatron/wikipedia-trivia': DEFAULT_PROCESSORS,
10 | 'Tevatron/wikipedia-curated': DEFAULT_PROCESSORS,
11 | 'Tevatron/wikipedia-wq': DEFAULT_PROCESSORS,
12 | 'Tevatron/wikipedia-squad': DEFAULT_PROCESSORS,
13 | 'Tevatron/scifact': DEFAULT_PROCESSORS,
14 | 'Tevatron/msmarco-passage': DEFAULT_PROCESSORS,
15 | 'json': [None, None, None, None]
16 | }
17 |
18 |
19 | class HFTrainDataset:
20 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
21 | data_files = data_args.train_path
22 | if data_files:
23 | data_files = {data_args.dataset_split: data_files}
24 |
25 | self.dataset = load_dataset(data_args.dataset_name,
26 | data_args.dataset_language,
27 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
28 |
29 | if data_args.query_cluster_path is not None:
30 | data_files = {data_args.dataset_split: data_args.query_cluster_path}
31 | self.qidx_cluster = load_dataset(data_args.dataset_name,
32 | data_args.dataset_language,
33 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
34 | else:
35 | self.qidx_cluster = None
36 |
37 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][0] if data_args.dataset_name in PROCESSOR_INFO\
38 | else DEFAULT_PROCESSORS[0]
39 | self.tokenizer = tokenizer
40 | self.q_max_len = data_args.q_max_len
41 | self.p_max_len = data_args.p_max_len
42 | self.proc_num = data_args.dataset_proc_num
43 | self.neg_num = data_args.train_n_passages - 1
44 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator)
45 |
46 | def process(self, shard_num=1, shard_idx=0):
47 | self.dataset = self.dataset.shard(shard_num, shard_idx)
48 | if self.preprocessor is not None:
49 | self.dataset = self.dataset.map(
50 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len, self.separator),
51 | batched=False,
52 | num_proc=self.proc_num,
53 | remove_columns=self.dataset.column_names,
54 | desc="Running tokenizer on train dataset",
55 | )
56 | return self.dataset, self.qidx_cluster
57 |
58 |
59 | class HFQueryDataset:
60 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
61 | data_files = data_args.encode_in_path
62 | if data_files:
63 | data_files = {data_args.dataset_split: data_files}
64 | self.dataset = load_dataset(data_args.dataset_name,
65 | data_args.dataset_language,
66 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
67 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][1] if data_args.dataset_name in PROCESSOR_INFO \
68 | else DEFAULT_PROCESSORS[1]
69 | self.tokenizer = tokenizer
70 | self.q_max_len = data_args.q_max_len
71 | self.proc_num = data_args.dataset_proc_num
72 |
73 | def process(self, shard_num=1, shard_idx=0):
74 | self.dataset = self.dataset.shard(shard_num, shard_idx)
75 | if self.preprocessor is not None:
76 | self.dataset = self.dataset.map(
77 | self.preprocessor(self.tokenizer, self.q_max_len),
78 | batched=False,
79 | num_proc=self.proc_num,
80 | remove_columns=self.dataset.column_names,
81 | desc="Running tokenization",
82 | )
83 | return self.dataset
84 |
85 |
86 | class HFCorpusDataset:
87 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
88 | if data_args.encode_in_path is not None:
89 | data_files = data_args.encode_in_path
90 | if data_args.corpus_path is not None:
91 | data_files = data_args.corpus_path
92 | if data_files:
93 | data_files = {data_args.dataset_split: data_files}
94 | self.dataset = load_dataset(data_args.dataset_name,
95 | data_args.dataset_language,
96 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
97 | script_prefix = data_args.dataset_name
98 | if script_prefix.endswith('-corpus'):
99 | script_prefix = script_prefix[:-7]
100 | self.preprocessor = PROCESSOR_INFO[script_prefix][2] \
101 | if script_prefix in PROCESSOR_INFO else DEFAULT_PROCESSORS[2]
102 | self.tokenizer = tokenizer
103 | self.p_max_len = data_args.p_max_len
104 | self.proc_num = data_args.dataset_proc_num
105 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator)
106 |
107 | def process(self, shard_num=1, shard_idx=0):
108 | self.dataset = self.dataset.shard(shard_num, shard_idx)
109 | if self.preprocessor is not None:
110 | self.dataset = self.dataset.map(
111 | self.preprocessor(self.tokenizer, self.p_max_len, self.separator),
112 | batched=False,
113 | num_proc=self.proc_num,
114 | remove_columns=self.dataset.column_names,
115 | desc="Running tokenization",
116 | )
117 | return self.dataset
118 |
119 | class HFEvalDataset:
120 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
121 | data_files = data_args.encode_in_path
122 | if data_files:
123 | data_files = {data_args.dataset_split: data_files}
124 | self.dataset = load_dataset(data_args.dataset_name,
125 | data_args.dataset_language,
126 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
127 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][3] if data_args.dataset_name in PROCESSOR_INFO \
128 | else DEFAULT_PROCESSORS[3]
129 | self.tokenizer = tokenizer
130 | self.q_max_len = data_args.q_max_len
131 | self.p_max_len = data_args.p_max_len
132 | self.proc_num = data_args.dataset_proc_num
133 |
134 | def process(self, shard_num=1, shard_idx=0):
135 | self.dataset = self.dataset.shard(shard_num, shard_idx)
136 | if self.preprocessor is not None:
137 | self.dataset = self.dataset.map(
138 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len),
139 | batched=False,
140 | num_proc=self.proc_num,
141 | remove_columns=self.dataset.column_names,
142 | desc="Running tokenization",
143 | )
144 | return self.dataset
--------------------------------------------------------------------------------
/tevatron/datasets/preprocessor.py:
--------------------------------------------------------------------------------
1 | class TrainPreProcessor:
2 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256, separator=' '):
3 | self.tokenizer = tokenizer
4 | self.query_max_length = query_max_length
5 | self.text_max_length = text_max_length
6 | self.separator = separator
7 |
8 | def __call__(self, example):
9 | query = self.tokenizer.encode(example['query'],
10 | add_special_tokens=False,
11 | max_length=self.query_max_length,
12 | truncation=True)
13 | positives = []
14 | for pos in example['positive_passages']:
15 | text = pos['title'] + self.separator + pos['text'] if 'title' in pos else pos['text']
16 | positives.append(self.tokenizer.encode(text,
17 | add_special_tokens=False,
18 | max_length=self.text_max_length,
19 | truncation=True))
20 | negatives = []
21 | for neg in example['negative_passages']:
22 | text = neg['title'] + self.separator + neg['text'] if 'title' in neg else neg['text']
23 | negatives.append(self.tokenizer.encode(text,
24 | add_special_tokens=False,
25 | max_length=self.text_max_length,
26 | truncation=True))
27 | return {'query': query, 'positives': positives, 'negatives': negatives}
28 |
29 |
30 | class QueryPreProcessor:
31 | def __init__(self, tokenizer, query_max_length=32):
32 | self.tokenizer = tokenizer
33 | self.query_max_length = query_max_length
34 |
35 | def __call__(self, example):
36 | query_id = example['query_id']
37 | query = self.tokenizer.encode(example['query'],
38 | add_special_tokens=False,
39 | max_length=self.query_max_length,
40 | truncation=True)
41 | return {'text_id': query_id, 'text': query}
42 |
43 |
44 | class CorpusPreProcessor:
45 | def __init__(self, tokenizer, text_max_length=256, separator=' '):
46 | self.tokenizer = tokenizer
47 | self.text_max_length = text_max_length
48 | self.separator = separator
49 |
50 | def __call__(self, example):
51 | docid = example['docid']
52 | text = example['title'] + self.separator + example['text'] if 'title' in example else example['text']
53 | text = self.tokenizer.encode(text,
54 | add_special_tokens=False,
55 | max_length=self.text_max_length,
56 | truncation=True)
57 | return {'text_id': docid, 'text': text}
58 |
59 | class EvalPreProcessor:
60 | def __init__(self, tokenizer, qry_max_length=32, psg_max_length=256, separator=' '):
61 | self.tokenizer = tokenizer
62 | self.qry_max_length = qry_max_length
63 | self.psg_max_length = psg_max_length
64 | self.separator = separator
65 |
66 | def __call__(self, example):
67 | docid = example['docid']
68 | qry_text = example['qry_text']
69 | qry_text = self.tokenizer.encode(qry_text,
70 | add_special_tokens=False,
71 | max_length=self.qry_max_length,
72 | truncation=True)
73 | psg_text = example['title'] + self.separator + example['psg_text'] if 'title' in example else example['psg_text']
74 | psg_text = self.tokenizer.encode(psg_text,
75 | add_special_tokens=False,
76 | max_length=self.psg_max_length,
77 | truncation=True)
78 | return {'text_id': docid, 'qry_text': qry_text, 'psg_text': psg_text}
79 |
--------------------------------------------------------------------------------
/tevatron/driver/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/driver/__init__.py
--------------------------------------------------------------------------------
/tevatron/driver/encode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 |
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoConfig, AutoTokenizer
14 | from transformers import (
15 | HfArgumentParser,
16 | )
17 |
18 | from tevatron.arguments import ModelArguments, DataArguments, \
19 | DenseTrainingArguments as TrainingArguments
20 | from tevatron.data import EncodeDataset, EncodeCollator
21 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset
22 | from tevatron.DHR.utils import densify
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
38 | raise NotImplementedError('Multi-GPU encoding is not supported.')
39 |
40 | # Setup logging
41 | logging.basicConfig(
42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
43 | datefmt="%m/%d/%Y %H:%M:%S",
44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
45 | )
46 |
47 | num_labels = 1
48 | config = AutoConfig.from_pretrained(
49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
50 | num_labels=num_labels,
51 | output_hidden_states=True,
52 | cache_dir=model_args.cache_dir,
53 | )
54 | tokenizer = AutoTokenizer.from_pretrained(
55 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
56 | cache_dir=model_args.cache_dir,
57 | use_fast=False,
58 | )
59 |
60 | if (model_args.model).lower() == 'dhr':
61 | from tevatron.DHR.modeling import DHRModelForInference
62 | from tevatron.DHR.modeling import DHROutput as Output
63 | logger.info("Encoding model DHR")
64 | model = DHRModelForInference.build(
65 | model_args=model_args,
66 | config=config,
67 | cache_dir=model_args.cache_dir,
68 | )
69 | elif (model_args.model).lower() == 'dlr':
70 | from tevatron.DHR.modeling import DHRModelForInference
71 | from tevatron.DHR.modeling import DHROutput as Output
72 | logger.info("Encoding model DLR")
73 | model_args.combine_cls = False
74 | model = DHRModelForInference.build(
75 | model_args=model_args,
76 | config=config,
77 | cache_dir=model_args.cache_dir,
78 | )
79 | elif (model_args.model).lower() == 'agg':
80 | from tevatron.Aggretriever.modeling import DenseModelForInference
81 | from tevatron.Aggretriever.modeling import DenseOutput as Output
82 | logger.info("Encoding model Dense (AGG)")
83 | model = DenseModelForInference.build(
84 | model_args=model_args,
85 | config=config,
86 | cache_dir=model_args.cache_dir,
87 | )
88 | elif (model_args.model).lower() == 'dense':
89 | from tevatron.Dense.modeling import DenseModelForInference
90 | from tevatron.Dense.modeling import DenseOutput as Output
91 | logger.info("Encding model Dense (CLS)")
92 | model = DenseModelForInference.build(
93 | model_args=model_args,
94 | config=config,
95 | cache_dir=model_args.cache_dir,
96 | )
97 | else:
98 | raise ValueError('input model is not supported')
99 |
100 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
101 | if data_args.encode_is_qry:
102 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
103 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
104 | else:
105 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
106 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
107 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
108 | tokenizer, max_len=text_max_length)
109 |
110 | encode_loader = DataLoader(
111 | encode_dataset,
112 | batch_size=training_args.per_device_eval_batch_size,
113 | collate_fn=EncodeCollator(
114 | tokenizer,
115 | max_length=text_max_length,
116 | padding='max_length'
117 | ),
118 | shuffle=False,
119 | drop_last=False,
120 | num_workers=training_args.dataloader_num_workers,
121 | )
122 |
123 |
124 |
125 | def initialize_reps(data_num, dim, dtype):
126 | return np.zeros((data_num, dim), dtype=dtype)
127 |
128 |
129 | offset = 0
130 | lookup_indices = []
131 | model = model.to(training_args.device)
132 | model.eval()
133 |
134 | data_num = len(encode_dataset)
135 | value_encoded, index_encoded = None, None
136 |
137 | for (batch_ids, batch) in tqdm(encode_loader):
138 | batch_size = len(batch_ids)
139 | lookup_indices.extend(batch_ids)
140 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext():
141 | with torch.no_grad():
142 | for k, v in batch.items():
143 | batch[k] = v.to(training_args.device)
144 |
145 | if data_args.encode_is_qry:
146 |
147 | model_output: Output = model(query=batch)
148 |
149 | if (model_args.model).lower() == 'dense' or (model_args.model).lower() == 'agg':
150 | reps = model_output.q_reps.cpu().detach().numpy()
151 | if value_encoded is None:
152 | value_encoded = initialize_reps(data_num, reps.shape[1], np.float16)
153 | value_encoded[offset: (offset + batch_size), :] = reps
154 | else:
155 | dlr_value_reps, dlr_index_reps = densify(model_output.q_lexical_reps, model_args.dlr_out_dim)
156 | dlr_value_reps = dlr_value_reps.cpu().detach().numpy()
157 | dlr_index_reps = dlr_index_reps.cpu().detach().numpy().astype(np.uint8)
158 | cls_reps = model_output.q_semantic_reps.cpu().detach().numpy()
159 |
160 | if value_encoded is None:
161 | if cls_reps is None:
162 | cls_dim = 0
163 | else:
164 | cls_dim = cls_reps.shape[1]
165 | value_encoded = initialize_reps(data_num, dlr_value_reps.shape[1] + cls_dim, np.float16)
166 | index_encoded = initialize_reps(data_num, dlr_index_reps.shape[1], np.uint8)
167 | value_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_value_reps
168 | index_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_index_reps
169 | if cls_reps is not None:
170 | value_encoded[offset: (offset + batch_size), model_args.dlr_out_dim:] = cls_reps
171 |
172 | else:
173 | model_output: Output = model(passage=batch)
174 | if (model_args.model).lower() == 'dense' or (model_args.model).lower() == 'agg':
175 | reps = model_output.p_reps.cpu().detach().numpy()
176 | if value_encoded is None:
177 | value_encoded = initialize_reps(data_num, reps.shape[1], np.float16)
178 | value_encoded[offset: (offset + batch_size), :] = reps
179 | else:
180 | dlr_value_reps, dlr_index_reps = densify(model_output.p_lexical_reps, model_args.dlr_out_dim)
181 | dlr_value_reps = dlr_value_reps.cpu().detach().numpy()
182 | dlr_index_reps = dlr_index_reps.cpu().detach().numpy().astype(np.uint8)
183 | cls_reps = model_output.p_semantic_reps.cpu().detach().numpy()
184 |
185 | if value_encoded is None:
186 | if cls_reps is None:
187 | cls_dim = 0
188 | else:
189 | cls_dim = cls_reps.shape[1]
190 | value_encoded = initialize_reps(data_num, dlr_value_reps.shape[1] + cls_dim, np.float16)
191 | index_encoded = initialize_reps(data_num, dlr_index_reps.shape[1], np.uint8)
192 | value_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_value_reps
193 | index_encoded[offset: (offset + batch_size), :model_args.dlr_out_dim] = dlr_index_reps
194 | if cls_reps is not None:
195 | value_encoded[offset: (offset + batch_size), model_args.dlr_out_dim:] = cls_reps
196 |
197 | offset += batch_size
198 |
199 | output_dir = '/'.join( (data_args.encoded_save_path).split('/')[:-1] )
200 | if not os.path.exists(output_dir):
201 | logger.info(f'{output_dir} not exists, create')
202 | os.mkdir(output_dir)
203 | with open(data_args.encoded_save_path, 'wb') as f:
204 | pickle.dump([value_encoded, index_encoded, lookup_indices], f, protocol=4)
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
--------------------------------------------------------------------------------
/tevatron/driver/eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 |
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoConfig, AutoTokenizer
14 | from transformers import (
15 | HfArgumentParser,
16 | )
17 |
18 | from tevatron.arguments import ModelArguments, DataArguments, \
19 | DenseTrainingArguments as TrainingArguments
20 | from tevatron.data import EvalDataset, EvalCollator
21 | from tevatron.datasets import HFEvalDataset
22 | from tevatron.utils import metrics
23 | METRICS_MAP = ['MAP', 'RPrec', 'NDCG', 'MRR', 'MRR@10']
24 | # from tevatron.densification.utils import densify
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | def main():
30 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
33 | else:
34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
35 | model_args: ModelArguments
36 | data_args: DataArguments
37 | training_args: TrainingArguments
38 |
39 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
40 | raise NotImplementedError('Multi-GPU encoding is not supported.')
41 |
42 | # Setup logging
43 | logging.basicConfig(
44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
45 | datefmt="%m/%d/%Y %H:%M:%S",
46 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
47 | )
48 |
49 | num_labels = 1
50 | config = AutoConfig.from_pretrained(
51 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
52 | num_labels=num_labels,
53 | output_hidden_states=True,
54 | cache_dir=model_args.cache_dir,
55 | )
56 | tokenizer = AutoTokenizer.from_pretrained(
57 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
58 | cache_dir=model_args.cache_dir,
59 | use_fast=False,
60 | )
61 |
62 | if (model_args.model).lower() == 'colbert':
63 | from tevatron.ColBERT.modeling import ColBERTForInference
64 | from tevatron.ColBERT.modeling import ColBERTOutput as Output
65 | logger.info("Evaluating model ColBERT")
66 | model = ColBERTForInference.build(
67 | model_args=model_args,
68 | config=config,
69 | cache_dir=model_args.cache_dir,
70 | )
71 | elif (model_args.model).lower() == 'dhr':
72 | from tevatron.DHR.modeling import DHRModelForInference
73 | from tevatron.DHR.modeling import DHROutput as Output
74 | logger.info("Evaluating model DHR")
75 | model = DHRModelForInference.build(
76 | model_args=model_args,
77 | config=config,
78 | cache_dir=model_args.cache_dir,
79 | )
80 | elif (model_args.model).lower() == 'dlr':
81 | from tevatron.DHR.modeling import DHRModelForInference
82 | from tevatron.DHR.modeling import DHROutput as Output
83 | logger.info("Evaluating model DHR")
84 | model_args.combine_cls = False
85 | model = DHRModelForInference.build(
86 | model_args=model_args,
87 | config=config,
88 | cache_dir=model_args.cache_dir,
89 | )
90 | elif (model_args.model).lower() == 'agg':
91 | from tevatron.Aggretriever.modeling import DenseModelForInference
92 | from tevatron.Aggretriever.modeling import DenseOutput as Output
93 | logger.info("Evaluating model Dense (AGG)")
94 | model = DHRModelForInference.build(
95 | model_args=model_args,
96 | config=config,
97 | cache_dir=model_args.cache_dir,
98 | )
99 | elif (model_args.model).lower() == 'dense':
100 | from tevatron.Dense.modeling import DenseModelForInference
101 | from tevatron.Dense.modeling import DenseOutput as Output
102 | logger.info("Evaluating model Dense (CLS)")
103 | model = DenseModelForInference.build(
104 | model_args=model_args,
105 | config=config,
106 | cache_dir=model_args.cache_dir,
107 | )
108 | else:
109 | raise ValueError('input model is not supported')
110 |
111 | eval_dataset = HFEvalDataset(tokenizer=tokenizer, data_args=data_args,
112 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
113 | eval_dataset = EvalDataset(data_args, eval_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
114 | tokenizer)
115 |
116 | eval_loader = DataLoader(
117 | eval_dataset,
118 | batch_size=training_args.per_device_eval_batch_size,
119 | collate_fn=EvalCollator(
120 | tokenizer,
121 | max_p_len=data_args.p_max_len,
122 | max_q_len=data_args.q_max_len,
123 | padding='max_length'
124 | ),
125 | shuffle=False,
126 | drop_last=False,
127 | num_workers=training_args.dataloader_num_workers,
128 | )
129 |
130 | model = model.to(training_args.device)
131 | model.eval()
132 |
133 | num_candidates_per_qry = 1000
134 | if num_candidates_per_qry%training_args.per_device_eval_batch_size!=0:
135 | raise ValueError('Batch size should be a factor of {}'.format(num_candidates_per_qry))
136 | all_metrics = np.zeros(len(METRICS_MAP))
137 | num_examples = 0
138 | qids = []
139 | candidiate_psg_ids = []
140 | scores = []
141 | labels = []
142 | for (batch_qry_ids, batch_qry_featutres, batch_psg_ids, batch_psg_features, rels) in tqdm(eval_loader):
143 | if len(set(batch_qry_ids)) != 1:
144 | raise ValueError('Tere is other query in the Eval batch!')
145 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext():
146 | with torch.no_grad():
147 | for k, v in batch_qry_featutres.items():
148 | batch_qry_featutres[k] = v.to(training_args.device)
149 | for k, v in batch_psg_features.items():
150 | batch_psg_features[k] = v.to(training_args.device)
151 | model_output: Output = model(query=batch_qry_featutres, passage=batch_psg_features)
152 |
153 | qids += batch_qry_ids
154 | candidiate_psg_ids += batch_psg_ids
155 | scores += model_output.scores.cpu().numpy().tolist()
156 | labels += rels
157 | if len(candidiate_psg_ids) == num_candidates_per_qry:
158 | if len(set(qids)) != 1:
159 | raise ValueError('Tere is other query in the set!')
160 | gt = set(list(np.where(np.array(labels) > 0)[0]))
161 |
162 | predict_doc = np.array(scores).argsort()[::-1]
163 | all_metrics += metrics.metrics(gt=gt, pred=predict_doc, metrics_map=METRICS_MAP)
164 | num_examples+=1
165 | qids = []
166 | candidiate_psg_ids = []
167 | scores = []
168 | labels = []
169 | if (num_examples%10==0):
170 | logging.warn("Read {} examples, Metrics so far:".format(num_examples))
171 | logging.warn(" ".join(METRICS_MAP))
172 | logging.warn(all_metrics / num_examples)
173 | if num_examples==200:
174 | break
175 | # Write results
176 |
177 |
178 | output_dir = '/'.join( (data_args.encoded_save_path).split('/')[:-1] )
179 |
180 |
181 |
182 | if __name__ == "__main__":
183 | main()
184 |
--------------------------------------------------------------------------------
/tevatron/driver/jax_encode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 |
6 | import datasets
7 | import jax
8 | import numpy as np
9 | from flax.training.common_utils import shard
10 | from jax import pmap
11 | from tevatron.arguments import DataArguments
12 | from tevatron.arguments import DenseTrainingArguments as TrainingArguments
13 | from tevatron.arguments import ModelArguments
14 | from tevatron.data import EncodeCollator, EncodeDataset
15 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset
16 | from torch.utils.data import DataLoader
17 | from tqdm import tqdm
18 | from flax.training.train_state import TrainState
19 | from flax import jax_utils
20 | import optax
21 | from transformers import (AutoConfig, AutoTokenizer, FlaxAutoModel,
22 | HfArgumentParser, TensorType)
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | # Setup logging
38 | logging.basicConfig(
39 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
40 | datefmt="%m/%d/%Y %H:%M:%S",
41 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
42 | )
43 |
44 | num_labels = 1
45 | config = AutoConfig.from_pretrained(
46 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
47 | num_labels=num_labels,
48 | cache_dir=model_args.cache_dir,
49 | )
50 | tokenizer = AutoTokenizer.from_pretrained(
51 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
52 | cache_dir=model_args.cache_dir,
53 | use_fast=False,
54 | )
55 |
56 | model = FlaxAutoModel.from_pretrained(model_args.model_name_or_path, config=config, from_pt=False)
57 |
58 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
59 | if data_args.encode_is_qry:
60 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
61 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
62 | else:
63 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
64 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
65 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
66 | tokenizer, max_len=text_max_length)
67 |
68 | # prepare padding batch (for last nonfull batch)
69 | dataset_size = len(encode_dataset)
70 | padding_prefix = "padding_"
71 | total_batch_size = len(jax.devices()) * training_args.per_device_eval_batch_size
72 | features = list(encode_dataset.encode_data.features.keys())
73 | padding_batch = {features[0]: [], features[1]: []}
74 | for i in range(total_batch_size - (dataset_size % total_batch_size)):
75 | padding_batch["text_id"].append(f"{padding_prefix}{i}")
76 | padding_batch["text"].append([0])
77 | padding_batch = datasets.Dataset.from_dict(padding_batch)
78 | encode_dataset.encode_data = datasets.concatenate_datasets([encode_dataset.encode_data, padding_batch])
79 |
80 | encode_loader = DataLoader(
81 | encode_dataset,
82 | batch_size=training_args.per_device_eval_batch_size * len(jax.devices()),
83 | collate_fn=EncodeCollator(
84 | tokenizer,
85 | max_length=text_max_length,
86 | padding='max_length',
87 | pad_to_multiple_of=16,
88 | return_tensors=TensorType.NUMPY,
89 | ),
90 | shuffle=False,
91 | drop_last=False,
92 | num_workers=training_args.dataloader_num_workers,
93 | )
94 |
95 | # craft a fake state for now to replicate on devices
96 | adamw = optax.adamw(0.0001)
97 | state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
98 |
99 | def encode_step(batch, state):
100 | embedding = state.apply_fn(**batch, params=state.params, train=False)[0]
101 | return embedding[:, 0]
102 |
103 | p_encode_step = pmap(encode_step)
104 | state = jax_utils.replicate(state)
105 |
106 | encoded = []
107 | lookup_indices = []
108 |
109 | for (batch_ids, batch) in tqdm(encode_loader):
110 | lookup_indices.extend(batch_ids)
111 | batch_embeddings = p_encode_step(shard(batch.data), state)
112 | encoded.extend(np.concatenate(batch_embeddings, axis=0))
113 | with open(data_args.encoded_save_path, 'wb') as f:
114 | pickle.dump((encoded[:dataset_size], lookup_indices[:dataset_size]), f)
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/tevatron/driver/jax_train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from functools import partial
5 |
6 | import datasets
7 | import jax
8 | import jax.numpy as jnp
9 | import optax
10 | from flax import jax_utils, traverse_util
11 | from flax.jax_utils import prefetch_to_device
12 | from flax.training.common_utils import get_metrics, shard
13 | from torch.utils.data import DataLoader, IterableDataset
14 | from tqdm import tqdm
15 | from transformers import AutoConfig, AutoTokenizer, FlaxAutoModel
16 | from transformers import (
17 | HfArgumentParser,
18 | set_seed,
19 | )
20 |
21 | from tevatron.arguments import ModelArguments, DataArguments, DenseTrainingArguments
22 | from tevatron.tevax.training import TiedParams, RetrieverTrainState, retriever_train_step, grad_cache_train_step, \
23 | DualParams
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def main():
29 | parser = HfArgumentParser((ModelArguments, DataArguments, DenseTrainingArguments))
30 |
31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
33 | else:
34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
35 | model_args: ModelArguments
36 | data_args: DataArguments
37 | training_args: DenseTrainingArguments
38 |
39 | if (
40 | os.path.exists(training_args.output_dir)
41 | and os.listdir(training_args.output_dir)
42 | and training_args.do_train
43 | and not training_args.overwrite_output_dir
44 | ):
45 | raise ValueError(
46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
47 | )
48 |
49 | # Setup logging
50 | logging.basicConfig(
51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
52 | datefmt="%m/%d/%Y %H:%M:%S",
53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
54 | )
55 | logger.warning(
56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
57 | training_args.local_rank,
58 | training_args.device,
59 | training_args.n_gpu,
60 | bool(training_args.local_rank != -1),
61 | training_args.fp16,
62 | )
63 | logger.info("Training/evaluation parameters %s", training_args)
64 | logger.info("MODEL parameters %s", model_args)
65 |
66 | set_seed(training_args.seed)
67 |
68 | config = AutoConfig.from_pretrained(
69 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
70 | cache_dir=model_args.cache_dir,
71 | )
72 | tokenizer = AutoTokenizer.from_pretrained(
73 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
74 | cache_dir=model_args.cache_dir,
75 | )
76 | try:
77 | model = FlaxAutoModel.from_pretrained(
78 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
79 | )
80 | except:
81 | model = FlaxAutoModel.from_pretrained(
82 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype),
83 | from_pt=True
84 | )
85 |
86 | if data_args.train_dir:
87 | data_files = {
88 | 'train': data_args.train_path
89 | }
90 | else:
91 | data_files = None
92 |
93 | train_dataset = \
94 | datasets.load_dataset(data_args.dataset_name, data_args.dataset_language, cache_dir=model_args.cache_dir,
95 | data_files=data_files)[data_args.dataset_split]
96 |
97 | def tokenize_train(example):
98 | tokenize = partial(tokenizer, return_attention_mask=False, return_token_type_ids=False, padding=False,
99 | truncation=True)
100 | query = example['query']
101 | pos_psgs = [p['title'] + " " + p['text'] for p in example['positive_passages']]
102 | neg_psgs = [p['title'] + " " + p['text'] for p in example['negative_passages']]
103 |
104 | example['query_input_ids'] = dict(tokenize(query, max_length=32))
105 | example['pos_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in pos_psgs]
106 | example['neg_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in neg_psgs]
107 |
108 | return example
109 |
110 | train_data = train_dataset.map(
111 | tokenize_train,
112 | batched=False,
113 | num_proc=data_args.dataset_proc_num,
114 | desc="Running tokenizer on train dataset",
115 | )
116 | train_data = train_data.filter(
117 | function=lambda data: len(data["pos_psgs_input_ids"]) >= 1 and \
118 | len(data["neg_psgs_input_ids"]) >= data_args.train_n_passages-1, num_proc=64
119 | )
120 |
121 | class TrainDataset:
122 | def __init__(self, train_data, group_size, tokenizer):
123 | self.group_size = group_size
124 | self.data = train_data
125 | self.tokenizer = tokenizer
126 |
127 | def __len__(self):
128 | return len(self.data)
129 |
130 | def get_example(self, i, epoch):
131 | example = self.data[i]
132 | q = example['query_input_ids']
133 |
134 | pp = example['pos_psgs_input_ids']
135 | p = pp[0]
136 |
137 | nn = example['neg_psgs_input_ids']
138 | off = epoch * (self.group_size - 1) % len(nn)
139 | nn = nn * 2
140 | nn = nn[off: off + self.group_size - 1]
141 |
142 | return q, [p] + nn
143 |
144 | def get_batch(self, indices, epoch):
145 | qq, dd = zip(*[self.get_example(i, epoch) for i in map(int, indices)])
146 | dd = sum(dd, [])
147 | return dict(tokenizer.pad(qq, max_length=32, padding='max_length', return_tensors='np')), dict(
148 | tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np'))
149 |
150 | train_dataset = TrainDataset(train_data, data_args.train_n_passages, tokenizer)
151 |
152 | def create_learning_rate_fn(
153 | train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int,
154 | learning_rate: float
155 | ):
156 | """Returns a linear warmup, linear_decay learning rate function."""
157 | steps_per_epoch = train_ds_size // train_batch_size
158 | num_train_steps = steps_per_epoch * num_train_epochs
159 | warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
160 | decay_fn = optax.linear_schedule(
161 | init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
162 | )
163 | schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
164 | return schedule_fn
165 |
166 | def _decay_mask_fn(params):
167 | flat_params = traverse_util.flatten_dict(params)
168 | layer_norm_params = [
169 | (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
170 | ]
171 | flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
172 | return traverse_util.unflatten_dict(flat_mask)
173 |
174 | def decay_mask_fn(params):
175 | param_nodes, treedef = jax.tree_flatten(params, lambda v: isinstance(v, dict))
176 | masks = [_decay_mask_fn(param_node) for param_node in param_nodes]
177 | return jax.tree_unflatten(treedef, masks)
178 |
179 | num_epochs = int(training_args.num_train_epochs)
180 | train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
181 | steps_per_epoch = len(train_dataset) // train_batch_size
182 | total_train_steps = steps_per_epoch * num_epochs
183 |
184 | linear_decay_lr_schedule_fn = create_learning_rate_fn(
185 | len(train_dataset),
186 | train_batch_size,
187 | int(training_args.num_train_epochs),
188 | int(total_train_steps * 0.1),
189 | training_args.learning_rate,
190 | )
191 |
192 | adamw = optax.adamw(
193 | learning_rate=linear_decay_lr_schedule_fn,
194 | b1=training_args.adam_beta1,
195 | b2=training_args.adam_beta2,
196 | eps=training_args.adam_epsilon,
197 | weight_decay=training_args.weight_decay,
198 | mask=decay_mask_fn,
199 | )
200 |
201 | if model_args.untie_encoder:
202 | params = DualParams.create(model.params)
203 | else:
204 | params = TiedParams.create(model.params)
205 | state = RetrieverTrainState.create(apply_fn=model.__call__, params=params, tx=adamw)
206 |
207 | if training_args.grad_cache:
208 | q_n_subbatch = train_batch_size // training_args.gc_q_chunk_size
209 | p_n_subbatch = train_batch_size * data_args.train_n_passages // training_args.gc_p_chunk_size
210 | p_train_step = jax.pmap(
211 | partial(grad_cache_train_step, q_n_subbatch=q_n_subbatch, p_n_subbatch=p_n_subbatch),
212 | "device"
213 | )
214 | else:
215 | p_train_step = jax.pmap(
216 | retriever_train_step,
217 | "device"
218 | )
219 |
220 | state = jax_utils.replicate(state)
221 | rng = jax.random.PRNGKey(training_args.seed)
222 | dropout_rngs = jax.random.split(rng, jax.local_device_count())
223 |
224 | class IterableTrain(IterableDataset):
225 | def __init__(self, dataset, batch_idx, epoch):
226 | super(IterableTrain).__init__()
227 | self.dataset = dataset
228 | self.batch_idx = batch_idx
229 | self.epoch = epoch
230 |
231 | def __iter__(self):
232 | for idx in self.batch_idx:
233 | batch = self.dataset.get_batch(idx, self.epoch)
234 | batch = shard(batch)
235 | yield batch
236 |
237 | logger.info("***** Running training *****")
238 | logger.info(f" Num examples = {len(train_dataset)}")
239 | logger.info(f" Num Epochs = {num_epochs}")
240 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
241 | logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
242 | logger.info(f" Total optimization steps = {total_train_steps}")
243 |
244 | train_metrics = []
245 | for epoch in tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0):
246 | # ======================== Training ================================
247 | # Create sampling rng
248 | rng, input_rng = jax.random.split(rng)
249 |
250 | steps_per_epoch = len(train_dataset) // train_batch_size
251 |
252 | batch_idx = jax.random.permutation(input_rng, len(train_dataset))
253 | batch_idx = batch_idx[: steps_per_epoch * train_batch_size]
254 | batch_idx = batch_idx.reshape((steps_per_epoch, train_batch_size)).tolist()
255 |
256 | train_loader = prefetch_to_device(
257 | iter(DataLoader(
258 | IterableTrain(train_dataset, batch_idx, epoch),
259 | num_workers=16, prefetch_factor=256, batch_size=None, collate_fn=lambda v: v)
260 | ), 2)
261 |
262 | # train
263 | epochs = tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)
264 | for step in epochs:
265 | cur_step = epoch * (len(train_dataset) // train_batch_size) + step
266 | batch = next(train_loader)
267 |
268 | loss, state, dropout_rngs = p_train_step(state, *batch, dropout_rngs)
269 | train_metrics.append({'loss': loss})
270 |
271 | if cur_step % training_args.logging_steps == 0 and cur_step > 0:
272 | train_metrics = get_metrics(train_metrics)
273 | print(
274 | f"Step... ({cur_step} | Loss: {train_metrics['loss'].mean()},"
275 | f" Learning Rate: {linear_decay_lr_schedule_fn(cur_step)})",
276 | flush=True,
277 | )
278 | train_metrics = []
279 |
280 | epochs.write(
281 | f"Epoch... ({epoch + 1}/{num_epochs})"
282 | )
283 |
284 | params = jax_utils.unreplicate(state.params)
285 |
286 | if model_args.untie_encoder:
287 | os.makedirs(training_args.output_dir, exist_ok=True)
288 | model.save_pretrained(os.path.join(training_args.output_dir, 'query_encoder'), params=params.q_params)
289 | model.save_pretrained(os.path.join(training_args.output_dir, 'passage_encoder'), params=params.p_params)
290 | else:
291 | model.save_pretrained(training_args.output_dir, params=params.p_params)
292 | tokenizer.save_pretrained(training_args.output_dir)
293 |
294 |
295 | if __name__ == "__main__":
296 | main()
297 |
--------------------------------------------------------------------------------
/tevatron/driver/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | from transformers import AutoConfig, AutoTokenizer
6 | from transformers import (
7 | HfArgumentParser,
8 | set_seed,
9 | )
10 |
11 | from tevatron.arguments import ModelArguments, DataArguments, ColBERTModelArguments, \
12 | DenseTrainingArguments as TrainingArguments
13 | from tevatron.data import TrainDataset, TrainTASBDataset, QPCollator
14 | from tevatron.trainer import DenseTrainer as Trainer, GCTrainer
15 | from tevatron.datasets import HFTrainDataset, HFCorpusDataset
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def main():
21 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
22 |
23 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
24 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
25 | else:
26 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
27 |
28 | model_args: ModelArguments
29 | data_args: DataArguments
30 | training_args: TrainingArguments
31 |
32 |
33 |
34 | if (
35 | os.path.exists(training_args.output_dir)
36 | and os.listdir(training_args.output_dir)
37 | and training_args.do_train
38 | and not training_args.overwrite_output_dir
39 | ):
40 | raise ValueError(
41 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
42 | )
43 |
44 | # Setup logging
45 | logging.basicConfig(
46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
47 | datefmt="%m/%d/%Y %H:%M:%S",
48 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
49 | )
50 | logger.warning(
51 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
52 | training_args.local_rank,
53 | training_args.device,
54 | training_args.n_gpu,
55 | bool(training_args.local_rank != -1),
56 | training_args.fp16,
57 | )
58 | logger.info("Training/evaluation parameters %s", training_args)
59 | logger.info("MODEL parameters %s", model_args)
60 |
61 | set_seed(training_args.seed)
62 |
63 | num_labels = 1
64 | config = AutoConfig.from_pretrained(
65 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
66 | num_labels=num_labels,
67 | output_hidden_states=True,
68 | cache_dir=model_args.cache_dir,
69 | )
70 | tokenizer = AutoTokenizer.from_pretrained(
71 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
72 | cache_dir=model_args.cache_dir,
73 | use_fast=False,
74 | )
75 |
76 | teacher_model = None
77 | if model_args.tct:
78 | if model_args.teacher_model_name_or_path is None:
79 | raise ValueError(
80 | f"when use --tct option, you should input --teacher_model_name_or_path"
81 | )
82 | # use default setting
83 | teacher_model_args = ColBERTModelArguments()
84 | teacher_model_args.model_name_or_path = model_args.teacher_model_name_or_path
85 | colbert_config = AutoConfig.from_pretrained(
86 | teacher_model_args.config_name if teacher_model_args.config_name else teacher_model_args.model_name_or_path,
87 | num_labels=num_labels,
88 | output_hidden_states=True,
89 | cache_dir=teacher_model_args.cache_dir,
90 | )
91 |
92 | from tevatron.ColBERT.modeling import ColBERTForInference, ColBERTOutput
93 | from tevatron.ColBERT.modeling import ColBERTOutput as Output
94 | logger.info("Call model ColBERT as listwise teacher")
95 | teacher_model = ColBERTForInference.build(
96 | model_args=teacher_model_args,
97 | data_args=data_args,
98 | train_args=training_args,
99 | config=colbert_config,
100 | cache_dir=teacher_model_args.cache_dir,
101 | )
102 |
103 | if (model_args.model).lower() == 'colbert':
104 | from tevatron.ColBERT.modeling import ColBERT
105 | logger.info("Training model ColBERT")
106 | model = ColBERT.build(
107 | model_args,
108 | data_args,
109 | training_args,
110 | config=config,
111 | cache_dir=model_args.cache_dir,
112 | )
113 | elif (model_args.model).lower() == 'dhr':
114 | from tevatron.DHR.modeling import DHRModel
115 | logger.info("Training model DHR")
116 | model = DHRModel.build(
117 | model_args,
118 | data_args,
119 | training_args,
120 | teacher_model,
121 | config=config,
122 | cache_dir=model_args.cache_dir,
123 | )
124 | elif (model_args.model).lower() == 'dlr':
125 | from tevatron.DHR.modeling import DHRModel
126 | logger.info("Training model DLR")
127 | model_args.combine_cls = False
128 | model = DHRModel.build(
129 | model_args,
130 | data_args,
131 | training_args,
132 | teacher_model,
133 | config=config,
134 | cache_dir=model_args.cache_dir,
135 | )
136 | elif (model_args.model).lower() == 'agg':
137 | from tevatron.Aggretriever.modeling import DenseModel
138 | logger.info("Training model Dense (AGG)")
139 | model = DenseModel.build(
140 | model_args,
141 | data_args,
142 | training_args,
143 | config=config,
144 | cache_dir=model_args.cache_dir,
145 | )
146 | elif (model_args.model).lower() == 'dense':
147 | from tevatron.Dense.modeling import DenseModel
148 | logger.info("Training model Dense (CLS)")
149 | model = DenseModel.build(
150 | model_args,
151 | data_args,
152 | training_args,
153 | config=config,
154 | cache_dir=model_args.cache_dir,
155 | )
156 | else:
157 | raise ValueError('input model is not supported')
158 |
159 |
160 | train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args,
161 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
162 |
163 | corpus_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
164 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
165 | ### Todo: set augument, using TASB training dataset
166 | # train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer)
167 | train_dataset = TrainTASBDataset(data_args, model_args.kd, train_dataset.process(), corpus_dataset.process(), tokenizer)
168 |
169 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer
170 | trainer = trainer_cls(
171 | model=model,
172 | args=training_args,
173 | train_dataset=train_dataset,
174 | data_collator=QPCollator(
175 | tokenizer,
176 | max_p_len=data_args.p_max_len,
177 | max_q_len=data_args.q_max_len
178 | ),
179 | )
180 | train_dataset.trainer = trainer
181 |
182 | trainer.train() # TODO: resume training
183 | trainer.save_model()
184 | if trainer.is_world_process_zero():
185 | tokenizer.save_pretrained(training_args.output_dir)
186 |
187 |
188 | if __name__ == "__main__":
189 | main()
190 |
--------------------------------------------------------------------------------
/tevatron/faiss_retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from .retriever import BaseFaissIPRetriever
2 |
--------------------------------------------------------------------------------
/tevatron/faiss_retriever/__main__.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import numpy as np
4 | import glob
5 | from argparse import ArgumentParser
6 | from itertools import chain
7 | from tqdm import tqdm
8 |
9 | from .retriever import BaseFaissIPRetriever
10 |
11 | import logging
12 | logger = logging.getLogger(__name__)
13 | logging.basicConfig(
14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15 | datefmt="%m/%d/%Y %H:%M:%S",
16 | level=logging.INFO,
17 | )
18 |
19 |
20 | def search_queries(retriever, q_reps, p_lookup, args):
21 | if args.batch_size > 0:
22 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size)
23 | else:
24 | all_scores, all_indices = retriever.search(q_reps, args.depth)
25 |
26 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices]
27 | psg_indices = np.array(psg_indices)
28 | return all_scores, psg_indices
29 |
30 |
31 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file):
32 | with open(ranking_save_file, 'w') as f:
33 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices):
34 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
35 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
36 | for s, idx in score_list:
37 | f.write(f'{qid}\t{idx}\t{s}\n')
38 |
39 |
40 | def pickle_load(path):
41 | with open(path, 'rb') as f:
42 | obj = pickle.load(f)
43 | return obj
44 |
45 |
46 | def pickle_save(obj, path):
47 | with open(path, 'wb') as f:
48 | pickle.dump(obj, f)
49 |
50 |
51 | def main():
52 | parser = ArgumentParser()
53 | parser.add_argument('--query_reps', required=True)
54 | parser.add_argument('--passage_reps', required=True)
55 | parser.add_argument('--batch_size', type=int, default=128)
56 | parser.add_argument('--depth', type=int, default=1000)
57 | parser.add_argument('--save_ranking_to', required=True)
58 | parser.add_argument('--save_text', action='store_true')
59 |
60 | args = parser.parse_args()
61 |
62 | index_files = glob.glob(args.passage_reps)
63 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
64 |
65 | p_reps_0, p_lookup_0 = pickle_load(index_files[0])
66 | retriever = BaseFaissIPRetriever(p_reps_0)
67 |
68 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
69 | if len(index_files) > 1:
70 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))
71 | look_up = []
72 | for p_reps, p_lookup in shards:
73 | retriever.add(p_reps)
74 | look_up += p_lookup
75 |
76 | q_reps, q_lookup = pickle_load(args.query_reps)
77 | q_reps = q_reps
78 |
79 | logger.info('Index Search Start')
80 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args)
81 | logger.info('Index Search Finished')
82 |
83 | if args.save_text:
84 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to)
85 | else:
86 | pickle_save((all_scores, psg_indices), args.save_ranking_to)
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/tevatron/faiss_retriever/reducer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import faiss
3 | from argparse import ArgumentParser
4 | from tqdm import tqdm
5 | from typing import Iterable, Tuple
6 | from numpy import ndarray
7 | from .__main__ import pickle_load, write_ranking
8 |
9 |
10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]):
11 | rh = None
12 | for scores, indices in results:
13 | if rh is None:
14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.')
15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1])
16 | rh.add_result(-scores, indices)
17 | rh.finalize()
18 | corpus_scores, corpus_indices = -rh.D, rh.I
19 |
20 | return corpus_scores, corpus_indices
21 |
22 |
23 | def main():
24 | parser = ArgumentParser()
25 | parser.add_argument('--score_dir', required=True)
26 | parser.add_argument('--query', required=True)
27 | parser.add_argument('--save_ranking_to', required=True)
28 | args = parser.parse_args()
29 |
30 | partitions = glob.glob(f'{args.score_dir}/*')
31 |
32 | corpus_scores, corpus_indices = combine_faiss_results(map(pickle_load, tqdm(partitions)))
33 |
34 | _, q_lookup = pickle_load(args.query)
35 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/tevatron/faiss_retriever/retriever.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import faiss
3 |
4 | import logging
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class BaseFaissIPRetriever:
10 | def __init__(self, init_reps: np.ndarray):
11 | index = faiss.IndexFlatIP(init_reps.shape[1])
12 | self.index = index
13 |
14 | def add(self, p_reps: np.ndarray):
15 | self.index.add(p_reps)
16 |
17 | def search(self, q_reps: np.ndarray, k: int):
18 | return self.index.search(q_reps, k)
19 |
20 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int):
21 | num_query = q_reps.shape[0]
22 | all_scores = []
23 | all_indices = []
24 | for start_idx in range(0, num_query, batch_size):
25 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k)
26 | all_scores.append(nn_scores)
27 | all_indices.append(nn_indices)
28 | all_scores = np.concatenate(all_scores, axis=0)
29 | all_indices = np.concatenate(all_indices, axis=0)
30 |
31 | return all_scores, all_indices
32 |
33 |
34 | class FaissRetriever(BaseFaissIPRetriever):
35 |
36 | def __init__(self, init_reps: np.ndarray, factory_str: str):
37 | index = faiss.index_factory(init_reps.shape[1], factory_str)
38 | self.index = index
39 | self.index.verbose = True
40 | if not self.index.is_trained:
41 | self.index.train(init_reps)
42 |
--------------------------------------------------------------------------------
/tevatron/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from torch import distributed as dist
5 |
6 |
7 | class SimpleContrastiveLoss:
8 | def __init__(self, n_target: int = 1):
9 | self.target_per_qry = n_target
10 |
11 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
12 | if target is None:
13 | assert x.size(0) * self.target_per_qry == y.size(0)
14 | target = torch.arange(
15 | 0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device, dtype=torch.long)
16 | logits = torch.matmul(x, y.transpose(0, 1))
17 | return F.cross_entropy(logits, target, reduction=reduction)
18 |
19 |
20 | class DistributedContrastiveLoss(SimpleContrastiveLoss):
21 | def __init__(self, n_target: int = 0, scale_loss: bool = True):
22 | assert dist.is_initialized(), "Distributed training has not been properly initialized."
23 | super().__init__(n_target=n_target)
24 | self.word_size = dist.get_world_size()
25 | self.rank = dist.get_rank()
26 | self.scale_loss = scale_loss
27 |
28 | def __call__(self, x: Tensor, y: Tensor, **kwargs):
29 | dist_x = self.gather_tensor(x)
30 | dist_y = self.gather_tensor(y)
31 | loss = super().__call__(dist_x, dist_y, **kwargs)
32 | if self.scale_loss:
33 | loss = loss * self.word_size
34 | return loss
35 |
36 | def gather_tensor(self, t):
37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)]
38 | dist.all_gather(gathered, t)
39 | gathered[self.rank] = t
40 | return torch.cat(gathered, dim=0)
--------------------------------------------------------------------------------
/tevatron/preprocessor/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocessor_tsv import SimpleTrainPreProcessor as MarcoPassageTrainPreProcessor, \
2 | SimpleCollectionPreProcessor as MarcoPassageCollectionPreProcessor
3 |
--------------------------------------------------------------------------------
/tevatron/preprocessor/preprocessor_tsv.py:
--------------------------------------------------------------------------------
1 | import json
2 | import csv
3 | import datasets
4 | from transformers import PreTrainedTokenizer
5 | from dataclasses import dataclass
6 |
7 |
8 | @dataclass
9 | class SimpleTrainPreProcessor:
10 | query_file: str
11 | collection_file: str
12 | tokenizer: PreTrainedTokenizer
13 |
14 | max_length: int = 128
15 | columns = ['text_id', 'title', 'text']
16 | title_field = 'title'
17 | text_field = 'text'
18 |
19 | def __post_init__(self):
20 | self.queries = self.read_queries(self.query_file)
21 | self.collection = datasets.load_dataset(
22 | 'csv',
23 | data_files=self.collection_file,
24 | column_names=self.columns,
25 | delimiter='\t',
26 | )['train']
27 |
28 | @staticmethod
29 | def read_queries(queries):
30 | qmap = {}
31 | with open(queries) as f:
32 | for l in f:
33 | qid, qry = l.strip().split('\t')
34 | qmap[qid] = qry
35 | return qmap
36 |
37 | @staticmethod
38 | def read_qrel(relevance_file):
39 | qrel = {}
40 | with open(relevance_file, encoding='utf8') as f:
41 | tsvreader = csv.reader(f, delimiter="\t")
42 | for [topicid, _, docid, rel] in tsvreader:
43 | assert rel == "1"
44 | if topicid in qrel:
45 | qrel[topicid].append(docid)
46 | else:
47 | qrel[topicid] = [docid]
48 | return qrel
49 |
50 | def get_query(self, q):
51 | query_encoded = self.tokenizer.encode(
52 | self.queries[q],
53 | add_special_tokens=False,
54 | max_length=self.max_length,
55 | truncation=True
56 | )
57 | return query_encoded
58 |
59 | def get_passage(self, p):
60 | entry = self.collection[int(p)]
61 | title = entry[self.title_field]
62 | title = "" if title is None else title
63 | body = entry[self.text_field]
64 | content = title + self.tokenizer.sep_token + body
65 |
66 | passage_encoded = self.tokenizer.encode(
67 | content,
68 | add_special_tokens=False,
69 | max_length=self.max_length,
70 | truncation=True
71 | )
72 |
73 | return passage_encoded
74 |
75 | def process_one(self, train):
76 | q, pp, nn = train
77 | train_example = {
78 | 'query': self.get_query(q),
79 | 'positives': [self.get_passage(p) for p in pp],
80 | 'negatives': [self.get_passage(n) for n in nn],
81 | }
82 |
83 | return json.dumps(train_example)
84 |
85 |
86 | @dataclass
87 | class SimpleCollectionPreProcessor:
88 | tokenizer: PreTrainedTokenizer
89 | separator: str = '\t'
90 | max_length: int = 128
91 |
92 | def process_line(self, line: str):
93 | xx = line.strip().split(self.separator)
94 | text_id, text = xx[0], xx[1:]
95 | text_encoded = self.tokenizer.encode(
96 | self.tokenizer.sep_token.join(text),
97 | add_special_tokens=False,
98 | max_length=self.max_length,
99 | truncation=True
100 | )
101 | encoded = {
102 | 'text_id': text_id,
103 | 'text': text_encoded
104 | }
105 | return json.dumps(encoded)
106 |
--------------------------------------------------------------------------------
/tevatron/tevax/__init__.py:
--------------------------------------------------------------------------------
1 | from .training import TiedParams, DualParams, RetrieverTrainState, retriever_train_step
2 |
--------------------------------------------------------------------------------
/tevatron/tevax/loss.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import lax
3 | import optax
4 | import chex
5 |
6 |
7 | def _onehot(labels: chex.Array, num_classes: int) -> chex.Array:
8 | x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,))
9 | x = lax.select(x, jnp.ones(x.shape), jnp.zeros(x.shape))
10 | return x.astype(jnp.float32)
11 |
12 |
13 | def p_contrastive_loss(ss: chex.Array, tt: chex.Array, axis: str = 'device') -> chex.Array:
14 | per_shard_targets = tt.shape[0]
15 | per_sample_targets = int(tt.shape[0] / ss.shape[0])
16 | labels = jnp.arange(0, per_shard_targets, per_sample_targets) + per_shard_targets * lax.axis_index(axis)
17 |
18 | tt = lax.all_gather(tt, axis).reshape((-1, ss.shape[-1]))
19 | scores = jnp.dot(ss, jnp.transpose(tt))
20 |
21 | return optax.softmax_cross_entropy(scores, _onehot(labels, scores.shape[-1]))
22 |
--------------------------------------------------------------------------------
/tevatron/tevax/training.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple, Any, Union
3 |
4 | import jax
5 | from jax import numpy as jnp
6 |
7 | from flax.training.train_state import TrainState
8 | from flax.core import FrozenDict
9 | from flax.struct import PyTreeNode
10 |
11 | from .loss import p_contrastive_loss
12 |
13 |
14 | class TiedParams(PyTreeNode):
15 | params: FrozenDict[str, Any]
16 |
17 | @property
18 | def q_params(self):
19 | return self.params
20 |
21 | @property
22 | def p_params(self):
23 | return self.params
24 |
25 | @classmethod
26 | def create(cls, params):
27 | return cls(params=params)
28 |
29 |
30 | class DualParams(PyTreeNode):
31 | params: Tuple[FrozenDict[str, Any], FrozenDict[str, Any]]
32 |
33 | @property
34 | def q_params(self):
35 | return self.params[0]
36 |
37 | @property
38 | def p_params(self):
39 | return self.params[1]
40 |
41 | @classmethod
42 | def create(cls, *ps):
43 | if len(ps) == 1:
44 | return cls(params=ps*2)
45 | else:
46 | p_params, q_params = ps
47 | return cls(params=[p_params, q_params])
48 |
49 |
50 | class RetrieverTrainState(TrainState):
51 | params: Union[TiedParams, DualParams]
52 |
53 |
54 | def retriever_train_step(state, queries, passages, dropout_rng, axis='device'):
55 | q_dropout_rng, p_dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 3)
56 |
57 | def compute_loss(params):
58 | q_reps = state.apply_fn(**queries, params=params.q_params, dropout_rng=q_dropout_rng, train=True)[0][:, 0, :]
59 | p_reps = state.apply_fn(**passages, params=params.p_params, dropout_rng=p_dropout_rng, train=True)[0][:, 0, :]
60 | return jnp.mean(p_contrastive_loss(q_reps, p_reps, axis=axis))
61 |
62 | loss, grad = jax.value_and_grad(compute_loss)(state.params)
63 | loss, grad = jax.lax.pmean([loss, grad], axis)
64 |
65 | new_state = state.apply_gradients(grads=grad)
66 |
67 | return loss, new_state, new_dropout_rng
68 |
69 |
70 | def grad_cache_train_step(state, queries, passages, dropout_rng, axis='device', q_n_subbatch=1, p_n_subbatch=1):
71 | try:
72 | from grad_cache import cachex
73 | except ImportError:
74 | raise ModuleNotFoundError('GradCache packaged needs to be installed for running grad_cache_train_step')
75 |
76 | def encode_query(params, **kwargs):
77 | return state.apply_fn(**kwargs, params=params.q_params, train=True)[0][:, 0, :]
78 |
79 | def encode_passage(params, **kwargs):
80 | return state.apply_fn(**kwargs, params=params.p_params, train=True)[0][:, 0, :]
81 |
82 | queries, passages = cachex.tree_chunk(queries, q_n_subbatch), cachex.tree_chunk(passages, p_n_subbatch)
83 | q_rngs, p_rngs, new_rng = jax.random.split(dropout_rng, 3)
84 | q_rngs = jax.random.split(q_rngs, q_n_subbatch)
85 | p_rngs = jax.random.split(p_rngs, p_n_subbatch)
86 |
87 | q_reps = cachex.chunk_encode(partial(encode_query, state.params))(**queries, dropout_rng=q_rngs)
88 | p_reps = cachex.chunk_encode(partial(encode_passage, state.params))(**passages, dropout_rng=p_rngs)
89 |
90 | @cachex.unchunk_args(axis=0, argnums=(0, 1))
91 | def compute_loss(xx, yy):
92 | return jnp.mean(p_contrastive_loss(xx, yy, axis=axis))
93 |
94 | loss, (q_grads, p_grads) = jax.value_and_grad(compute_loss, argnums=(0, 1))(q_reps, p_reps)
95 |
96 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params)
97 | grads = cachex.cache_grad(encode_query)(state.params, grads, q_grads, **queries, dropout_rng=q_rngs)
98 | grads = cachex.cache_grad(encode_passage)(state.params, grads, p_grads, **passages, dropout_rng=p_rngs)
99 |
100 | loss, grads = jax.lax.pmean([loss, grads], axis)
101 | new_state = state.apply_gradients(grads=grads)
102 | return loss, new_state, new_rng
103 |
--------------------------------------------------------------------------------
/tevatron/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 | from typing import Dict, List, Tuple, Optional, Any, Union
4 |
5 | from transformers.trainer import Trainer
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import torch.distributed as dist
10 |
11 | from .loss import SimpleContrastiveLoss, DistributedContrastiveLoss
12 |
13 | import logging
14 | logger = logging.getLogger(__name__)
15 |
16 | try:
17 | from grad_cache import GradCache
18 | _grad_cache_available = True
19 | except ModuleNotFoundError:
20 | _grad_cache_available = False
21 |
22 |
23 | class DenseTrainer(Trainer):
24 | def __init__(self, *args, **kwargs):
25 | super(DenseTrainer, self).__init__(*args, **kwargs)
26 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1
27 |
28 | def _save(self, output_dir: Optional[str] = None):
29 | output_dir = output_dir if output_dir is not None else self.args.output_dir
30 | os.makedirs(output_dir, exist_ok=True)
31 | logger.info("Saving model checkpoint to %s", output_dir)
32 | self.model.save(output_dir)
33 |
34 | def _prepare_inputs(
35 | self,
36 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...]
37 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]:
38 | prepared = []
39 | for x in inputs:
40 | if isinstance(x, torch.Tensor):
41 | prepared.append(x.to(self.args.device))
42 | else:
43 | prepared.append(super()._prepare_inputs(x))
44 | return prepared
45 |
46 | def get_train_dataloader(self) -> DataLoader:
47 | if self.train_dataset is None:
48 | raise ValueError("Trainer: training requires a train_dataset.")
49 | train_sampler = self._get_train_sampler()
50 |
51 | return DataLoader(
52 | self.train_dataset,
53 | batch_size=self.args.train_batch_size,
54 | sampler=train_sampler,
55 | collate_fn=self.data_collator,
56 | drop_last=True,
57 | num_workers=self.args.dataloader_num_workers,
58 | )
59 |
60 | def compute_loss(self, model, inputs):
61 | query, passage, teacher_scores = inputs
62 |
63 | return model(query=query, passage=passage, teacher_scores=teacher_scores).loss
64 |
65 | def training_step(self, *args):
66 | return super(DenseTrainer, self).training_step(*args) / self._dist_loss_scale_factor
67 |
68 |
69 | def split_dense_inputs(model_input: dict, chunk_size: int):
70 | assert len(model_input) == 1
71 | arg_key = list(model_input.keys())[0]
72 | arg_val = model_input[arg_key]
73 |
74 | keys = list(arg_val.keys())
75 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
76 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]
77 |
78 | return [{arg_key: c} for c in chunked_arg_val]
79 |
80 |
81 | def get_dense_rep(x):
82 | if x.q_reps is None:
83 | return x.p_reps
84 | else:
85 | return x.q_reps
86 |
87 |
88 | class GCTrainer(DenseTrainer):
89 | def __init__(self, *args, **kwargs):
90 | logger.info('Initializing Gradient Cache Trainer')
91 | if not _grad_cache_available:
92 | raise ValueError(
93 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')
94 | super(GCTrainer, self).__init__(*args, **kwargs)
95 |
96 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss
97 | loss_fn = loss_fn_cls(self.model.data_args.train_n_passages)
98 |
99 | self.gc = GradCache(
100 | models=[self.model, self.model],
101 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size],
102 | loss_fn=loss_fn,
103 | split_input_fn=split_dense_inputs,
104 | get_rep_fn=get_dense_rep,
105 | fp16=self.args.fp16,
106 | scaler=self.scaler
107 | )
108 |
109 | def training_step(self, model, inputs) -> torch.Tensor:
110 | model.train()
111 | queries, passages = self._prepare_inputs(inputs)
112 | queries, passages = {'query': queries}, {'passage': passages}
113 |
114 | _distributed = self.args.local_rank > -1
115 | self.gc.models = [model, model]
116 | loss = self.gc(queries, passages, no_sync_except_last=_distributed)
117 |
118 | return loss / self._dist_loss_scale_factor
119 |
--------------------------------------------------------------------------------
/tevatron/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/utils/__init__.py
--------------------------------------------------------------------------------
/tevatron/utils/convert_from_dpr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 |
5 | from transformers import AutoConfig, AutoTokenizer
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--dpr_model', required=True)
10 | parser.add_argument('--save_to', required=True)
11 | args = parser.parse_args()
12 |
13 | dpr_model_ckpt = torch.load(args.dpr_model, map_location='cpu')
14 | config_name = dpr_model_ckpt['encoder_params']['pretrained_model_cfg']
15 | dpr_model_dict = dpr_model_ckpt['model_dict']
16 |
17 | AutoConfig.from_pretrained(config_name).save_pretrained(args.save_to)
18 | AutoTokenizer.from_pretrained(config_name).save_pretrained(args.save_to)
19 |
20 | question_keys = [k for k in dpr_model_dict.keys() if k.startswith('question_model')]
21 | ctx_keys = [k for k in dpr_model_dict.keys() if k.startswith('ctx_model')]
22 |
23 | question_dict = dict([(k[len('question_model')+1:], dpr_model_dict[k]) for k in question_keys])
24 | ctx_dict = dict([(k[len('ctx_model')+1:], dpr_model_dict[k]) for k in ctx_keys])
25 |
26 | os.makedirs(os.path.join(args.save_to, 'query_model'), exist_ok=True)
27 | os.makedirs(os.path.join(args.save_to, 'passage_model'), exist_ok=True)
28 | torch.save(question_dict, os.path.join(args.save_to, 'query_model', 'pytorch_model.bin'))
29 | torch.save(ctx_dict, os.path.join(args.save_to, 'passage_model', 'pytorch_model.bin'))
30 |
31 |
32 | if __name__ == '__main__':
33 | main()
--------------------------------------------------------------------------------
/tevatron/utils/data_reader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from tqdm import tqdm
3 | import os
4 | import json
5 | from typing import List, Tuple
6 | from collections import defaultdict
7 | logger = logging.getLogger(__name__)
8 |
9 | def create_dir(dir_: str):
10 | output_parent = '/'.join((dir_).split('/')[:-1])
11 | if not os.path.exists(output_parent):
12 | logger.info(f'Create {output_parent}')
13 | os.mkdir(output_parent)
14 | if not os.path.exists(dir_):
15 | logger.info(f'Create {dir_}')
16 | os.mkdir(dir_)
17 |
18 | def read_tsv(path: str):
19 | id2info = {}
20 | with open(path, 'r') as f:
21 | for line in tqdm(f, desc=f"read {path}"):
22 | idx, info = line.strip().split('\t')
23 | id2info[idx] = info
24 | return id2info
25 |
26 | def read_json(path: str,
27 | id_key: str = 'id',
28 | content_key: str = 'content',
29 | meta_keys: List[str] = None,
30 | sep: str = ' '):
31 | id2info = {}
32 | with open(path, 'r') as f:
33 | for line in tqdm(f, desc=f"read {path}"):
34 | data = json.loads(line.strip().split('\t'))
35 | idx = data[id_key]
36 | info = data[content_key]
37 | if meta_key:
38 | info = [info]
39 | for meta_key in meta_keys:
40 | info.append(data[meta_key])
41 | info = sep.join(info)
42 | id2info[idx] = info
43 | return id2info
44 |
45 | def read_trec(path: str):
46 | qid2psg = defaultdict(list)
47 | with open(path, 'r') as f:
48 | for line in tqdm(f, desc=f"read {path}"):
49 | try:
50 | data = line.strip().split('\t')
51 | qid = data[0]
52 | psg = data[2]
53 | except:
54 | data = line.strip().split(' ')
55 | qid = data[0]
56 | psg = data[2]
57 | qid2psg[qid].append(psg)
58 |
59 |
60 | return qid2psg
61 |
62 | def read_qrel(path: str):
63 | qid_pid2qrel = defaultdict(int)
64 | with open(path, 'r') as f:
65 | for line in tqdm(f, desc=f"read {path}"):
66 | qid, _, pid, rel,= line.strip().split('\t')
67 | qid_pid2qrel[f'{qid}_{pid}'] = int(rel)
68 | return qid_pid2qrel
--------------------------------------------------------------------------------
/tevatron/utils/format/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/castorini/dhr/e236f3da1c14424c730cd22276554ab900bdece2/tevatron/utils/format/__init__.py
--------------------------------------------------------------------------------
/tevatron/utils/format/convert_result_to_trec.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | parser = ArgumentParser()
4 | parser.add_argument('--input', type=str, required=True)
5 | parser.add_argument('--output', type=str, required=True)
6 | args = parser.parse_args()
7 |
8 | with open(args.input) as f_in, open(args.output, 'w') as f_out:
9 | cur_qid = None
10 | rank = 0
11 | for line in f_in:
12 | qid, docid, score = line.split()
13 | if cur_qid != qid:
14 | cur_qid = qid
15 | rank = 0
16 | rank += 1
17 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n')
18 |
--------------------------------------------------------------------------------
/tevatron/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def average_precision(gt, pred):
5 | """
6 | Computes the average precision.
7 |
8 | This function computes the average prescision at k between two lists of
9 | items.
10 |
11 | Parameters
12 | ----------
13 | gt: set
14 | A set of ground-truth elements (order doesn't matter)
15 | pred: list
16 | A list of predicted elements (order does matter)
17 |
18 | Returns
19 | -------
20 | score: double
21 | The average precision over the input lists
22 | """
23 |
24 | if not gt:
25 | return 0.0
26 |
27 | score = 0.0
28 | num_hits = 0.0
29 | for i,p in enumerate(pred):
30 | if p in gt and p not in pred[:i]:
31 | num_hits += 1.0
32 | score += num_hits / (i + 1.0)
33 |
34 | return score / max(1.0, len(gt))
35 |
36 |
37 | def NDCG(gt, pred, use_graded_scores=False):
38 | score = 0.0
39 | for rank, item in enumerate(pred):
40 | if item in gt:
41 | if use_graded_scores:
42 | grade = 1.0 / (gt.index(item) + 1)
43 | else:
44 | grade = 1.0
45 | score += grade / np.log2(rank + 2)
46 |
47 | norm = 0.0
48 | for rank in range(len(gt)):
49 | if use_graded_scores:
50 | grade = 1.0 / (rank + 1)
51 | else:
52 | grade = 1.0
53 | norm += grade / np.log2(rank + 2)
54 | return score / max(0.3, norm)
55 |
56 |
57 | def metrics(gt, pred, metrics_map):
58 | '''
59 | Returns a numpy array containing metrics specified by metrics_map.
60 | gt: ground-truth items
61 | pred: predicted items
62 | '''
63 | out = np.zeros((len(metrics_map),), np.float32)
64 |
65 | if ('MAP' in metrics_map):
66 | avg_precision = average_precision(gt=gt, pred=pred)
67 | out[metrics_map.index('MAP')] = avg_precision
68 |
69 | if ('RPrec' in metrics_map):
70 | intersec = len(gt & set(pred[:len(gt)]))
71 | out[metrics_map.index('RPrec')] = intersec / max(1., float(len(gt)))
72 |
73 | if 'MRR' in metrics_map:
74 | score = 0.0
75 | for rank, item in enumerate(pred):
76 | if item in gt:
77 | score = 1.0 / (rank + 1.0)
78 | break
79 | out[metrics_map.index('MRR')] = score
80 |
81 | if 'MRR@10' in metrics_map:
82 | score = 0.0
83 | for rank, item in enumerate(pred[:10]):
84 | if item in gt:
85 | score = 1.0 / (rank + 1.0)
86 | break
87 | out[metrics_map.index('MRR@10')] = score
88 |
89 | if ('NDCG' in metrics_map):
90 | out[metrics_map.index('NDCG')] = NDCG(gt, pred)
91 |
92 | return out
93 |
94 |
--------------------------------------------------------------------------------
/tevatron/utils/tokenize_corpus.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
3 | import argparse
4 | from tqdm import tqdm
5 | import os
6 | import json
7 | from multiprocessing import Pool
8 | from transformers import AutoTokenizer
9 | from .data_reader import create_dir
10 |
11 | DATA_ITEM = {'msmarco-passage': {'id':'id', 'contents': ['contents']},
12 | 'beir': {'id':'_id', 'contents': ['title', 'text']}}
13 |
14 | def tokenize_and_json_save(data_item, data_type, tokenizer, lines, jsonl_path, tokenize, encode):
15 | output = open(jsonl_path, 'w')
16 | for i, line in enumerate( tqdm(lines, total=len(lines), desc=f"write {output}") ):
17 | if data_type == 'tsv':
18 | docid, contents = line.strip().split('\t')
19 | elif (data_type =='json') or (data_type =='jsonl'):
20 | line = json.loads(line.strip())
21 | docid = line[data_item['id']]
22 |
23 | contents = []
24 | for content in data_item['contents']:
25 | contents.append(line[content])
26 | contents = ' '.join(contents)
27 | if tokenize:
28 | if encode:
29 | contents = tokenizer.encode(contents, add_special_tokens=False)
30 | # Fit the format of tevatron
31 | output_dict = {'text_id': docid, 'text': contents}
32 | else:
33 | contents = ' '.join(tokenizer.tokenize(contents))
34 | output_dict = {'id': docid, 'contents': contents}
35 | else:
36 | output_dict = {'id': docid, 'contents': contents}
37 | output.write(json.dumps(output_dict) + '\n')
38 | output.close()
39 |
40 | def main():
41 | parser = argparse.ArgumentParser(
42 | description='Transform corpus into wordpiece corpus')
43 | parser.add_argument('--corpus_path', required=True, help='TSV or json corpus file with format {docid}\t{document}.')
44 | parser.add_argument('--output_dir', required=True)
45 | parser.add_argument('--corpus_domain', required=False, default='msmarco-passage')
46 | parser.add_argument('--tokenizer', required=False, default='bert-base-uncased', help='tokenizer model name')
47 | parser.add_argument('--tokenize', action='store_true')
48 | parser.add_argument('--encode', action='store_true')
49 | parser.add_argument('--num_workers', type=int, required=False, default=None)
50 | parser.add_argument('--max_line_per_file', type=int, required=False, default=300000, help='max length 150 use default; max length 512 use 300000')
51 | args = parser.parse_args()
52 |
53 | if args.encode:
54 | if not args.tokenize:
55 | raise ValueError('if you want to encode, you must set tokenize option!')
56 |
57 | create_dir(args.output_dir)
58 |
59 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
60 |
61 | data_type = (args.corpus_path).split('.')[-1]
62 | if (data_type != 'tsv') and (data_type != 'json') and (data_type != 'jsonl'):
63 | raise ValueError('--corpus_path should be tsv, json or jsonl format')
64 |
65 | with open(args.corpus_path, 'r') as f:
66 | print("read {}".format(args.corpus_path))
67 | lines = f.readlines()
68 | total_num_docs = len(lines)
69 | print("total {} lines".format(total_num_docs))
70 |
71 | ## for debug
72 | # tokenize_and_json_save(DATA_ITEM[args.corpus_domain], data_type, tokenizer, lines, os.path.join(jsonl_dir, 'split.json'), args.tokenize )
73 | if args.num_workers is None:
74 | num_docs_per_worker = args.max_line_per_file
75 | args.num_workers = total_num_docs // num_docs_per_worker
76 | if (total_num_docs%num_docs_per_worker ) != 0:
77 | args.num_workers+=1
78 | else:
79 | num_docs_per_worker = total_num_docs//args.num_workers
80 | if (total_num_docs%args.num_workers) != 0:
81 | args.num_workers+=1
82 |
83 | logging.info(f'Run with {args.num_workers} workers on {total_num_docs} documents')
84 | pool = Pool(args.num_workers)
85 | for i in range(args.num_workers):
86 | f_out = os.path.join(args.output_dir, 'split%02d.json'%i)
87 | start = i*num_docs_per_worker
88 | if i==(args.num_workers-1):
89 | pool.apply_async(tokenize_and_json_save ,(DATA_ITEM[args.corpus_domain], data_type, tokenizer,\
90 | lines[start:], f_out, args.tokenize, args.encode))
91 | else:
92 | pool.apply_async(tokenize_and_json_save ,(DATA_ITEM[args.corpus_domain], data_type, tokenizer,\
93 | lines[start:(start+num_docs_per_worker)], f_out, args.tokenize, args.encode))
94 |
95 | pool.close()
96 | pool.join()
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/tevatron/utils/tokenize_query.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
3 | import argparse
4 | from tqdm import tqdm
5 | import os
6 | import json
7 | from collections import defaultdict
8 | from transformers import AutoTokenizer
9 | import sys
10 | from .data_reader import read_tsv, create_dir
11 |
12 | def main():
13 | parser = argparse.ArgumentParser(
14 | description='Tokenize query')
15 | parser.add_argument('--qry_file', required=True, help='format {qid}\t{qry}')
16 | parser.add_argument('--output_dir', required=True)
17 | parser.add_argument('--tokenizer', required=False, default='bert-base-uncased', help='tokenizer model name')
18 | args = parser.parse_args()
19 |
20 | create_dir(args.output_dir)
21 |
22 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
23 | qid2qry = read_tsv(args.qry_file)
24 |
25 | query_name = args.qry_file.split('/')[-1].replace('.tsv','.json')
26 | output_path = os.path.join(args.output_dir, query_name)
27 | output = open(output_path, 'w')
28 | with open(args.qry_file, 'r') as f:
29 | for line in tqdm(f, desc=f"tokenize query: {output_path}"):
30 | qid, qry = line.strip().split('\t')
31 | qry = tokenizer.encode(qry, add_special_tokens=False)
32 | output_dict = {"text_id": qid, "text": qry}
33 | output.write(json.dumps(output_dict) + '\n')
34 | output.close()
35 | if __name__ == "__main__":
36 | main()
--------------------------------------------------------------------------------