├── src └── dense │ ├── __init__.py │ ├── driver │ ├── __init__.py │ ├── train.py │ └── encode.py │ ├── utils │ ├── __init__.py │ └── format │ │ ├── __init__.py │ │ └── convert_result_to_trec.py │ ├── faiss_retriever │ ├── __init__.py │ ├── retriever.py │ ├── reducer.py │ └── __main__.py │ ├── processor │ ├── __init__.py │ └── processors.py │ ├── dataset │ ├── __init__.py │ └── processor.py │ ├── loss.py │ ├── arguments.py │ ├── trainer.py │ ├── data.py │ └── modeling.py ├── requirements.txt ├── examples ├── coCondenser-marco │ ├── create_hn.sh │ ├── get_data.sh │ └── README.md ├── msmarco-passage-ranking │ ├── score_to_marco.py │ ├── tokenize_queries.py │ ├── get_data.sh │ ├── tokenize_passages.py │ ├── README.md │ ├── build_train.py │ └── build_train_hn.py ├── README.md ├── wikipedia-nq │ ├── prepare_wiki_train.py │ ├── README.md │ └── run.py ├── scifact │ ├── README.md │ └── run.py └── run.py ├── setup.py ├── README.md └── LICENSE /src/dense/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dense/driver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dense/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dense/utils/format/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dense/faiss_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import BaseFaissIPRetriever 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch<=1.8.0 2 | faiss-cpu>=1.6.5 3 | transformers==4.2.0 4 | datasets==1.1.3 -------------------------------------------------------------------------------- /src/dense/processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import SimpleTrainProcessor, SimpleCollectionProcessor 2 | 3 | MarcoPassageTrainProcessor = SimpleTrainProcessor 4 | -------------------------------------------------------------------------------- /examples/coCondenser-marco/create_hn.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR=$PWD/../msmarco-passage-ranking 2 | cd marco 3 | python $SCRIPT_DIR/build_hn.py --tokenizer_name bert-base-uncased --hn_file ../train.rank.txt --qrels qrels.train.tsv \ 4 | --queries train.query.txt --collection corpus.tsv --save_to bert/train-hn 5 | ln -s bert/train/* bert/train-hn 6 | cd - -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='dense', 5 | version='0.0.1', 6 | packages=find_packages("src"), 7 | package_dir={'': 'src'}, 8 | url='https://github.com/luyug/Dense', 9 | license='Apache 2.0', 10 | author='Luyu Gao', 11 | author_email='luyug@cs.cmu.edu', 12 | description='A toolkit for learning and running deep dense retrieval models.' 13 | ) 14 | -------------------------------------------------------------------------------- /src/dense/utils/format/convert_result_to_trec.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | parser = ArgumentParser() 4 | parser.add_argument('--input', type=str, required=True) 5 | parser.add_argument('--output', type=str, required=True) 6 | args = parser.parse_args() 7 | 8 | with open(args.input) as f_in, open(args.output, 'w') as f_out: 9 | cur_qid = None 10 | rank = 0 11 | for line in f_in: 12 | qid, docid, score = line.split() 13 | if cur_qid != qid: 14 | cur_qid = qid 15 | rank = 0 16 | rank += 1 17 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n') 18 | -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/score_to_marco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('score_file') 6 | args = parser.parse_args() 7 | 8 | with open(args.score_file) as f: 9 | lines = f.readlines() 10 | 11 | all_scores = defaultdict(list) 12 | 13 | for line in lines: 14 | if len(line.strip()) == 0: 15 | continue 16 | qid, did, score = line.strip().split() 17 | score = float(score) 18 | all_scores[qid].append((did, score)) 19 | 20 | qq = list(all_scores.keys()) 21 | 22 | with open(args.score_file + '.marco', 'w') as f: 23 | for qid in qq: 24 | score_list = sorted(all_scores[qid], key=lambda x: x[1], reverse=True) 25 | for rank, (did, score) in enumerate(score_list): 26 | f.write(f'{qid}\t{did}\t{rank+1}\n') 27 | 28 | -------------------------------------------------------------------------------- /src/dense/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import TrainProcessor, TestProcessor, CorpusProcessor 2 | 3 | PROCESSOR_INFO = { 4 | 'Tevatron/wikipedia-nq': { 5 | 'train': TrainProcessor, 6 | 'dev': TrainProcessor, 7 | 'test': TestProcessor, 8 | 'corpus': CorpusProcessor, 9 | }, 10 | 'Tevatron/wikipedia-trivia': { 11 | 'train': TrainProcessor, 12 | 'dev': TrainProcessor, 13 | 'test': TestProcessor, 14 | 'corpus': CorpusProcessor, 15 | }, 16 | 'Tevatron/msmarco-passage': { 17 | 'train': TrainProcessor, 18 | 'dev': TestProcessor, 19 | 'corpus': CorpusProcessor, 20 | }, 21 | 'Tevatron/scifact': { 22 | 'train': TrainProcessor, 23 | 'dev': TestProcessor, 24 | 'test': TestProcessor, 25 | 'corpus': CorpusProcessor, 26 | }, 27 | } 28 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | Here we provide examples for running Dense on various datasets/models. 3 | 4 | ## Research 5 | Researchers are recommended to start with the [run.py](run.py) under this directory. It includes logics in `dense.driver.train` and `dense.driver.encode` for training and encoding. 6 | Adjustments can then be made into `dense.modeling`, `dense.trainer` and `dense.data`; either create sub-classes or make direct edits. 7 | 8 | In particular, 9 | - better models can go into `dense.modeling` 10 | - better training technique can go into `dense.trainer` 11 | - better data control go into `dense.data` 12 | 13 | To change retriever behaviors, check out its [main function](../src/dense/faiss_retriever/__main__.py), 14 | and also the entire `faiss_retriever` [submodule](../src/dense/faiss_retriever). 15 | ## Example Index 16 | - [MS-MARCO passage ranking](msmarco-passage-ranking) 17 | -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/tokenize_queries.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoTokenizer 3 | import os 4 | from tqdm import tqdm 5 | from dense.processor import SimpleCollectionProcessor 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument('--tokenizer_name', required=True) 9 | parser.add_argument('--truncate', type=int, default=32) 10 | parser.add_argument('--query_file', required=True) 11 | parser.add_argument('--save_to', required=True) 12 | args = parser.parse_args() 13 | 14 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 15 | processor = SimpleCollectionProcessor(tokenizer=tokenizer, max_length=args.truncate) 16 | 17 | with open(args.query_file, 'r') as f: 18 | lines = f.readlines() 19 | 20 | os.makedirs(os.path.split(args.save_to)[0], exist_ok=True) 21 | with open(args.save_to, 'w') as jfile: 22 | for x in tqdm(lines): 23 | q = processor.process_line(x) 24 | jfile.write(q + '\n') 25 | -------------------------------------------------------------------------------- /src/dense/faiss_retriever/retriever.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class BaseFaissIPRetriever: 9 | def __init__(self, init_reps: np.ndarray): 10 | index = faiss.IndexFlatIP(init_reps.shape[1]) 11 | self.index = index 12 | 13 | def search(self, q_reps: np.ndarray, k: int): 14 | return self.index.search(q_reps, k) 15 | 16 | def add(self, p_reps: np.ndarray): 17 | self.index.add(p_reps) 18 | 19 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): 20 | num_query = q_reps.shape[0] 21 | all_scores = [] 22 | all_indices = [] 23 | for start_idx in range(0, num_query, batch_size): 24 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 25 | all_scores.append(nn_scores) 26 | all_indices.append(nn_indices) 27 | all_scores = np.concatenate(all_scores, axis=0) 28 | all_indices = np.concatenate(all_indices, axis=0) 29 | 30 | return all_scores, all_indices -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/get_data.sh: -------------------------------------------------------------------------------- 1 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz 2 | tar -zxf marco.tar.gz 3 | rm -rf marco.tar.gz 4 | 5 | cd marco 6 | 7 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz 8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv 9 | gunzip qidpidtriples.train.full.2.tsv.gz 10 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv 11 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv 12 | 13 | TOKENIZER=bert-base-uncased 14 | TOKENIZER_ID=bert 15 | 16 | python ../build_train.py --tokenizer_name $TOKENIZER --negative_file train.negatives.tsv --qrels qrels.train.tsv \ 17 | --queries train.query.txt --collection corpus.tsv --save_to ${TOKENIZER_ID}/train 18 | python ../tokenize_queries.py --tokenizer_name $TOKENIZER --query_file dev.query.txt --save_to $TOKENIZER_ID/query/dev.query.json 19 | python ../tokenize_passages.py --tokenizer_name $TOKENIZER --file corpus.tsv --save_to $TOKENIZER_ID/corpus 20 | 21 | cd - -------------------------------------------------------------------------------- /examples/coCondenser-marco/get_data.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR=$PWD/../msmarco-passage-ranking/ 2 | 3 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz 4 | tar -zxf marco.tar.gz 5 | rm -rf marco.tar.gz 6 | cd marco 7 | 8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz 9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv 10 | gunzip qidpidtriples.train.full.2.tsv.gz 11 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv 12 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv 13 | 14 | TOKENIZER=bert-base-uncased 15 | TOKENIZER_ID=bert 16 | 17 | python $SCRIPT_DIR/build_train.py --tokenizer_name $TOKENIZER --negative_file train.negatives.tsv --qrels qrels.train.tsv \ 18 | --queries train.query.txt --collection corpus.tsv --save_to ${TOKENIZER_ID}/train 19 | python $SCRIPT_DIR/tokenize_queries.py --tokenizer_name $TOKENIZER --query_file dev.query.txt --save_to $TOKENIZER_ID/query/dev.query.json 20 | python $SCRIPT_DIR/tokenize_queries.py --tokenizer_name $TOKENIZER --query_file train.query.txt --save_to $TOKENIZER_ID/query/train.query.json 21 | python $SCRIPT_DIR/tokenize_passages.py --tokenizer_name $TOKENIZER --file corpus.tsv --save_to $TOKENIZER_ID/corpus 22 | 23 | cd - -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/tokenize_passages.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoTokenizer 3 | import os 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | from dense.processor import SimpleCollectionProcessor 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument('--tokenizer_name', required=True) 10 | parser.add_argument('--truncate', type=int, default=128) 11 | parser.add_argument('--file', required=True) 12 | parser.add_argument('--save_to', required=True) 13 | parser.add_argument('--n_splits', type=int, default=10) 14 | 15 | args = parser.parse_args() 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 18 | processor = SimpleCollectionProcessor(tokenizer=tokenizer, max_length=args.truncate) 19 | 20 | with open(args.file, 'r') as f: 21 | lines = f.readlines() 22 | 23 | n_lines = len(lines) 24 | if n_lines % args.n_splits == 0: 25 | split_size = int(n_lines / args.n_splits) 26 | else: 27 | split_size = int(n_lines / args.n_splits) + 1 28 | 29 | 30 | os.makedirs(args.save_to, exist_ok=True) 31 | with Pool() as p: 32 | for i in range(args.n_splits): 33 | with open(os.path.join(args.save_to, f'split{i:02d}.json'), 'w') as f: 34 | pbar = tqdm(lines[i*split_size: (i+1)*split_size]) 35 | pbar.set_description(f'split - {i:02d}') 36 | for jitem in p.imap(processor.process_line, pbar, chunksize=500): 37 | f.write(jitem + '\n') 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/dense/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torch import distributed as dist 5 | 6 | 7 | class SimpleContrastiveLoss: 8 | def __init__(self, n_target: int = 1): 9 | self.target_per_qry = n_target 10 | 11 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 12 | if target is None: 13 | assert x.size(0) * self.target_per_qry == y.size(0) 14 | target = torch.arange( 15 | 0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device, dtype=torch.long) 16 | logits = torch.matmul(x, y.transpose(0, 1)) 17 | return F.cross_entropy(logits, target, reduction=reduction) 18 | 19 | 20 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 21 | def __init__(self, n_target: int = 0, scale_loss: bool = True): 22 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 23 | super().__init__(n_target=n_target) 24 | self.word_size = dist.get_world_size() 25 | self.rank = dist.get_rank() 26 | self.scale_loss = scale_loss 27 | 28 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 29 | dist_x = self.gather_tensor(x) 30 | dist_y = self.gather_tensor(y) 31 | loss = super().__call__(dist_x, dist_y, **kwargs) 32 | if self.scale_loss: 33 | loss = loss * self.word_size 34 | return loss 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) -------------------------------------------------------------------------------- /src/dense/faiss_retriever/reducer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import faiss 4 | from argparse import ArgumentParser 5 | from tqdm import tqdm 6 | from typing import List, Iterable, Tuple 7 | from numpy import ndarray 8 | 9 | 10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]): 11 | rh = None 12 | for scores, indices in results: 13 | if rh is None: 14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.') 15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1]) 16 | rh.add_result(-scores, indices) 17 | rh.finalize() 18 | corpus_scores, corpus_indices = -rh.D, rh.I 19 | 20 | return corpus_scores, corpus_indices 21 | 22 | 23 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): 24 | with open(ranking_save_file, 'w') as f: 25 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): 26 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 27 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 28 | for s, idx in score_list: 29 | f.write(f'{qid}\t{idx}\t{s}\n') 30 | 31 | 32 | def main(): 33 | parser = ArgumentParser() 34 | parser.add_argument('--score_dir', required=True) 35 | parser.add_argument('--query', required=True) 36 | parser.add_argument('--save_ranking_to', required=True) 37 | args = parser.parse_args() 38 | 39 | partitions = glob.glob(f'{args.score_dir}/*') 40 | 41 | corpus_scores, corpus_indices = combine_faiss_results(map(torch.load, tqdm(partitions))) 42 | 43 | _, q_lookup = torch.load(args.query) 44 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() -------------------------------------------------------------------------------- /examples/wikipedia-nq/prepare_wiki_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | from transformers import AutoTokenizer 6 | from tqdm import tqdm 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument('--input', type=str, required=True) 10 | parser.add_argument('--output', type=str, required=True) 11 | parser.add_argument('--tokenizer', type=str, required=False, default='bert-base-uncased') 12 | parser.add_argument('--minimum-negatives', type=int, required=False, default=8) 13 | args = parser.parse_args() 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True) 16 | 17 | data = json.load(open(args.input)) 18 | 19 | if not os.path.exists(args.output): 20 | os.makedirs(args.output) 21 | with open(os.path.join(args.output, 'train_data.json'), 'w') as f: 22 | for idx, item in enumerate(tqdm(data)): 23 | group = {} 24 | query = tokenizer.encode(item['question'], add_special_tokens=False, max_length=256, truncation=True) 25 | group['query'] = query 26 | positives = item['positive_ctxs'] 27 | negatives = item['hard_negative_ctxs'] 28 | group['positives'] = [] 29 | group['negatives'] = [] 30 | for pos in positives: 31 | text = pos['title'] + " " + pos['text'] 32 | text = tokenizer.encode(text, add_special_tokens=False, max_length=256, truncation=True) 33 | group['positives'].append(text) 34 | for neg in negatives: 35 | text = neg['title'] + " " + neg['text'] 36 | text = tokenizer.encode(text, add_special_tokens=False, max_length=256, truncation=True) 37 | group['negatives'].append(text) 38 | if len(group['negatives']) >= args.minimum_negatives and len(group['positives']) >= 1: 39 | f.write(json.dumps(group) + '\n') 40 | -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/README.md: -------------------------------------------------------------------------------- 1 | # MS-MARCO Passage Ranking 2 | ## Get Data 3 | Run, 4 | ``` 5 | bash get_data.sh 6 | ``` 7 | This downloads the cleaned corpus, generate BM25 negatives and tokenize train/inference data using BERT tokenizer. The process could take up to tens of minutes depending on connection and hardware. 8 | 9 | ## Train a BERT Model 10 | Train a BERT(`bert-base-uncased`) with mixed precision. 11 | ``` 12 | python -m dense.driver.train \ 13 | --output_dir ./retriever_model \ 14 | --model_name_or_path bert-base-uncased \ 15 | --save_steps 20000 \ 16 | --train_dir ./marco/bert/train \ 17 | --fp16 \ 18 | --per_device_train_batch_size 8 \ 19 | --learning_rate 5e-6 \ 20 | --num_train_epochs 2 \ 21 | --dataloader_num_workers 2 22 | ``` 23 | 24 | ## Encode the Corpus and Query 25 | ``` 26 | mkdir encoding 27 | for i in $(seq -f "%02g" 0 9) 28 | do 29 | python -m dense.driver.encode \ 30 | --output_dir ./retriever_model \ 31 | --model_name_or_path ./retriever_model \ 32 | --fp16 \ 33 | --per_device_eval_batch_size 128 \ 34 | --encode_in_path marco/bert/corpus/split${i}.json \ 35 | --encoded_save_path encoding/split${i}.pt 36 | done 37 | 38 | 39 | python -m dense.driver.encode \ 40 | --output_dir ./retriever_model \ 41 | --model_name_or_path ./retriever_model \ 42 | --fp16 \ 43 | --q_max_len 32 \ 44 | --encode_is_qry \ 45 | --per_device_eval_batch_size 128 \ 46 | --encode_in_path marco/bert/query/dev.query.json \ 47 | --encoded_save_path encoding/qry.pt 48 | ``` 49 | 50 | ## Search the Corpus 51 | ``` 52 | mkdir -p ranking/intermediate 53 | 54 | for i in $(seq -f "%02g" 0 9) 55 | do 56 | python -m dense.faiss_retriever \ 57 | --query_reps encoding/qry.pt \ 58 | --passage_reps encoding/split${i}.pt \ 59 | --depth 10 \ 60 | --save_ranking_to ranking/intermediate/split${i} 61 | done 62 | 63 | python -m dense.faiss_retriever.reducer \ 64 | --score_dir ranking/intermediate \ 65 | --query encoding/qry.pt \ 66 | --save_ranking_to ranking/rank.txt 67 | ``` 68 | Finally format the retrieval result, 69 | ``` 70 | python score_to_marco.py ranking/rank.txt 71 | ``` 72 | -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/build_train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoTokenizer 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | from multiprocessing import Pool 8 | from dense.processor import MarcoPassageTrainProcessor as TrainProcessor 9 | 10 | random.seed(datetime.now()) 11 | parser = ArgumentParser() 12 | parser.add_argument('--tokenizer_name', required=True) 13 | parser.add_argument('--negative_file', required=True) 14 | parser.add_argument('--qrels', required=True) 15 | parser.add_argument('--queries', required=True) 16 | parser.add_argument('--collection', required=True) 17 | parser.add_argument('--save_to', required=True) 18 | 19 | parser.add_argument('--truncate', type=int, default=128) 20 | parser.add_argument('--n_sample', type=int, default=30) 21 | parser.add_argument('--mp_chunk_size', type=int, default=500) 22 | parser.add_argument('--shard_size', type=int, default=45000) 23 | 24 | args = parser.parse_args() 25 | 26 | 27 | qrel = TrainProcessor.read_qrel(args.qrels) 28 | 29 | def read_line(l): 30 | q, nn = l.strip().split('\t') 31 | nn = nn.split(',') 32 | random.shuffle(nn) 33 | return q, qrel[q], nn[:args.n_sample] 34 | 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 37 | processor = TrainProcessor( 38 | query_file=args.queries, 39 | collection_file=args.collection, 40 | tokenizer=tokenizer, 41 | max_length=args.truncate, 42 | ) 43 | 44 | counter = 0 45 | shard_id = 0 46 | f = None 47 | os.makedirs(args.save_to, exist_ok=True) 48 | 49 | with open(args.negative_file) as nf: 50 | pbar = tqdm(map(read_line, nf)) 51 | with Pool() as p: 52 | for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size): 53 | counter += 1 54 | if f is None: 55 | f = open(os.path.join(args.save_to, f'split{shard_id:02d}.json'), 'w') 56 | pbar.set_description(f'split - {shard_id:02d}') 57 | f.write(x + '\n') 58 | 59 | if counter == args.shard_size: 60 | f.close() 61 | f = None 62 | shard_id += 1 63 | counter = 0 64 | 65 | if f is not None: 66 | f.close() -------------------------------------------------------------------------------- /src/dense/faiss_retriever/__main__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import glob 4 | from argparse import ArgumentParser 5 | from itertools import chain 6 | from tqdm import tqdm 7 | 8 | from .retriever import BaseFaissIPRetriever 9 | from .reducer import write_ranking 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig( 14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S", 16 | level=logging.INFO, 17 | ) 18 | 19 | 20 | def search_queries(retriever, q_reps, p_lookup, args): 21 | if args.batch_size > 0: 22 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size) 23 | else: 24 | all_scores, all_indices = retriever.search(q_reps, args.depth) 25 | 26 | psg_indices = [[int(p_lookup[x]) for x in q_dd] for q_dd in all_indices] 27 | psg_indices = np.array(psg_indices) 28 | return all_scores, psg_indices 29 | 30 | 31 | def main(): 32 | parser = ArgumentParser() 33 | parser.add_argument('--query_reps', required=True) 34 | parser.add_argument('--passage_reps', required=True) 35 | parser.add_argument('--batch_size', type=int, default=128) 36 | parser.add_argument('--depth', type=int, default=1000) 37 | parser.add_argument('--save_ranking_to', required=True) 38 | parser.add_argument('--save_text', action='store_true') 39 | 40 | args = parser.parse_args() 41 | 42 | index_files = glob.glob(args.passage_reps) 43 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') 44 | 45 | p_reps_0, p_lookup_0 = torch.load(index_files[0]) 46 | retriever = BaseFaissIPRetriever(p_reps_0.float().numpy()) 47 | 48 | shards = chain([(p_reps_0, p_lookup_0)], map(torch.load, index_files[1:])) 49 | if len(index_files) > 1: 50 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 51 | look_up = [] 52 | for p_reps, p_lookup in shards: 53 | retriever.add(p_reps.float().numpy()) 54 | look_up += p_lookup 55 | 56 | q_reps, q_lookup = torch.load(args.query_reps) 57 | q_reps = q_reps.float().numpy() 58 | 59 | logger.info('Index Search Start') 60 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) 61 | logger.info('Index Search Finished') 62 | 63 | if args.save_text: 64 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) 65 | else: 66 | torch.save((all_scores, psg_indices), args.save_ranking_to) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() -------------------------------------------------------------------------------- /examples/scifact/README.md: -------------------------------------------------------------------------------- 1 | # SciFact 2 | 3 | We use SciFact example to show how to train dense retrieval in the "research" way by using `run.py` 4 | 5 | > Note: Different from original [SciFact](https://github.com/allenai/scifact) task that focus on Fact Verification, we focus on the retrieval stage, 6 | and consider a document as relevant to a claim if it appears in `cited_doc_ids`. 7 | 8 | ## Dataset Preparation 9 | The SciFact dataset is self contain in our toolkit based on huggingface datasets. 10 | The dataset name is `Tevatron/scifact`, see below for details. 11 | 12 | ## Train 13 | ```bash 14 | CUDA_VISIBLE_DEVICES=0 python run.py \ 15 | --output_dir scifact_model_e80_64x2 \ 16 | --model_name_or_path bert-base-uncased \ 17 | --do_train \ 18 | --save_steps 20000 \ 19 | --dataset_name Tevatron/scifact \ 20 | --fp16 \ 21 | --per_device_train_batch_size 64 \ 22 | --train_n_passages 2 \ 23 | --learning_rate 1e-5 \ 24 | --q_max_len 64 \ 25 | --p_max_len 512 \ 26 | --num_train_epochs 80 \ 27 | --grad_cache \ 28 | --gc_p_chunk_size 8 \ 29 | --logging_steps 10 \ 30 | --overwrite_output_dir 31 | ``` 32 | 33 | ## Encode Corpus 34 | ```bash 35 | CUDA_VISIBLE_DEVICES=0 python run.py \ 36 | --do_encode \ 37 | --output_dir=temp_out \ 38 | --model_name_or_path scifact_model_e80_64x2 \ 39 | --fp16 \ 40 | --per_device_eval_batch_size 156 \ 41 | --dataset_name Tevatron/scifact/corpus \ 42 | --p_max_len 512 \ 43 | --encoded_save_path docs_emb/docs.pt 44 | ``` 45 | 46 | ## Encode Query 47 | ```bash 48 | CUDA_VISIBLE_DEVICES=0 python run.py \ 49 | --do_encode \ 50 | --output_dir=temp_out \ 51 | --model_name_or_path scifact_model_e20_16x2 \ 52 | --fp16 \ 53 | --per_device_eval_batch_size 156 \ 54 | --dataset_name Tevatron/scifact/dev \ 55 | --encode_is_qry 56 | --q_max_len 64 \ 57 | --encoded_save_path queries_emb/queries.pt 58 | ``` 59 | 60 | ## Search 61 | ```bash 62 | python -m dense.faiss_retriever \ 63 | --query_reps queries_emb/queries.pt \ 64 | --passage_reps docs_emb/docs.pt \ 65 | --depth 20 \ 66 | --batch_size -1 \ 67 | --save_text \ 68 | --save_ranking_to run.scifact.dev.txt 69 | ``` 70 | 71 | ## Evaluate 72 | Download qrels 73 | ```bash 74 | wget https://www.dropbox.com/s/lpq8mfynqzsuyy5/dev_qrels.txt 75 | ``` 76 | 77 | Evaluate 78 | ```bash 79 | python -m dense.utils.format.convert_result_to_trec --input run.scifact.dev.txt --output run.scifact.dev.trec 80 | python -m pyserini.eval.trec_eval -c -mrecip_rank -mndcg_cut.10 dev_qrels.txt run.scifact.dev.trec 81 | ``` 82 | 83 | Following results can be reproduced by following the instructions above: 84 | ``` 85 | recip_rank all 0.7322 86 | ndcg_cut_10 all 0.7473 87 | ``` 88 | Comparing with BM25 baseline: `NDCG@10=0.665` 89 | 90 | ### Condenser 91 | By `bert-base-uncased` with Condenser checkpoint `Luyu/condenser` (See details at [here](https://github.com/luyug/Condenser)), 92 | we are able to get results as below: 93 | 94 | ```bash 95 | recip_rank all 0.7679 96 | ndcg_cut_10 all 0.7841 97 | ``` 98 | 99 | -------------------------------------------------------------------------------- /src/dense/dataset/processor.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedTokenizer 2 | 3 | 4 | class Processor: 5 | def __init__(self, tokenizer: PreTrainedTokenizer): 6 | self.tokenizer = tokenizer 7 | 8 | 9 | class TrainProcessor(Processor): 10 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256): 11 | super().__init__(tokenizer) 12 | self.query_max_length = query_max_length 13 | self.text_max_length = text_max_length 14 | 15 | def __call__(self, example): 16 | query = self.tokenizer.encode(example['query'], 17 | add_special_tokens=False, 18 | max_length=self.query_max_length, 19 | truncation=True) 20 | positives = [] 21 | for pos in example['positive_passages']: 22 | text = pos['title'] + " " + pos['text'] if 'title' in pos else pos['text'] 23 | positives.append(self.tokenizer.encode(text, 24 | add_special_tokens=False, 25 | max_length=self.text_max_length, 26 | truncation=True)) 27 | negatives = [] 28 | for neg in example['negative_passages']: 29 | text = neg['title'] + " " + neg['text'] if 'title' in neg else neg['text'] 30 | negatives.append(self.tokenizer.encode(text, 31 | add_special_tokens=False, 32 | max_length=self.text_max_length, 33 | truncation=True)) 34 | return {'query': query, 'positives': positives, 'negatives': negatives} 35 | 36 | 37 | class TestProcessor(Processor): 38 | def __init__(self, tokenizer, query_max_length=32): 39 | super().__init__(tokenizer) 40 | self.query_max_length = query_max_length 41 | 42 | def __call__(self, example): 43 | query_id = example['query_id'] 44 | query = self.tokenizer.encode(example['query'], 45 | add_special_tokens=False, 46 | max_length=self.query_max_length, 47 | truncation=True) 48 | return {'text_id': query_id, 'text': query} 49 | 50 | 51 | class CorpusProcessor(Processor): 52 | def __init__(self, tokenizer, text_max_length=256): 53 | super().__init__(tokenizer) 54 | self.text_max_length = text_max_length 55 | 56 | def __call__(self, example): 57 | docid = example['docid'] 58 | text = example['title'] + " " + example['text'] if 'title' in example else example['text'] 59 | text = self.tokenizer.encode(text, 60 | add_special_tokens=False, 61 | max_length=self.text_max_length, 62 | truncation=True) 63 | return {'text_id': docid, 'text': text} 64 | -------------------------------------------------------------------------------- /examples/msmarco-passage-ranking/build_train_hn.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoTokenizer 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | from multiprocessing import Pool 8 | from dense.processor import MarcoPassageTrainProcessor as TrainProcessor 9 | 10 | 11 | def load_ranking(rank_file, relevance, n_sample, depth): 12 | with open(rank_file) as rf: 13 | lines = iter(rf) 14 | q_0, p_0, _ = next(lines).strip().split() 15 | 16 | curr_q = q_0 17 | negatives = [] if p_0 in relevance[q_0] else [p_0] 18 | 19 | while True: 20 | try: 21 | q, p, _ = next(lines).strip().split() 22 | if q != curr_q: 23 | negatives = negatives[:depth] 24 | random.shuffle(negatives) 25 | yield curr_q, relevance[curr_q], negatives[:n_sample] 26 | curr_q = q 27 | negatives = [] if p in relevance[q] else [p] 28 | else: 29 | if p not in relevance[q]: 30 | negatives.append(p) 31 | except StopIteration: 32 | negatives = negatives[:depth] 33 | random.shuffle(negatives) 34 | yield curr_q, relevance[curr_q], negatives[:n_sample] 35 | return 36 | 37 | 38 | random.seed(datetime.now()) 39 | parser = ArgumentParser() 40 | parser.add_argument('--tokenizer_name', required=True) 41 | parser.add_argument('--hn_file', required=True) 42 | parser.add_argument('--qrels', required=True) 43 | parser.add_argument('--queries', required=True) 44 | parser.add_argument('--collection', required=True) 45 | parser.add_argument('--save_to', required=True) 46 | 47 | parser.add_argument('--truncate', type=int, default=128) 48 | parser.add_argument('--n_sample', type=int, default=30) 49 | parser.add_argument('--depth', type=int, default=200) 50 | parser.add_argument('--mp_chunk_size', type=int, default=500) 51 | parser.add_argument('--shard_size', type=int, default=45000) 52 | 53 | args = parser.parse_args() 54 | 55 | qrel = TrainProcessor.read_qrel(args.qrels) 56 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 57 | processor = TrainProcessor( 58 | query_file=args.queries, 59 | collection_file=args.collection, 60 | tokenizer=tokenizer, 61 | max_length=args.truncate, 62 | ) 63 | 64 | counter = 0 65 | shard_id = 0 66 | f = None 67 | os.makedirs(args.save_to, exist_ok=True) 68 | 69 | pbar = tqdm(load_ranking(args.hn_file, qrel, args.n_sample, args.depth)) 70 | with Pool() as p: 71 | for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size): 72 | counter += 1 73 | if f is None: 74 | f = open(os.path.join(args.save_to, f'split{shard_id:02d}.hn.json'), 'w') 75 | pbar.set_description(f'split - {shard_id:02d}') 76 | f.write(x + '\n') 77 | 78 | if counter == args.shard_size: 79 | f.close() 80 | f = None 81 | shard_id += 1 82 | counter = 0 83 | 84 | if f is not None: 85 | f.close() -------------------------------------------------------------------------------- /src/dense/processor/processors.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import datasets 4 | from transformers import PreTrainedTokenizer 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class SimpleTrainProcessor: 10 | query_file: str 11 | collection_file: str 12 | tokenizer: PreTrainedTokenizer 13 | 14 | max_length: int = 128 15 | columns = ['text_id', 'title', 'text'] 16 | title_field = 'title' 17 | text_field = 'text' 18 | 19 | def __post_init__(self): 20 | self.queries = self.read_queries(self.query_file) 21 | self.collection = datasets.load_dataset( 22 | 'csv', 23 | data_files=self.collection_file, 24 | column_names=self.columns, 25 | delimiter='\t', 26 | )['train'] 27 | 28 | @staticmethod 29 | def read_queries(queries): 30 | qmap = {} 31 | with open(queries) as f: 32 | for l in f: 33 | qid, qry = l.strip().split('\t') 34 | qmap[qid] = qry 35 | return qmap 36 | 37 | @staticmethod 38 | def read_qrel(relevance_file): 39 | qrel = {} 40 | with open(relevance_file, encoding='utf8') as f: 41 | tsvreader = csv.reader(f, delimiter="\t") 42 | for [topicid, _, docid, rel] in tsvreader: 43 | assert rel == "1" 44 | if topicid in qrel: 45 | qrel[topicid].append(docid) 46 | else: 47 | qrel[topicid] = [docid] 48 | return qrel 49 | 50 | def get_query(self, q): 51 | query_encoded = self.tokenizer.encode( 52 | self.queries[q], 53 | add_special_tokens=False, 54 | max_length=self.max_length, 55 | truncation=True 56 | ) 57 | return query_encoded 58 | 59 | def get_passage(self, p): 60 | entry = self.collection[int(p)] 61 | title = entry[self.title_field] 62 | title = "" if title is None else title 63 | body = entry[self.text_field] 64 | content = title + self.tokenizer.sep_token + body 65 | 66 | passage_encoded = self.tokenizer.encode( 67 | content, 68 | add_special_tokens=False, 69 | max_length=self.max_length, 70 | truncation=True 71 | ) 72 | 73 | return passage_encoded 74 | 75 | def process_one(self, train): 76 | q, pp, nn = train 77 | train_example = { 78 | 'query': self.get_query(q), 79 | 'positives': [self.get_passage(p) for p in pp], 80 | 'negatives': [self.get_passage(n) for n in nn], 81 | } 82 | 83 | return json.dumps(train_example) 84 | 85 | 86 | @dataclass 87 | class SimpleCollectionProcessor: 88 | tokenizer: PreTrainedTokenizer 89 | separator: str = '\t' 90 | max_length: int = 128 91 | 92 | def process_line(self, line: str): 93 | xx = line.strip().split(self.separator) 94 | text_id, text = xx[0], xx[1:] 95 | text_encoded = self.tokenizer.encode( 96 | self.tokenizer.sep_token.join(text), 97 | add_special_tokens=False, 98 | max_length=self.max_length, 99 | truncation=True 100 | ) 101 | encoded = { 102 | 'text_id': text_id, 103 | 'text': text_encoded 104 | } 105 | return json.dumps(encoded) 106 | -------------------------------------------------------------------------------- /src/dense/driver/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | from transformers import ( 7 | HfArgumentParser, 8 | set_seed, 9 | ) 10 | 11 | from dense.arguments import ModelArguments, DataArguments, \ 12 | DenseTrainingArguments as TrainingArguments 13 | from dense.data import TrainDataset, QPCollator 14 | from dense.modeling import DenseModel 15 | from dense.trainer import DenseTrainer as Trainer, GCTrainer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def main(): 21 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 22 | 23 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 24 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 25 | else: 26 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 27 | model_args: ModelArguments 28 | data_args: DataArguments 29 | training_args: TrainingArguments 30 | 31 | if ( 32 | os.path.exists(training_args.output_dir) 33 | and os.listdir(training_args.output_dir) 34 | and training_args.do_train 35 | and not training_args.overwrite_output_dir 36 | ): 37 | raise ValueError( 38 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 39 | ) 40 | 41 | # Setup logging 42 | logging.basicConfig( 43 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 44 | datefmt="%m/%d/%Y %H:%M:%S", 45 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 46 | ) 47 | logger.warning( 48 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 49 | training_args.local_rank, 50 | training_args.device, 51 | training_args.n_gpu, 52 | bool(training_args.local_rank != -1), 53 | training_args.fp16, 54 | ) 55 | logger.info("Training/evaluation parameters %s", training_args) 56 | logger.info("MODEL parameters %s", model_args) 57 | 58 | set_seed(training_args.seed) 59 | 60 | num_labels = 1 61 | config = AutoConfig.from_pretrained( 62 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 63 | num_labels=num_labels, 64 | cache_dir=model_args.cache_dir, 65 | ) 66 | tokenizer = AutoTokenizer.from_pretrained( 67 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 68 | cache_dir=model_args.cache_dir, 69 | use_fast=False, 70 | ) 71 | model = DenseModel.build( 72 | model_args, 73 | data_args, 74 | training_args, 75 | config=config, 76 | cache_dir=model_args.cache_dir, 77 | ) 78 | 79 | train_dataset = TrainDataset( 80 | data_args, data_args.train_path, tokenizer, 81 | ) 82 | 83 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer 84 | trainer = trainer_cls( 85 | model=model, 86 | args=training_args, 87 | train_dataset=train_dataset, 88 | data_collator=QPCollator( 89 | tokenizer, 90 | max_p_len=data_args.p_max_len, 91 | max_q_len=data_args.q_max_len 92 | ), 93 | ) 94 | train_dataset.trainer = trainer 95 | 96 | trainer.train( 97 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 98 | ) 99 | trainer.save_model() 100 | if trainer.is_world_process_zero(): 101 | tokenizer.save_pretrained(training_args.output_dir) 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /src/dense/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, Union 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | target_model_path: str = field( 13 | default=None, 14 | metadata={"help": "Path to pretrained reranker target model"} 15 | ) 16 | config_name: Optional[str] = field( 17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 18 | ) 19 | tokenizer_name: Optional[str] = field( 20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 21 | ) 22 | cache_dir: Optional[str] = field( 23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 24 | ) 25 | 26 | # modeling 27 | untie_encoder: bool = field( 28 | default=False, 29 | metadata={"help": "no weight sharing between qry passage encoders"} 30 | ) 31 | 32 | # out projection 33 | add_pooler: bool = field(default=False) 34 | projection_in_dim: int = field(default=768) 35 | projection_out_dim: int = field(default=768) 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | train_dir: str = field( 41 | default=None, metadata={"help": "Path to train directory"} 42 | ) 43 | dataset_name: str = field( 44 | default=None, metadata={"help": "huggingface dataset name"} 45 | ) 46 | dataset_proc_num: int = field( 47 | default=12, metadata={"help": "number of proc used in dataset preprocess"} 48 | ) 49 | train_n_passages: int = field(default=8) 50 | 51 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"}) 52 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"}) 53 | encode_is_qry: bool = field(default=False) 54 | encode_num_shard: int = field(default=1) 55 | encode_shard_index: int = field(default=0) 56 | 57 | q_max_len: int = field( 58 | default=32, 59 | metadata={ 60 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 61 | "than this will be truncated, sequences shorter will be padded." 62 | }, 63 | ) 64 | p_max_len: int = field( 65 | default=128, 66 | metadata={ 67 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " 68 | "than this will be truncated, sequences shorter will be padded." 69 | }, 70 | ) 71 | 72 | def __post_init__(self): 73 | if self.dataset_name is not None: 74 | info = self.dataset_name.split('/') 75 | self.dataset_split = info[-1] if len(info) == 3 else 'train' 76 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info) 77 | if self.train_dir is not None: 78 | files = os.listdir(self.train_dir) 79 | self.train_path = [ 80 | os.path.join(self.train_dir, f) 81 | for f in files 82 | if f.endswith('tsv') or f.endswith('json') 83 | ] 84 | 85 | 86 | @dataclass 87 | class DenseTrainingArguments(TrainingArguments): 88 | warmup_ratio: float = field(default=0.1) 89 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) 90 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) 91 | 92 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) 93 | gc_q_chunk_size: int = field(default=4) 94 | gc_p_chunk_size: int = field(default=32) 95 | -------------------------------------------------------------------------------- /src/dense/driver/encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from contextlib import nullcontext 5 | from tqdm import tqdm 6 | 7 | import torch 8 | 9 | from torch.utils.data import DataLoader 10 | from transformers import AutoConfig, AutoTokenizer 11 | from transformers import ( 12 | HfArgumentParser, 13 | ) 14 | 15 | from dense.arguments import ModelArguments, DataArguments, \ 16 | DenseTrainingArguments as TrainingArguments 17 | from dense.data import EncodeDataset, EncodeCollator 18 | from dense.modeling import DenseOutput, DenseModelForInference 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def main(): 24 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 25 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 26 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 27 | else: 28 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 29 | model_args: ModelArguments 30 | data_args: DataArguments 31 | training_args: TrainingArguments 32 | 33 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 34 | raise NotImplementedError('Multi-GPU encoding is not supported.') 35 | 36 | # Setup logging 37 | logging.basicConfig( 38 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 39 | datefmt="%m/%d/%Y %H:%M:%S", 40 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 41 | ) 42 | 43 | num_labels = 1 44 | config = AutoConfig.from_pretrained( 45 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 46 | num_labels=num_labels, 47 | cache_dir=model_args.cache_dir, 48 | ) 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 51 | cache_dir=model_args.cache_dir, 52 | use_fast=False, 53 | ) 54 | 55 | model = DenseModelForInference.build( 56 | model_name_or_path=model_args.model_name_or_path, 57 | config=config, 58 | cache_dir=model_args.cache_dir, 59 | ) 60 | 61 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 62 | 63 | encode_dataset = EncodeDataset(data_args.encode_in_path, tokenizer, max_len=text_max_length) 64 | encode_loader = DataLoader( 65 | encode_dataset, 66 | batch_size=training_args.per_device_eval_batch_size, 67 | collate_fn=EncodeCollator( 68 | tokenizer, 69 | max_length=text_max_length, 70 | padding='max_length' 71 | ), 72 | shuffle=False, 73 | drop_last=False, 74 | num_workers=training_args.dataloader_num_workers, 75 | ) 76 | encoded = [] 77 | lookup_indices = [] 78 | model = model.to(training_args.device) 79 | model.eval() 80 | 81 | for (batch_ids, batch) in tqdm(encode_loader): 82 | lookup_indices.extend(batch_ids) 83 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 84 | with torch.no_grad(): 85 | for k, v in batch.items(): 86 | batch[k] = v.to(training_args.device) 87 | if data_args.encode_is_qry: 88 | model_output: DenseOutput = model(query=batch) 89 | encoded.append(model_output.q_reps.cpu()) 90 | else: 91 | model_output: DenseOutput = model(passage=batch) 92 | encoded.append(model_output.p_reps.cpu()) 93 | 94 | encoded = torch.cat(encoded) 95 | torch.save((encoded, lookup_indices), data_args.encoded_save_path) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /src/dense/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | from typing import Dict, List, Tuple, Optional, Any, Union 4 | 5 | from transformers.trainer import Trainer 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | 11 | from .loss import SimpleContrastiveLoss, DistributedContrastiveLoss 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | try: 17 | from grad_cache import GradCache 18 | _grad_cache_available = True 19 | except ModuleNotFoundError: 20 | _grad_cache_available = False 21 | 22 | 23 | class DenseTrainer(Trainer): 24 | def __init__(self, *args, **kwargs): 25 | super(DenseTrainer, self).__init__(*args, **kwargs) 26 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1 27 | 28 | def _save(self, output_dir: Optional[str] = None): 29 | output_dir = output_dir if output_dir is not None else self.args.output_dir 30 | os.makedirs(output_dir, exist_ok=True) 31 | logger.info("Saving model checkpoint to %s", output_dir) 32 | self.model.save(output_dir) 33 | 34 | def _prepare_inputs( 35 | self, 36 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...] 37 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]: 38 | prepared = [] 39 | for x in inputs: 40 | if isinstance(x, torch.Tensor): 41 | prepared.append(x.to(self.args.device)) 42 | else: 43 | prepared.append(super()._prepare_inputs(x)) 44 | return prepared 45 | 46 | def get_train_dataloader(self) -> DataLoader: 47 | if self.train_dataset is None: 48 | raise ValueError("Trainer: training requires a train_dataset.") 49 | train_sampler = self._get_train_sampler() 50 | 51 | return DataLoader( 52 | self.train_dataset, 53 | batch_size=self.args.train_batch_size, 54 | sampler=train_sampler, 55 | collate_fn=self.data_collator, 56 | drop_last=True, 57 | num_workers=self.args.dataloader_num_workers, 58 | ) 59 | 60 | def compute_loss(self, model, inputs): 61 | query, passage = inputs 62 | return model(query=query, passage=passage).loss 63 | 64 | def training_step(self, *args): 65 | return super(DenseTrainer, self).training_step(*args) / self._dist_loss_scale_factor 66 | 67 | 68 | def split_dense_inputs(model_input: dict, chunk_size: int): 69 | assert len(model_input) == 1 70 | arg_key = list(model_input.keys())[0] 71 | arg_val = model_input[arg_key] 72 | 73 | keys = list(arg_val.keys()) 74 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] 75 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 76 | 77 | return [{arg_key: c} for c in chunked_arg_val] 78 | 79 | 80 | def get_dense_rep(x): 81 | if x.q_reps is None: 82 | return x.p_reps 83 | else: 84 | return x.q_reps 85 | 86 | 87 | class GCTrainer(DenseTrainer): 88 | def __init__(self, *args, **kwargs): 89 | logger.info('Initializing Gradient Cache Trainer') 90 | if not _grad_cache_available: 91 | raise ValueError( 92 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') 93 | super(GCTrainer, self).__init__(*args, **kwargs) 94 | 95 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss 96 | loss_fn = loss_fn_cls(self.model.data_args.train_n_passages) 97 | 98 | self.gc = GradCache( 99 | models=[self.model, self.model], 100 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], 101 | loss_fn=loss_fn, 102 | split_input_fn=split_dense_inputs, 103 | get_rep_fn=get_dense_rep, 104 | fp16=self.args.fp16, 105 | scaler=self.scaler 106 | ) 107 | 108 | def training_step(self, model, inputs) -> torch.Tensor: 109 | model.train() 110 | queries, passages = self._prepare_inputs(inputs) 111 | queries, passages = {'query': queries}, {'passage': passages} 112 | 113 | _distributed = self.args.local_rank > -1 114 | self.gc.models = [model, model] 115 | loss = self.gc(queries, passages, no_sync_except_last=_distributed) 116 | 117 | return loss / self._dist_loss_scale_factor 118 | -------------------------------------------------------------------------------- /examples/wikipedia-nq/README.md: -------------------------------------------------------------------------------- 1 | # Wikipedia Natural Questions & TriviaQA 2 | 3 | ## NQ 4 | We will use NQ as an example to show the process. 5 | 6 | ### 1. Prepare train data 7 | We use the train data provided by [DPR repo](https://github.com/facebookresearch/DPR). 8 | 1. Download train data 9 | ```bash 10 | $ wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz 11 | $ gzip -d biencoder-nq-train.json.gz 12 | ``` 13 | 2. Convert train data format & do tokenization 14 | ```bash 15 | $ python prepare_wiki_train.py --input biencoder-nq-train.json \ 16 | --output nq-train \ 17 | --tokenizer bert-base-uncased 18 | ``` 19 | 20 | ### 2. Train 21 | ```bash 22 | TRAIN_DIR=nq-train 23 | OUTDIR=model-nq 24 | 25 | python -m torch.distributed.launch --nproc_per_node=4 -m dense.driver.train \ 26 | --output_dir $OUTDIR \ 27 | --model_name_or_path bert-base-uncased \ 28 | --do_train \ 29 | --save_steps 20000 \ 30 | --train_dir $TRAIN_DIR \ 31 | --fp16 \ 32 | --per_device_train_batch_size 32 \ 33 | --train_n_passages 2 \ 34 | --learning_rate 1e-5 \ 35 | --q_max_len 32 \ 36 | --p_max_len 156 \ 37 | --num_train_epochs 40 \ 38 | --negatives_x_device 39 | ``` 40 | 41 | ### Encode 42 | Download wikipedia corpus 43 | ```bash 44 | wget https://www.dropbox.com/s/8ocbt0qpykszgeu/wikipedia-corpus.tar.gz 45 | tar -xvf wikipedia-corpus.tar.gz 46 | ``` 47 | 48 | Encode Corpus 49 | ```bash 50 | ENCODE_DIR=embeddings-nq 51 | OUTDIR=temp 52 | MODEL_DIR=model-nq 53 | CORPUS_DIR=wikipedia-corpus 54 | mkdir $ENCODE_DIR 55 | for s in $(seq -f "%02g" 0 21) 56 | do 57 | python -m dense.driver.encode \ 58 | --output_dir=$OUTDIR \ 59 | --model_name_or_path $MODEL_DIR \ 60 | --fp16 \ 61 | --per_device_eval_batch_size 156 \ 62 | --encode_in_path $CORPUS_DIR/docs$s.json \ 63 | --encoded_save_path $ENCODE_DIR/$s.pt 64 | done 65 | ``` 66 | 67 | Download queries 68 | ```bash 69 | wget https://www.dropbox.com/s/x4abrhszjssq6gl/nq-test-queries.json 70 | wget https://www.dropbox.com/s/b64e07jzlji8zhl/trivia-test-queries.json 71 | ``` 72 | 73 | Encode Query 74 | ```bash 75 | ENCODE_QRY_DIR=embeddings-nq-queries 76 | OUTDIR=temp 77 | MODEL_DIR=model-nq 78 | QUERY=nq-test-queries.json 79 | mkdir $ENCODE_QRY_DIR 80 | python -m dense.driver.encode \ 81 | --output_dir=$OUTDIR \ 82 | --model_name_or_path $MODEL_DIR \ 83 | --fp16 \ 84 | --per_device_eval_batch_size 156 \ 85 | --encode_in_path $QUERY \ 86 | --encoded_save_path $ENCODE_QRY_DIR/query.pt 87 | ``` 88 | 89 | 90 | ### Search 91 | ```bash 92 | ENCODE_QRY_DIR=embeddings-nq-queries 93 | ENCODE_DIR=embeddings-nq 94 | DEPTH=100 95 | RUN=run.nq.test.txt 96 | python -m dense.faiss_retriever \ 97 | --query_reps $ENCODE_QRY_DIR/query.pt \ 98 | --passage_reps $ENCODE_DIR/'*.pt' \ 99 | --depth $DEPTH \ 100 | --batch_size -1 \ 101 | --save_text \ 102 | --save_ranking_to $RUN 103 | ``` 104 | 105 | ### Evaluation 106 | Convert result to trec format 107 | ```bash 108 | RUN=run.nq.test.txt 109 | TREC_RUN=run.nq.test.trec 110 | python -m dense.utils.format.result_to_trec --input $RUN --output $TREC_RUN 111 | ``` 112 | 113 | Evaluate with Pyserini for now, `pip install pyserini` 114 | Recover query and passage contents 115 | ```bash 116 | TREC_RUN=run.nq.test.trec 117 | JSON_RUN=run.nq.test.json 118 | $ python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run --topics dpr-nq-test \ 119 | --index wikipedia-dpr \ 120 | --input $TREC_RUN \ 121 | --output $JSON_RUN 122 | ``` 123 | ```bash 124 | $ python -m pyserini.eval.evaluate_dpr_retrieval --retrieval $JSON_RUN --topk 20 100 125 | Top20 accuracy: 0.8002770083102493 126 | Top100 accuracy: 0.871191135734072 127 | ``` 128 | 129 | ## TriviaQA 130 | To train dense retriever for TriviaQA, simply replace all the `nq` in above command with `trivia` 131 | The retrieval accuracy we get by using our repo is: 132 | ```bash 133 | Top20 accuracy: 0.8112790594890834 134 | Top100 accuracy: 0.8629010872447627 135 | ``` 136 | 137 | ## Un-tie model 138 | Un-tie model is that the query encoder and passage encoder do not share parameters. 139 | To train untie models, simply add `--untie_encoder` option to the training command. 140 | 141 | ## Summary 142 | Using the process above should be able to obtain `top-k` retrieval accuracy as below: 143 | 144 | | Dataset/Model | Top20 | Top100 | 145 | |----------------|-------|--------| 146 | | NQ | 0.81 | 0.86 | 147 | | NQ-untie | 0.80 | 0.87 | 148 | | TriviaQA | 0.81 | 0.86 | 149 | | TriviaQA-untie | 0.81 | 0.86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense 2 | Dense is a simple and efficient toolkit for training and running dense retrievers with deep language models. The toolkit has a modularized design for easy research; a set of command line tools are also provided for fast development and testing. A set of easy-to-use interfaces to Huggingface🤗's state-of-the-art pre-trained transformers ensures Dense's superior performance. 3 | 4 | *Dense is currently under initial development stage. We will be actively adding new features and API changes may happen.* 5 | 6 | ## Features 7 | - Command line interface for dense retriever training/encoding and dense index search. 8 | - Flexible and extendable Pytorch retriever models. 9 | - Highly efficient Trainer, a subclass of Huggingface Trainer, that naively support training performance features like mixed precision and distributed data parallel. 10 | - Fast and memory-efficient train/inference data access based on memory mapping with Apache Arrow through Huggingface datasets. 11 | 12 | ## Installation 13 | First install the dependencies. The current code base has been testes with, 14 | ``` 15 | pytorch==1.8.0 16 | faiss-cpu==1.6.5 17 | transformers==4.2.0 18 | datasets==1.1.3 19 | ``` 20 | Then clone this repo and run pip. 21 | ``` 22 | git https://github.com/luyug/Dense 23 | cd Dense 24 | pip install . 25 | ``` 26 | Or typically for research, install as editable, 27 | ``` 28 | pip install --editable . 29 | ``` 30 | 31 | ## Data Format 32 | Training: Each line of the the Train file is a training instance, 33 | ``` 34 | {'query': TEXT_TYPE, 'positives': List[TEXT_TYPE], 'negatives': List[TEXT_TYPE]} 35 | ... 36 | ``` 37 | Inference/Encoding: Each line of the the encoding file is a piece of text to be encoded, 38 | ``` 39 | {text_id: "xxx", 'text': TEXT_TYPE} 40 | ... 41 | ``` 42 | Here `TEXT_TYPE` can be either raw string or pre-tokenized ids, i.e. `List[int]`. Using the latter can help lower data processing latency during training to reduce/eliminate GPU wait. **Note**: the current code requires text_id of passages/contexts to be convertible to integer, e.g. integers or string of integers. 43 | 44 | ## Training (Simple) 45 | To train a simple dense retriever, call the `dense.driver.train` module, 46 | ``` 47 | python -m dense.driver.train \ 48 | --output_dir $OUTDIR \ 49 | --model_name_or_path bert-base-uncased \ 50 | --do_train \ 51 | --save_steps 20000 \ 52 | --train_dir $TRAIN_DIR \ 53 | --fp16 \ 54 | --per_device_train_batch_size 8 \ 55 | --learning_rate 5e-6 \ 56 | --num_train_epochs 2 \ 57 | --dataloader_num_workers 2 58 | ``` 59 | Here we picked `bert-base-uncased` BERT weight from Huggingface Hub and turned on AMP with `--fp16` to speed up training. Several command flags are provided in addition to configure the learned model, e.g. `--add_pooler` which adds an linear projection. A full list command line arguments can be found in `dense.arguments`. 60 | 61 | ## Training (Research) 62 | Check out the [run.py](examples/run.py) in examples directory for a fully configurable train/test loop. Typically you will do, 63 | ``` 64 | from dense.modeling import DenseModel 65 | from dense.trainer import DenseTrainer as Trainer 66 | 67 | ... 68 | model = DenseModel.build( 69 | model_args, 70 | data_args, 71 | training_args, 72 | config=config, 73 | cache_dir=model_args.cache_dir, 74 | ) 75 | trainer = Trainer( 76 | model=model, 77 | args=training_args, 78 | train_dataset=train_dataset, 79 | data_collator=collator, 80 | ) 81 | ... 82 | trainer.train() 83 | ``` 84 | 85 | 86 | ## Encoding 87 | To encode, call the `dense.driver.encode` module. For large corpus, split the corpus into shards to parallelize. 88 | ``` 89 | for s in shard1 shar2 shard3 90 | do 91 | python -m dense.driver.encode \ 92 | --output_dir=$OUTDIR \ 93 | --tokenizer_name $TOK \ 94 | --config_name $CONFIG \ 95 | --model_name_or_path $MODEL_DIR \ 96 | --fp16 \ 97 | --per_device_eval_batch_size 128 \ 98 | --encode_in_path $CORPUS_DIR/$s.json \ 99 | --encoded_save_path $ENCODE_DIR/$s.pt 100 | done 101 | ``` 102 | ## Index Search 103 | Call the `dense.faiss_retriever` module, 104 | ``` 105 | python -m dense.faiss_retriever \ 106 | --query_reps $ENCODE_QRY_DIR/qry.pt \ 107 | --passage_reps $ENCODE_DIR/'*.pt' \ 108 | --depth $DEPTH \ 109 | --batch_size -1 \ 110 | --save_text \ 111 | --save_ranking_to rank.tsv 112 | ``` 113 | Encoded corpus or corpus shards are loaded based on glob pattern matching of argument `--passage_reps`. Argument `--batch_size` controls number of queries passed to the FAISS index each search call and `-1` will pass all queries in one call. Larger batches typically run faster (due to better memory access patterns and hardware utilization.) Setting flag `--save_text` will save the ranking to a tsv file with each line being `qid pid score`. 114 | 115 | Alternatively paralleize search over the shards, 116 | ``` 117 | for s in shard1 shar2 shard3 118 | do 119 | python -m dense.faiss_retriever \ 120 | --query_reps $ENCODE_QRY_DIR/qry.pt \ 121 | --passage_reps $ENCODE_DIR/$s.pt \ 122 | --depth $DEPTH \ 123 | --save_ranking_to $INTERMEDIATE_DIR/$s 124 | done 125 | ``` 126 | Then combine the results using the reducer module, 127 | ``` 128 | python -m dense.faiss_retriever.reducer \ 129 | --score_dir $INTERMEDIATE_DIR \ 130 | --query $ENCODE_QRY_DIR/qry.pt \ 131 | --save_ranking_to rank.txt 132 | ``` 133 | -------------------------------------------------------------------------------- /examples/coCondenser-marco/README.md: -------------------------------------------------------------------------------- 1 | # coCondenser MS-MARCO Passage Retrieval 2 | ## coCondenser 3 | You can find details about coCondenser pre-training in its [paper](https://arxiv.org/abs/2108.05540) and [open source code](https://github.com/luyug/Condenser), 4 | ``` 5 | @misc{gao2021unsupervised, 6 | title={Unsupervised Corpus Aware Language Model Pre-training for Dense Passage Retrieval}, 7 | author={Luyu Gao and Jamie Callan}, 8 | year={2021}, 9 | eprint={2108.05540}, 10 | archivePrefix={arXiv}, 11 | primaryClass={cs.IR} 12 | } 13 | ``` 14 | ## Get Data 15 | Run, 16 | ``` 17 | bash get_data.sh 18 | ``` 19 | This downloads the cleaned corpus hosted by RocketQA team, generate BM25 negatives and tokenize train/inference data using BERT tokenizer. 20 | The process could take up to tens of minutes depending on connection and hardware. 21 | 22 | ## Inference with Fine-tuned Checkpoint 23 | You can obtain a fine-tuned retriever from HF hub using the identifier ` Luyu/co-condenser-marco-retriever`. 24 | ### Encode 25 | ``` 26 | mkdir -p encoding/corpus 27 | mkdir -p encoding/query 28 | for i in $(seq -f "%02g" 0 9) 29 | do 30 | python -m dense.driver.encode \ 31 | --output_dir ./retriever_model \ 32 | --model_name_or_path Luyu/co-condenser-marco-retriever \ 33 | --fp16 \ 34 | --per_device_eval_batch_size 128 \ 35 | --encode_in_path marco/bert/corpus/split${i}.json \ 36 | --encoded_save_path encoding/corpus/split${i}.pt 37 | done 38 | 39 | 40 | python -m dense.driver.encode \ 41 | --output_dir ./retriever_model \ 42 | --model_name_or_path Luyu/co-condenser-marco-retriever \ 43 | --fp16 \ 44 | --q_max_len 32 \ 45 | --encode_is_qry \ 46 | --per_device_eval_batch_size 128 \ 47 | --encode_in_path marco/bert/query/dev.query.json \ 48 | --encoded_save_path encoding/query/qry.pt 49 | ``` 50 | ### Index Search 51 | ``` 52 | python -m dense.faiss_retriever \ 53 | --query_reps encoding/query/qry.pt \ 54 | --passage_reps corpus/corpus/'*.pt' \ 55 | --depth 10 \ 56 | --batch_size -1 \ 57 | --save_text \ 58 | --save_ranking_to rank.tsv 59 | ``` 60 | And format the retrieval result, 61 | ``` 62 | python ../msmarco-passage-ranking/score_to_marco.py rank.txt 63 | ``` 64 | ## Fine-tuning Stage 1 65 | Pick a pre-trained condenser that is most suitable for the experiment from [Condenser Repo](https://github.com/luyug/Condenser#pre-trained-models). 66 | Train 67 | ``` 68 | python -m dense.driver.train \ 69 | --output_dir ./retriever_model_s1 \ 70 | --model_name_or_path CONDENSER_MODEL_NAME \ 71 | --save_steps 20000 \ 72 | --train_dir ./marco/bert/train \ 73 | --fp16 \ 74 | --per_device_train_batch_size 8 \ 75 | --learning_rate 5e-6 \ 76 | --num_train_epochs 3 \ 77 | --dataloader_num_workers 2 78 | ``` 79 | ## Mining Hard Negatives 80 | ### Encode 81 | Encode corpus and train queries, 82 | ``` 83 | mkdir -p encoding/corpus 84 | mkdir -p encoding/query 85 | for i in $(seq -f "%02g" 0 9) 86 | do 87 | python -m dense.driver.encode \ 88 | --output_dir ./retriever_model \ 89 | --model_name_or_path ./retriever_model_s1 \ 90 | --fp16 \ 91 | --per_device_eval_batch_size 128 \ 92 | --encode_in_path marco/bert/corpus/split${i}.json \ 93 | --encoded_save_path encoding/corpus/split${i}.pt 94 | done 95 | 96 | python -m dense.driver.encode \ 97 | --output_dir ./retriever_model \ 98 | --model_name_or_path ./retriever_model_s1 \ 99 | --fp16 \ 100 | --q_max_len 32 \ 101 | --encode_is_qry \ 102 | --per_device_eval_batch_size 128 \ 103 | --encode_in_path marco/bert/query/train.query.json \ 104 | --encoded_save_path encoding/query/train.pt 105 | ``` 106 | 107 | ### Search 108 | ``` 109 | python -m dense.faiss_retriever \ 110 | --query_reps encoding/query/train.pt \ 111 | --passage_reps corpus/corpus/'*.pt' \ 112 | --batch_size 5000 \ 113 | --save_text \ 114 | --save_ranking_to train.rank.tsv 115 | ``` 116 | 117 | ### Build HN Train file 118 | ``` 119 | bash create_hn.sh 120 | ``` 121 | 122 | ## Fine-tuning Stage 2 123 | ``` 124 | python -m dense.driver.train \ 125 | --output_dir ./retriever_model_s2 \ 126 | --model_name_or_path CONDENSER_MODEL_NAME \ 127 | --save_steps 20000 \ 128 | --train_dir ./marco/bert/train-hn \ 129 | --fp16 \ 130 | --per_device_train_batch_size 8 \ 131 | --learning_rate 5e-6 \ 132 | --num_train_epochs 2 \ 133 | --dataloader_num_workers 2 134 | ``` 135 | 136 | ## Encode and Search 137 | Do encoding, 138 | ``` 139 | mkdir -p encoding/corpus-s2 140 | mkdir -p encoding/query-s2 141 | for i in $(seq -f "%02g" 0 9) 142 | do 143 | python -m dense.driver.encode \ 144 | --output_dir ./retriever_model_s2 \ 145 | --model_name_or_path ./retriever_model_s2 \ 146 | --fp16 \ 147 | --per_device_eval_batch_size 128 \ 148 | --encode_in_path marco/bert/corpus/split${i}.json \ 149 | --encoded_save_path encoding/corpus-s2/split${i}.pt 150 | done 151 | 152 | python -m dense.driver.encode \ 153 | --output_dir ./retriever_model_s2 \ 154 | --model_name_or_path ./retriever_model_s2 \ 155 | --fp16 \ 156 | --q_max_len 32 \ 157 | --encode_is_qry \ 158 | --per_device_eval_batch_size 128 \ 159 | --encode_in_path marco/bert/query/dev.query.json \ 160 | --encoded_save_path encoding/query-s2/qry.pt 161 | ``` 162 | Run the retriever, 163 | ``` 164 | python -m dense.faiss_retriever \ 165 | --query_reps encoding/query-s2/qry.pt \ 166 | --passage_reps corpus/corpus-s2/'*.pt' \ 167 | --depth 10 \ 168 | --batch_size -1 \ 169 | --save_text \ 170 | --save_ranking_to dev.rank.tsv 171 | ``` 172 | And format the retrieval result, 173 | ``` 174 | python ../msmarco-passage-ranking/score_to_marco.py dev.rank.tsv 175 | ``` 176 | -------------------------------------------------------------------------------- /src/dense/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Union, List 4 | 5 | import datasets 6 | from torch.utils.data import Dataset 7 | from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding 8 | 9 | 10 | from .arguments import DataArguments 11 | from .trainer import DenseTrainer 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class TrainDataset(Dataset): 18 | def __init__( 19 | self, 20 | data_args: DataArguments, 21 | path_to_data: Union[List[str], datasets.Dataset], 22 | tokenizer: PreTrainedTokenizer, 23 | trainer: DenseTrainer = None, 24 | ): 25 | if isinstance(path_to_data, datasets.Dataset): 26 | self.train_data = path_to_data 27 | else: 28 | self.train_data = datasets.load_dataset( 29 | 'json', 30 | data_files=path_to_data, 31 | ignore_verifications=False, 32 | )['train'] 33 | 34 | self.tok = tokenizer 35 | self.trainer = trainer 36 | 37 | self.data_args = data_args 38 | self.total_len = len(self.train_data) 39 | 40 | def create_one_example(self, text_encoding: List[int], is_query=False): 41 | item = self.tok.encode_plus( 42 | text_encoding, 43 | truncation='only_first', 44 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, 45 | padding=False, 46 | return_attention_mask=False, 47 | return_token_type_ids=False, 48 | ) 49 | return item 50 | 51 | def __len__(self): 52 | return self.total_len 53 | 54 | def __getitem__(self, item) -> [BatchEncoding, List[BatchEncoding]]: 55 | group = self.train_data[item] 56 | epoch = int(self.trainer.state.epoch) 57 | 58 | _hashed_seed = hash(item + self.trainer.args.seed) 59 | 60 | qry = group['query'] 61 | encoded_query = self.create_one_example(qry, is_query=True) 62 | 63 | encoded_passages = [] 64 | group_positives = group['positives'] 65 | group_negatives = group['negatives'] 66 | 67 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)] 68 | encoded_passages.append(self.create_one_example(pos_psg)) 69 | 70 | negative_size = self.data_args.train_n_passages - 1 71 | if len(group_negatives) < negative_size: 72 | negs = random.choices(group_negatives, k=negative_size) 73 | elif self.data_args.train_n_passages == 1: 74 | negs = [] 75 | else: 76 | _offset = epoch * negative_size % len(group_negatives) 77 | negs = [x for x in group_negatives] 78 | random.Random(_hashed_seed).shuffle(negs) 79 | negs = negs * 2 80 | negs = negs[_offset: _offset + negative_size] 81 | 82 | for neg_psg in negs: 83 | encoded_passages.append(self.create_one_example(neg_psg)) 84 | 85 | return encoded_query, encoded_passages 86 | 87 | 88 | class EncodeDataset(Dataset): 89 | input_keys = ['text_id', 'text'] 90 | 91 | def __init__(self, path_to_json: Union[List[str], datasets.Dataset], tokenizer: PreTrainedTokenizer, max_len=128): 92 | if isinstance(path_to_json, datasets.Dataset): 93 | self.encode_data = path_to_json 94 | else: 95 | self.encode_data = datasets.load_dataset( 96 | 'json', 97 | data_files=path_to_json, 98 | )['train'] 99 | self.tok = tokenizer 100 | self.max_len = max_len 101 | 102 | def __len__(self): 103 | return len(self.encode_data) 104 | 105 | def __getitem__(self, item) -> [str, BatchEncoding]: 106 | text_id, text = (self.encode_data[item][f] for f in self.input_keys) 107 | encoded_text = self.tok.encode_plus( 108 | text, 109 | max_length=self.max_len, 110 | truncation='only_first', 111 | padding=False, 112 | return_token_type_ids=False, 113 | ) 114 | return text_id, encoded_text 115 | 116 | 117 | @dataclass 118 | class QPCollator(DataCollatorWithPadding): 119 | """ 120 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] 121 | and pass batch separately to the actual collator. 122 | Abstract out data detail for the model. 123 | """ 124 | max_q_len: int = 32 125 | max_p_len: int = 128 126 | 127 | def __call__(self, features): 128 | qq = [f[0] for f in features] 129 | dd = [f[1] for f in features] 130 | 131 | if isinstance(qq[0], list): 132 | qq = sum(qq, []) 133 | if isinstance(dd[0], list): 134 | dd = sum(dd, []) 135 | 136 | q_collated = self.tokenizer.pad( 137 | qq, 138 | padding='max_length', 139 | max_length=self.max_q_len, 140 | return_tensors="pt", 141 | ) 142 | d_collated = self.tokenizer.pad( 143 | dd, 144 | padding='max_length', 145 | max_length=self.max_p_len, 146 | return_tensors="pt", 147 | ) 148 | 149 | return q_collated, d_collated 150 | 151 | 152 | @dataclass 153 | class EncodeCollator(DataCollatorWithPadding): 154 | def __call__(self, features): 155 | text_ids = [x[0] for x in features] 156 | text_features = [x[1] for x in features] 157 | collated_features = super().__call__(text_features) 158 | return text_ids, collated_features -------------------------------------------------------------------------------- /examples/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from contextlib import nullcontext 5 | from tqdm import tqdm 6 | 7 | import torch 8 | 9 | from torch.utils.data import DataLoader 10 | from transformers import AutoConfig, AutoTokenizer 11 | from transformers import ( 12 | HfArgumentParser, 13 | set_seed, 14 | ) 15 | 16 | from dense.arguments import ModelArguments, DataArguments, \ 17 | DenseTrainingArguments as TrainingArguments 18 | from dense.data import TrainDataset, EncodeDataset, QPCollator, EncodeCollator 19 | from dense.modeling import DenseModel, DenseOutput 20 | from dense.trainer import DenseTrainer as Trainer 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(): 26 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 27 | 28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 29 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 30 | else: 31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 32 | model_args: ModelArguments 33 | data_args: DataArguments 34 | training_args: TrainingArguments 35 | 36 | if ( 37 | os.path.exists(training_args.output_dir) 38 | and os.listdir(training_args.output_dir) 39 | and training_args.do_train 40 | and not training_args.overwrite_output_dir 41 | ): 42 | raise ValueError( 43 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 44 | ) 45 | 46 | # Setup logging 47 | logging.basicConfig( 48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 49 | datefmt="%m/%d/%Y %H:%M:%S", 50 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 51 | ) 52 | logger.warning( 53 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 54 | training_args.local_rank, 55 | training_args.device, 56 | training_args.n_gpu, 57 | bool(training_args.local_rank != -1), 58 | training_args.fp16, 59 | ) 60 | logger.info("Training/evaluation parameters %s", training_args) 61 | logger.info("MODEL parameters %s", model_args) 62 | 63 | set_seed(training_args.seed) 64 | 65 | num_labels = 1 66 | config = AutoConfig.from_pretrained( 67 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 68 | num_labels=num_labels, 69 | cache_dir=model_args.cache_dir, 70 | ) 71 | tokenizer = AutoTokenizer.from_pretrained( 72 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 73 | cache_dir=model_args.cache_dir, 74 | use_fast=False, 75 | ) 76 | 77 | model = DenseModel.build( 78 | model_args, 79 | data_args, 80 | training_args, 81 | config=config, 82 | cache_dir=model_args.cache_dir, 83 | ) 84 | 85 | if training_args.do_train: 86 | train_dataset = TrainDataset( 87 | data_args, data_args.train_path, tokenizer 88 | ) 89 | else: 90 | train_dataset = None 91 | 92 | trainer = Trainer( 93 | model=model, 94 | args=training_args, 95 | train_dataset=train_dataset, 96 | data_collator=QPCollator( 97 | tokenizer, 98 | max_p_len=data_args.p_max_len, 99 | max_q_len=data_args.q_max_len 100 | ), 101 | ) 102 | 103 | if train_dataset is not None: 104 | train_dataset.trainer = trainer 105 | 106 | # Training 107 | if training_args.do_train: 108 | trainer.train( 109 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 110 | ) 111 | trainer.save_model() 112 | if trainer.is_world_process_zero(): 113 | tokenizer.save_pretrained(training_args.output_dir) 114 | 115 | if training_args.do_encode: 116 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 117 | raise NotImplementedError('Parallel encoding is not supported.') 118 | 119 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 120 | 121 | encode_dataset = EncodeDataset(data_args.encode_in_path, tokenizer, max_len=text_max_length) 122 | encode_loader = DataLoader( 123 | encode_dataset, 124 | batch_size=training_args.per_device_eval_batch_size, 125 | collate_fn=EncodeCollator( 126 | tokenizer, 127 | max_length=text_max_length, 128 | padding='max_length' 129 | ), 130 | shuffle=False, 131 | drop_last=False, 132 | num_workers=training_args.dataloader_num_workers, 133 | ) 134 | encoded = [] 135 | lookup_indices = [] 136 | model = model.to(training_args.device) 137 | model.eval() 138 | 139 | for (batch_ids, batch) in tqdm(encode_loader): 140 | lookup_indices.extend(batch_ids) 141 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 142 | with torch.no_grad(): 143 | for k, v in batch.items(): 144 | batch[k] = v.to(training_args.device) 145 | if data_args.encode_is_qry: 146 | model_output: DenseOutput = model(query=batch) 147 | encoded.append(model_output.q_reps.cpu()) 148 | else: 149 | model_output: DenseOutput = model(passage=batch) 150 | encoded.append(model_output.p_reps.cpu()) 151 | 152 | encoded = torch.cat(encoded) 153 | torch.save((encoded, lookup_indices), data_args.encoded_save_path) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /examples/wikipedia-nq/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from contextlib import nullcontext 5 | 6 | import datasets 7 | from tqdm import tqdm 8 | 9 | import torch 10 | 11 | from torch.utils.data import DataLoader 12 | from transformers import AutoConfig, AutoTokenizer 13 | from transformers import ( 14 | HfArgumentParser, 15 | set_seed, 16 | ) 17 | 18 | from dense.arguments import ModelArguments, DataArguments, \ 19 | DenseTrainingArguments as TrainingArguments 20 | from dense.data import TrainDataset, EncodeDataset, QPCollator, EncodeCollator 21 | from dense.modeling import DenseModel, DenseOutput 22 | from dense.trainer import DenseTrainer as Trainer 23 | from dense.dataset import PROCESSOR_INFO, TrainProcessor 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 30 | 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: TrainingArguments 38 | 39 | if ( 40 | os.path.exists(training_args.output_dir) 41 | and os.listdir(training_args.output_dir) 42 | and training_args.do_train 43 | and not training_args.overwrite_output_dir 44 | ): 45 | raise ValueError( 46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 47 | ) 48 | 49 | # Setup logging 50 | logging.basicConfig( 51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 52 | datefmt="%m/%d/%Y %H:%M:%S", 53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 54 | ) 55 | logger.warning( 56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 57 | training_args.local_rank, 58 | training_args.device, 59 | training_args.n_gpu, 60 | bool(training_args.local_rank != -1), 61 | training_args.fp16, 62 | ) 63 | logger.info("Training/evaluation parameters %s", training_args) 64 | logger.info("MODEL parameters %s", model_args) 65 | 66 | set_seed(training_args.seed) 67 | 68 | num_labels = 1 69 | config = AutoConfig.from_pretrained( 70 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 71 | num_labels=num_labels, 72 | cache_dir=model_args.cache_dir, 73 | ) 74 | tokenizer = AutoTokenizer.from_pretrained( 75 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 76 | cache_dir=model_args.cache_dir, 77 | use_fast=False, 78 | ) 79 | 80 | model = DenseModel.build( 81 | model_args, 82 | data_args, 83 | training_args, 84 | config=config, 85 | cache_dir=model_args.cache_dir, 86 | ) 87 | 88 | if training_args.do_train: 89 | if data_args.train_dir is not None: 90 | train_dataset = TrainDataset( 91 | data_args, data_args.train_path, tokenizer 92 | ) 93 | else: 94 | train_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_split)['train'] 95 | train_dataset = train_dataset.map( 96 | PROCESSOR_INFO[data_args.dataset_name][data_args.dataset_split](tokenizer, 97 | data_args.q_max_len, 98 | data_args.p_max_len), 99 | batched=False, 100 | num_proc=data_args.dataset_proc_num, 101 | remove_columns=train_dataset.column_names, 102 | desc="Running tokenizer on train dataset", 103 | ) 104 | train_dataset = TrainDataset(data_args, train_dataset, tokenizer) 105 | else: 106 | train_dataset = None 107 | 108 | trainer = Trainer( 109 | model=model, 110 | args=training_args, 111 | train_dataset=train_dataset, 112 | data_collator=QPCollator( 113 | tokenizer, 114 | max_p_len=data_args.p_max_len, 115 | max_q_len=data_args.q_max_len 116 | ), 117 | ) 118 | 119 | if train_dataset is not None: 120 | train_dataset.trainer = trainer 121 | 122 | # Training 123 | if training_args.do_train: 124 | trainer.train( 125 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 126 | ) 127 | trainer.save_model() 128 | if trainer.is_world_process_zero(): 129 | tokenizer.save_pretrained(training_args.output_dir) 130 | 131 | if training_args.do_encode: 132 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 133 | raise NotImplementedError('Parallel encoding is not supported.') 134 | 135 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 136 | if data_args.encode_in_path: 137 | encode_dataset = EncodeDataset(data_args.encode_in_path, tokenizer, max_len=text_max_length) 138 | encode_dataset.encode_data = encode_dataset.encode_data\ 139 | .shard(data_args.encode_num_shard, data_args.encode_shard_index) 140 | else: 141 | encode_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_split)['train']\ 142 | .shard(data_args.encode_num_shard, data_args.encode_shard_index) 143 | encode_dataset = encode_dataset.map( 144 | PROCESSOR_INFO[data_args.dataset_name][data_args.dataset_split](tokenizer, text_max_length), 145 | batched=False, 146 | num_proc=data_args.dataset_proc_num, 147 | remove_columns=encode_dataset.column_names, 148 | desc="Running tokenization", 149 | ) 150 | encode_dataset = EncodeDataset(encode_dataset, tokenizer, max_len=text_max_length) 151 | encode_loader = DataLoader( 152 | encode_dataset, 153 | batch_size=training_args.per_device_eval_batch_size, 154 | collate_fn=EncodeCollator( 155 | tokenizer, 156 | max_length=text_max_length, 157 | padding='max_length' 158 | ), 159 | shuffle=False, 160 | drop_last=False, 161 | num_workers=training_args.dataloader_num_workers, 162 | ) 163 | encoded = [] 164 | lookup_indices = [] 165 | model = model.to(training_args.device) 166 | model.eval() 167 | 168 | for (batch_ids, batch) in tqdm(encode_loader): 169 | lookup_indices.extend(batch_ids) 170 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 171 | with torch.no_grad(): 172 | for k, v in batch.items(): 173 | batch[k] = v.to(training_args.device) 174 | if data_args.encode_is_qry: 175 | model_output: DenseOutput = model(query=batch) 176 | encoded.append(model_output.q_reps.cpu()) 177 | else: 178 | model_output: DenseOutput = model(passage=batch) 179 | encoded.append(model_output.p_reps.cpu()) 180 | 181 | encoded = torch.cat(encoded) 182 | torch.save((encoded, lookup_indices), data_args.encoded_save_path) 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | -------------------------------------------------------------------------------- /examples/scifact/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from contextlib import nullcontext 5 | 6 | import datasets 7 | from tqdm import tqdm 8 | 9 | import torch 10 | 11 | from torch.utils.data import DataLoader 12 | from transformers import AutoConfig, AutoTokenizer 13 | from transformers import ( 14 | HfArgumentParser, 15 | set_seed, 16 | ) 17 | 18 | from dense.arguments import ModelArguments, DataArguments, \ 19 | DenseTrainingArguments as TrainingArguments 20 | from dense.data import TrainDataset, EncodeDataset, QPCollator, EncodeCollator 21 | from dense.modeling import DenseModel, DenseOutput 22 | from dense.trainer import DenseTrainer as Trainer, GCTrainer 23 | from dense.dataset import PROCESSOR_INFO, TrainProcessor 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 30 | 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: TrainingArguments 38 | 39 | if ( 40 | os.path.exists(training_args.output_dir) 41 | and os.listdir(training_args.output_dir) 42 | and training_args.do_train 43 | and not training_args.overwrite_output_dir 44 | ): 45 | raise ValueError( 46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 47 | ) 48 | 49 | # Setup logging 50 | logging.basicConfig( 51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 52 | datefmt="%m/%d/%Y %H:%M:%S", 53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 54 | ) 55 | logger.warning( 56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 57 | training_args.local_rank, 58 | training_args.device, 59 | training_args.n_gpu, 60 | bool(training_args.local_rank != -1), 61 | training_args.fp16, 62 | ) 63 | logger.info("Training/evaluation parameters %s", training_args) 64 | logger.info("MODEL parameters %s", model_args) 65 | 66 | set_seed(training_args.seed) 67 | 68 | num_labels = 1 69 | config = AutoConfig.from_pretrained( 70 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 71 | num_labels=num_labels, 72 | cache_dir=model_args.cache_dir, 73 | ) 74 | tokenizer = AutoTokenizer.from_pretrained( 75 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 76 | cache_dir=model_args.cache_dir, 77 | use_fast=False, 78 | ) 79 | 80 | model = DenseModel.build( 81 | model_args, 82 | data_args, 83 | training_args, 84 | config=config, 85 | cache_dir=model_args.cache_dir, 86 | ) 87 | 88 | if training_args.do_train: 89 | if data_args.train_dir is not None: 90 | train_dataset = TrainDataset( 91 | data_args, data_args.train_path, tokenizer 92 | ) 93 | else: 94 | train_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_split)['train'] 95 | train_dataset = train_dataset.map( 96 | PROCESSOR_INFO[data_args.dataset_name][data_args.dataset_split](tokenizer, 97 | data_args.q_max_len, 98 | data_args.p_max_len), 99 | batched=False, 100 | num_proc=data_args.dataset_proc_num, 101 | remove_columns=train_dataset.column_names, 102 | desc="Running tokenizer on train dataset", 103 | ) 104 | train_dataset = TrainDataset(data_args, train_dataset, tokenizer) 105 | else: 106 | train_dataset = None 107 | 108 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer 109 | trainer = trainer_cls( 110 | model=model, 111 | args=training_args, 112 | train_dataset=train_dataset, 113 | data_collator=QPCollator( 114 | tokenizer, 115 | max_p_len=data_args.p_max_len, 116 | max_q_len=data_args.q_max_len 117 | ), 118 | ) 119 | 120 | if train_dataset is not None: 121 | train_dataset.trainer = trainer 122 | 123 | # Training 124 | if training_args.do_train: 125 | trainer.train( 126 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 127 | ) 128 | trainer.save_model() 129 | if trainer.is_world_process_zero(): 130 | tokenizer.save_pretrained(training_args.output_dir) 131 | 132 | if training_args.do_encode: 133 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 134 | raise NotImplementedError('Parallel encoding is not supported.') 135 | 136 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 137 | if data_args.encode_in_path: 138 | encode_dataset = EncodeDataset(data_args.encode_in_path, tokenizer, max_len=text_max_length) 139 | encode_dataset.encode_data = encode_dataset.encode_data\ 140 | .shard(data_args.encode_num_shard, data_args.encode_shard_index) 141 | else: 142 | encode_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_split)['train']\ 143 | .shard(data_args.encode_num_shard, data_args.encode_shard_index) 144 | encode_dataset = encode_dataset.map( 145 | PROCESSOR_INFO[data_args.dataset_name][data_args.dataset_split](tokenizer, text_max_length), 146 | batched=False, 147 | num_proc=data_args.dataset_proc_num, 148 | remove_columns=encode_dataset.column_names, 149 | desc="Running tokenization", 150 | ) 151 | encode_dataset = EncodeDataset(encode_dataset, tokenizer, max_len=text_max_length) 152 | encode_loader = DataLoader( 153 | encode_dataset, 154 | batch_size=training_args.per_device_eval_batch_size, 155 | collate_fn=EncodeCollator( 156 | tokenizer, 157 | max_length=text_max_length, 158 | padding='max_length' 159 | ), 160 | shuffle=False, 161 | drop_last=False, 162 | num_workers=training_args.dataloader_num_workers, 163 | ) 164 | encoded = [] 165 | lookup_indices = [] 166 | model = model.to(training_args.device) 167 | model.eval() 168 | 169 | for (batch_ids, batch) in tqdm(encode_loader): 170 | lookup_indices.extend(batch_ids) 171 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 172 | with torch.no_grad(): 173 | for k, v in batch.items(): 174 | batch[k] = v.to(training_args.device) 175 | if data_args.encode_is_qry: 176 | model_output: DenseOutput = model(query=batch) 177 | encoded.append(model_output.q_reps.cpu()) 178 | else: 179 | model_output: DenseOutput = model(passage=batch) 180 | encoded.append(model_output.p_reps.cpu()) 181 | 182 | encoded = torch.cat(encoded) 183 | torch.save((encoded, lookup_indices), data_args.encoded_save_path) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/dense/modeling.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import copy 4 | from dataclasses import dataclass 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | import torch.distributed as dist 10 | 11 | from transformers import AutoModel, PreTrainedModel 12 | from transformers.modeling_outputs import ModelOutput 13 | 14 | 15 | from typing import Optional, Dict 16 | 17 | from .arguments import ModelArguments, DataArguments, \ 18 | DenseTrainingArguments as TrainingArguments 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | @dataclass 25 | class DenseOutput(ModelOutput): 26 | q_reps: Tensor = None 27 | p_reps: Tensor = None 28 | loss: Tensor = None 29 | scores: Tensor = None 30 | 31 | 32 | class LinearPooler(nn.Module): 33 | def __init__( 34 | self, 35 | input_dim: int = 768, 36 | output_dim: int = 768, 37 | tied=True 38 | ): 39 | super(LinearPooler, self).__init__() 40 | self.linear_q = nn.Linear(input_dim, output_dim) 41 | if tied: 42 | self.linear_p = self.linear_q 43 | else: 44 | self.linear_p = nn.Linear(input_dim, output_dim) 45 | 46 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied} 47 | 48 | def forward(self, q: Tensor = None, p: Tensor = None): 49 | if q is not None: 50 | return self.linear_q(q[:, 0]) 51 | elif p is not None: 52 | return self.linear_p(p[:, 0]) 53 | else: 54 | raise ValueError 55 | 56 | def load(self, ckpt_dir: str): 57 | if ckpt_dir is not None: 58 | _pooler_path = os.path.join(ckpt_dir, 'pooler.pt') 59 | if os.path.exists(_pooler_path): 60 | logger.info(f'Loading Pooler from {ckpt_dir}') 61 | state_dict = torch.load(os.path.join(ckpt_dir, 'pooler.pt'), map_location='cpu') 62 | self.load_state_dict(state_dict) 63 | return 64 | logger.info("Training Pooler from scratch") 65 | return 66 | 67 | def save_pooler(self, save_path): 68 | torch.save(self.state_dict(), os.path.join(save_path, 'pooler.pt')) 69 | with open(os.path.join(save_path, 'pooler_config.json'), 'w') as f: 70 | json.dump(self._config, f) 71 | 72 | 73 | class DenseModel(nn.Module): 74 | def __init__( 75 | self, 76 | lm_q: PreTrainedModel, 77 | lm_p: PreTrainedModel, 78 | pooler: nn.Module = None, 79 | model_args: ModelArguments = None, 80 | data_args: DataArguments = None, 81 | train_args: TrainingArguments = None, 82 | ): 83 | super().__init__() 84 | 85 | self.lm_q = lm_q 86 | self.lm_p = lm_p 87 | self.pooler = pooler 88 | 89 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 90 | 91 | self.model_args = model_args 92 | self.train_args = train_args 93 | self.data_args = data_args 94 | 95 | if train_args.negatives_x_device: 96 | if not dist.is_initialized(): 97 | raise ValueError('Distributed training has not been initialized for representation all gather.') 98 | self.process_rank = dist.get_rank() 99 | self.world_size = dist.get_world_size() 100 | 101 | def forward( 102 | self, 103 | query: Dict[str, Tensor] = None, 104 | passage: Dict[str, Tensor] = None, 105 | ): 106 | 107 | q_hidden, q_reps = self.encode_query(query) 108 | p_hidden, p_reps = self.encode_passage(passage) 109 | 110 | if q_reps is None or p_reps is None: 111 | return DenseOutput( 112 | q_reps=q_reps, 113 | p_reps=p_reps 114 | ) 115 | 116 | if self.training: 117 | if self.train_args.negatives_x_device: 118 | q_reps = self.dist_gather_tensor(q_reps) 119 | p_reps = self.dist_gather_tensor(p_reps) 120 | 121 | effective_bsz = self.train_args.per_device_train_batch_size * self.world_size \ 122 | if self.train_args.negatives_x_device \ 123 | else self.train_args.per_device_train_batch_size 124 | 125 | scores = torch.matmul(q_reps, p_reps.transpose(0, 1)) 126 | scores = scores.view(effective_bsz, -1) 127 | 128 | target = torch.arange( 129 | scores.size(0), 130 | device=scores.device, 131 | dtype=torch.long 132 | ) 133 | target = target * self.data_args.train_n_passages 134 | loss = self.cross_entropy(scores, target) 135 | if self.train_args.negatives_x_device: 136 | loss = loss * self.world_size # counter average weight reduction 137 | return DenseOutput( 138 | loss=loss, 139 | scores=scores, 140 | q_reps=q_reps, 141 | p_reps=p_reps 142 | ) 143 | 144 | else: 145 | loss = None 146 | if query and passage: 147 | scores = (q_reps * p_reps).sum(1) 148 | else: 149 | scores = None 150 | 151 | return DenseOutput( 152 | loss=loss, 153 | scores=scores, 154 | q_reps=q_reps, 155 | p_reps=p_reps 156 | ) 157 | 158 | def encode_passage(self, psg): 159 | if psg is None: 160 | return None, None 161 | 162 | psg_out = self.lm_p(**psg, return_dict=True) 163 | p_hidden = psg_out.last_hidden_state 164 | if self.pooler is not None: 165 | p_reps = self.pooler(p=p_hidden) # D * d 166 | else: 167 | p_reps = p_hidden[:, 0] 168 | return p_hidden, p_reps 169 | 170 | def encode_query(self, qry): 171 | if qry is None: 172 | return None, None 173 | qry_out = self.lm_q(**qry, return_dict=True) 174 | q_hidden = qry_out.last_hidden_state 175 | if self.pooler is not None: 176 | q_reps = self.pooler(q=q_hidden) 177 | else: 178 | q_reps = q_hidden[:, 0] 179 | return q_hidden, q_reps 180 | 181 | @staticmethod 182 | def build_pooler(model_args): 183 | pooler = LinearPooler( 184 | model_args.projection_in_dim, 185 | model_args.projection_out_dim, 186 | tied=not model_args.untie_encoder 187 | ) 188 | pooler.load(model_args.model_name_or_path) 189 | return pooler 190 | 191 | @classmethod 192 | def build( 193 | cls, 194 | model_args: ModelArguments, 195 | data_args: DataArguments, 196 | train_args: TrainingArguments, 197 | **hf_kwargs, 198 | ): 199 | # load local 200 | if os.path.isdir(model_args.model_name_or_path): 201 | if model_args.untie_encoder: 202 | _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model') 203 | _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model') 204 | if not os.path.exists(_qry_model_path): 205 | _qry_model_path = model_args.model_name_or_path 206 | _psg_model_path = model_args.model_name_or_path 207 | logger.info(f'loading query model weight from {_qry_model_path}') 208 | lm_q = AutoModel.from_pretrained( 209 | _qry_model_path, 210 | **hf_kwargs 211 | ) 212 | logger.info(f'loading passage model weight from {_psg_model_path}') 213 | lm_p = AutoModel.from_pretrained( 214 | _psg_model_path, 215 | **hf_kwargs 216 | ) 217 | else: 218 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 219 | lm_p = lm_q 220 | # load pre-trained 221 | else: 222 | lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 223 | lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q 224 | 225 | if model_args.add_pooler: 226 | pooler = cls.build_pooler(model_args) 227 | else: 228 | pooler = None 229 | 230 | model = cls( 231 | lm_q=lm_q, 232 | lm_p=lm_p, 233 | pooler=pooler, 234 | model_args=model_args, 235 | data_args=data_args, 236 | train_args=train_args 237 | ) 238 | return model 239 | 240 | def save(self, output_dir: str): 241 | if self.model_args.untie_encoder: 242 | os.makedirs(os.path.join(output_dir, 'query_model')) 243 | os.makedirs(os.path.join(output_dir, 'passage_model')) 244 | self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model')) 245 | self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model')) 246 | else: 247 | self.lm_q.save_pretrained(output_dir) 248 | 249 | if self.model_args.add_pooler: 250 | self.pooler.save_pooler(output_dir) 251 | 252 | def dist_gather_tensor(self, t: Optional[torch.Tensor]): 253 | if t is None: 254 | return None 255 | t = t.contiguous() 256 | 257 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] 258 | dist.all_gather(all_tensors, t) 259 | 260 | all_tensors[self.process_rank] = t 261 | all_tensors = torch.cat(all_tensors, dim=0) 262 | 263 | return all_tensors 264 | 265 | 266 | class DenseModelForInference(DenseModel): 267 | POOLER_CLS = LinearPooler 268 | 269 | def __init__( 270 | self, 271 | lm_q: PreTrainedModel, 272 | lm_p: PreTrainedModel, 273 | pooler: nn.Module = None, 274 | **kwargs, 275 | ): 276 | nn.Module.__init__(self) 277 | self.lm_q = lm_q 278 | self.lm_p = lm_p 279 | self.pooler = pooler 280 | 281 | @torch.no_grad() 282 | def encode_passage(self, psg): 283 | return super(DenseModelForInference, self).encode_passage(psg) 284 | 285 | @torch.no_grad() 286 | def encode_query(self, qry): 287 | return super(DenseModelForInference, self).encode_query(qry) 288 | 289 | def forward( 290 | self, 291 | query: Dict[str, Tensor] = None, 292 | passage: Dict[str, Tensor] = None, 293 | ): 294 | q_hidden, q_reps = self.encode_query(query) 295 | p_hidden, p_reps = self.encode_passage(passage) 296 | return DenseOutput(q_reps=q_reps, p_reps=p_reps) 297 | 298 | @classmethod 299 | def build( 300 | cls, 301 | model_name_or_path: str = None, 302 | model_args: ModelArguments = None, 303 | data_args: DataArguments = None, 304 | train_args: TrainingArguments = None, 305 | **hf_kwargs, 306 | ): 307 | assert model_name_or_path is not None or model_args is not None 308 | if model_name_or_path is None: 309 | model_name_or_path = model_args.model_name_or_path 310 | 311 | # load local 312 | if os.path.isdir(model_name_or_path): 313 | _qry_model_path = os.path.join(model_name_or_path, 'query_model') 314 | _psg_model_path = os.path.join(model_name_or_path, 'passage_model') 315 | if os.path.exists(_qry_model_path): 316 | logger.info(f'found separate weight for query/passage encoders') 317 | logger.info(f'loading query model weight from {_qry_model_path}') 318 | lm_q = AutoModel.from_pretrained( 319 | _qry_model_path, 320 | **hf_kwargs 321 | ) 322 | logger.info(f'loading passage model weight from {_psg_model_path}') 323 | lm_p = AutoModel.from_pretrained( 324 | _psg_model_path, 325 | **hf_kwargs 326 | ) 327 | else: 328 | logger.info(f'try loading tied weight') 329 | logger.info(f'loading model weight from {model_name_or_path}') 330 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) 331 | lm_p = lm_q 332 | else: 333 | logger.info(f'try loading tied weight') 334 | logger.info(f'loading model weight from {model_name_or_path}') 335 | lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) 336 | lm_p = lm_q 337 | 338 | pooler_weights = os.path.join(model_name_or_path, 'pooler.pt') 339 | pooler_config = os.path.join(model_name_or_path, 'pooler_config.json') 340 | if os.path.exists(pooler_weights) and os.path.exists(pooler_config): 341 | logger.info(f'found pooler weight and configuration') 342 | with open(pooler_config) as f: 343 | pooler_config_dict = json.load(f) 344 | pooler = cls.POOLER_CLS(**pooler_config_dict) 345 | pooler.load(model_name_or_path) 346 | else: 347 | pooler = None 348 | 349 | model = cls( 350 | lm_q=lm_q, 351 | lm_p=lm_p, 352 | pooler=pooler 353 | ) 354 | return model --------------------------------------------------------------------------------