├── github.sh ├── framework.jpeg ├── ancetele ├── faiss_retriever │ ├── __init__.py │ ├── reducer.py │ ├── gpu_utils.py │ ├── retriever.py │ └── do_retrieval.py ├── losses │ ├── __init__.py │ └── contrastive_loss.py ├── __init__.py ├── grad_cache │ ├── cachex │ │ ├── __init__.py │ │ ├── tree_utils.py │ │ ├── training.py │ │ └── functional.py │ ├── __init__.py │ ├── context_managers.py │ ├── loss.py │ └── functional.py ├── trainers │ ├── __init__.py │ └── dense_trainer.py ├── networks │ └── __init__.py ├── dataloaders │ ├── loader_utils.py │ ├── __init__.py │ ├── dataset_utils.py │ ├── dense_dataset.py │ └── hf_dataset.py ├── encode.py ├── train.py └── arguments.py ├── scripts ├── convert_result_to_trec.py ├── score_to_marco.py └── ms_marco_eval.py ├── shells ├── tokenize_wikipedia_corpus.sh ├── tokenize_nq.sh ├── tokenize_triviaqa.sh ├── tokenize_msmarco.sh ├── train_ance-tele_nq.sh ├── epi-3-train-nq.sh ├── train_ance-tele_triviaqa.sh ├── epi-3-train-triviaqa.sh ├── epi-1-train-nq.sh ├── epi-2-train-nq.sh ├── epi-1-train-triviaqa.sh ├── epi-2-train-triviaqa.sh ├── epi-3-train-msmarco.sh ├── train_ance-tele_msmarco.sh ├── epi-2-train-msmarco.sh ├── epi-1-train-msmarco.sh ├── infer_msmarco.sh ├── infer_nq.sh ├── infer_triviaqa.sh ├── epi-1-mine-msmarco.sh ├── epi-2-mine-msmarco.sh ├── epi-3-mine-msmarco.sh ├── epi-1-mine-nq.sh ├── epi-1-mine-triviaqa.sh ├── epi-2-mine-nq.sh ├── epi-3-mine-nq.sh ├── epi-2-mine-triviaqa.sh └── epi-3-mine-triviaqa.sh ├── preprocess ├── tokenize_marco_queries.py ├── tokenize_nq_triviaqa_queries.py ├── tokenize_marco_passages.py ├── tokenize_wikipedia_passages.py ├── tokenize_marco_positives.py ├── combine_marco_negative.py ├── build_train_hn.py ├── combine_nq_triviaqa_negative.py ├── preprocessor.py └── build_train_em_hn.py └── LICENSE /github.sh: -------------------------------------------------------------------------------- 1 | git add . 2 | git commit -m "update" 3 | git push origin master 4 | -------------------------------------------------------------------------------- /framework.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMatch/ANCE-Tele/HEAD/framework.jpeg -------------------------------------------------------------------------------- /ancetele/faiss_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import BaseFaissIPRetriever 2 | -------------------------------------------------------------------------------- /ancetele/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive_loss import (SimpleContrastiveLoss, DistributedContrastiveLoss) -------------------------------------------------------------------------------- /ancetele/__init__.py: -------------------------------------------------------------------------------- 1 | from . import networks 2 | from . import losses 3 | from . import grad_cache 4 | from . import arguments -------------------------------------------------------------------------------- /ancetele/grad_cache/cachex/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import chunk_encode, cache_grad, unchunk_args 2 | from .tree_utils import tree_chunk, tree_unchunk 3 | 4 | -------------------------------------------------------------------------------- /ancetele/grad_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # try: 2 | # from .grad_cache import GradCache 3 | # except ModuleNotFoundError: 4 | # pass 5 | 6 | 7 | from .grad_cache import GradCache -------------------------------------------------------------------------------- /ancetele/grad_cache/cachex/tree_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | 5 | 6 | def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any: 7 | return jax.tree_map( 8 | lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]), 9 | tree 10 | ) 11 | 12 | 13 | def tree_unchunk(tree: Any, axis: int = 0) -> Any: 14 | return jax.tree_map( 15 | lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]), 16 | tree 17 | ) 18 | -------------------------------------------------------------------------------- /scripts/convert_result_to_trec.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from argparse import ArgumentParser 3 | 4 | parser = ArgumentParser() 5 | parser.add_argument('--input') 6 | args = parser.parse_args() 7 | 8 | output_file = args.input + ".teIn" 9 | with open(args.input) as f_in, open(output_file, 'w') as f_out: 10 | cur_qid = None 11 | rank = 0 12 | for line in tqdm(f_in): 13 | qid, docid, score = line.split() 14 | if cur_qid != qid: 15 | cur_qid = qid 16 | rank = 0 17 | rank += 1 18 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n') -------------------------------------------------------------------------------- /shells/tokenize_wikipedia_corpus.sh: -------------------------------------------------------------------------------- 1 | ## ************************************************* 2 | ## Tokenize Wikipedia Corpus 3 | ## ************************************************* 4 | DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 5 | TOKENIZER=bert-base-uncased 6 | TOKENIZER_ID=bert 7 | 8 | ## ********************************************** 9 | ## Corpus 10 | ## ********************************************** 11 | python ../preprocess/tokenize_wikipedia_passages.py \ 12 | --tokenizer_name ${TOKENIZER} \ 13 | --file ${DATA_DIR}/psgs_w100.tsv \ 14 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/corpus \ 15 | --n_splits 20 \ -------------------------------------------------------------------------------- /ancetele/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense_trainer import DenseTrainer, GCDenseTrainer 2 | 3 | 4 | def get_trainer( 5 | model, 6 | args, 7 | train_dataset, 8 | eval_dataset, 9 | data_collator, 10 | callbacks=None, 11 | ): 12 | if args.grad_cache: 13 | return GCDenseTrainer( 14 | model=model, 15 | args=args, 16 | train_dataset=train_dataset, 17 | data_collator=data_collator, 18 | callbacks=callbacks, 19 | ) 20 | 21 | else: 22 | return DenseTrainer( 23 | model=model, 24 | args=args, 25 | train_dataset=train_dataset, 26 | data_collator=data_collator, 27 | callbacks=callbacks, 28 | ) -------------------------------------------------------------------------------- /ancetele/grad_cache/context_managers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import get_device_states, set_device_states 3 | 4 | 5 | class RandContext: 6 | def __init__(self, *tensors): 7 | self.fwd_cpu_state = torch.get_rng_state() 8 | self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) 9 | 10 | def __enter__(self): 11 | self._fork = torch.random.fork_rng( 12 | devices=self.fwd_gpu_devices, 13 | enabled=True 14 | ) 15 | self._fork.__enter__() 16 | torch.set_rng_state(self.fwd_cpu_state) 17 | set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) 18 | 19 | def __exit__(self, exc_type, exc_val, exc_tb): 20 | self._fork.__exit__(exc_type, exc_val, exc_tb) 21 | self._fork = None -------------------------------------------------------------------------------- /scripts/score_to_marco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('score_file') 7 | args = parser.parse_args() 8 | 9 | with open(args.score_file) as f: 10 | lines = f.readlines() 11 | 12 | all_scores = defaultdict(list) 13 | 14 | for line in lines: 15 | if len(line.strip()) == 0: 16 | continue 17 | qid, did, score = line.strip().split() 18 | score = float(score) 19 | all_scores[qid].append((did, score)) 20 | 21 | qq = list(all_scores.keys()) 22 | 23 | with open(args.score_file + '.marco', 'w') as f: 24 | for qid in tqdm(qq): 25 | score_list = sorted(all_scores[qid], key=lambda x: x[1], reverse=True) 26 | for rank, (did, score) in enumerate(score_list): 27 | f.write(f'{qid}\t{did}\t{rank+1}\n') 28 | 29 | -------------------------------------------------------------------------------- /shells/tokenize_nq.sh: -------------------------------------------------------------------------------- 1 | ## ************************************************* 2 | ## Tokenize NQ Dataset 3 | ## ************************************************* 4 | DATA_DIR=/home/sunsi/dataset/nq 5 | TOKENIZER=bert-base-uncased 6 | TOKENIZER_ID=bert 7 | 8 | ## ********************************************** 9 | ## train queries 10 | ## ********************************************** 11 | python ../preprocess/tokenize_nq_triviaqa_queries.py \ 12 | --tokenizer_name ${TOKENIZER} \ 13 | --query_file ${DATA_DIR}/nq-train-qrels.jsonl \ 14 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 15 | 16 | # ********************************************** 17 | # test queries 18 | # ********************************************** 19 | python ../preprocess/tokenize_nq_triviaqa_queries.py \ 20 | --tokenizer_name ${TOKENIZER} \ 21 | --query_file ${DATA_DIR}/nq-test.jsonl \ 22 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/test.query.json \ -------------------------------------------------------------------------------- /shells/tokenize_triviaqa.sh: -------------------------------------------------------------------------------- 1 | ## ************************************************* 2 | ## Tokenize TriviaQA Dataset 3 | ## ************************************************* 4 | DATA_DIR=/home/sunsi/dataset/triviaqa 5 | TOKENIZER=bert-base-uncased 6 | TOKENIZER_ID=bert 7 | 8 | ## ********************************************** 9 | ## train queries 10 | ## ********************************************** 11 | python ../preprocess/tokenize_nq_triviaqa_queries.py \ 12 | --tokenizer_name ${TOKENIZER} \ 13 | --query_file ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 14 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 15 | 16 | # ********************************************** 17 | # test queries 18 | # ********************************************** 19 | python ../preprocess/tokenize_nq_triviaqa_queries.py \ 20 | --tokenizer_name ${TOKENIZER} \ 21 | --query_file ${DATA_DIR}/triviaqa-test.jsonl \ 22 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/test.query.json \ -------------------------------------------------------------------------------- /ancetele/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from arguments import ModelArguments, DataArguments 4 | from arguments import DenseTrainingArguments as TrainingArguments 5 | from .DenseRetriever import (DenseModel, DenseModelForInference) 6 | from collections import OrderedDict 7 | 8 | def get_network( 9 | model_args: ModelArguments, 10 | data_args: DataArguments, 11 | training_args: TrainingArguments, 12 | config: OrderedDict, 13 | cache_dir: str, 14 | do_train: bool, 15 | ): 16 | if do_train: 17 | model = DenseModel.build( 18 | model_args, 19 | data_args, 20 | training_args, 21 | config=config, 22 | cache_dir=model_args.cache_dir, 23 | ) 24 | else: 25 | model = DenseModelForInference.build( 26 | model_name_or_path=model_args.model_name_or_path, 27 | data_args=data_args, 28 | config=config, 29 | cache_dir=model_args.cache_dir, 30 | ) 31 | 32 | return model 33 | 34 | -------------------------------------------------------------------------------- /preprocess/tokenize_marco_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from argparse import ArgumentParser 4 | from transformers import AutoTokenizer 5 | from preprocessor import SimpleCollectionPreProcessor 6 | 7 | if __name__ == "__main__": 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument('--tokenizer_name', required=True) 11 | parser.add_argument('--truncate', type=int, default=32) 12 | parser.add_argument('--query_file', required=True) 13 | parser.add_argument('--save_to', required=True) 14 | args = parser.parse_args() 15 | 16 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 17 | processor = SimpleCollectionPreProcessor(tokenizer=tokenizer, max_length=args.truncate) 18 | 19 | with open(args.query_file, 'r') as f: 20 | lines = f.readlines() 21 | 22 | os.makedirs(os.path.split(args.save_to)[0], exist_ok=True) 23 | with open(args.save_to, 'w') as jfile: 24 | for x in tqdm(lines): 25 | q = processor.process_line(x) 26 | jfile.write(q + '\n') 27 | -------------------------------------------------------------------------------- /shells/tokenize_msmarco.sh: -------------------------------------------------------------------------------- 1 | ## ************************************************* 2 | ## Tokenize MS MARCO Dataset 3 | ## ************************************************* 4 | DATA_DIR=/home/sunsi/dataset/msmarco 5 | TOKENIZER=bert-base-uncased 6 | TOKENIZER_ID=bert 7 | 8 | ## corpus 9 | python ../preprocess/tokenize_marco_passages.py \ 10 | --tokenizer_name ${TOKENIZER} \ 11 | --file ${DATA_DIR}/corpus.tsv \ 12 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/corpus \ 13 | 14 | ## train queries 15 | python ../preprocess/tokenize_marco_queries.py \ 16 | --tokenizer_name ${TOKENIZER} \ 17 | --query_file ${DATA_DIR}/train.query.txt \ 18 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 19 | 20 | ## train positives 21 | python ../preprocess/tokenize_marco_positives.py \ 22 | --data_dir ${DATA_DIR} \ 23 | --tokenizer_name ${TOKENIZER} \ 24 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/train.positives.json \ 25 | 26 | ## dev queries 27 | python ../preprocess/tokenize_marco_queries.py \ 28 | --tokenizer_name ${TOKENIZER} \ 29 | --query_file ${DATA_DIR}/dev.query.txt \ 30 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/query/dev.query.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenMatch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ancetele/grad_cache/cachex/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from .functional import chunk_encode, cache_grad, unchunk_args 7 | 8 | 9 | def cache_train_step(loss_fn, state, ss, tt, axis='device'): 10 | def encode_with_params(params, **kwargs): 11 | return state.apply_fn(params=params, **kwargs) 12 | 13 | encode_fn = chunk_encode(partial(encode_with_params, state.params)) 14 | grad_fn = cache_grad(encode_with_params) 15 | 16 | s_reps = encode_fn(**ss) 17 | t_reps = encode_fn(**tt) 18 | 19 | @unchunk_args(axis=0, argnums=(0, 1)) 20 | def grad_cache_fn(xx, yy): 21 | return jnp.mean(loss_fn(xx, yy, axis=axis)) 22 | loss, (s_grads, t_grads) = jax.value_and_grad(grad_cache_fn, argnums=(0, 1))(s_reps, t_reps) 23 | 24 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params) 25 | grads = grad_fn(state.params, grads, s_grads, **ss) 26 | grads = grad_fn(state.params, grads, t_grads, **tt) 27 | 28 | loss, grads = jax.lax.pmean([loss, grads], axis) 29 | new_state = state.apply_gradients(grads=grads) 30 | return loss, new_state 31 | -------------------------------------------------------------------------------- /ancetele/faiss_retriever/reducer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import faiss 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from typing import Iterable, Tuple 6 | from numpy import ndarray 7 | from .__main__ import pickle_load, write_ranking 8 | 9 | 10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]): 11 | rh = None 12 | for scores, indices in results: 13 | if rh is None: 14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.') 15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1]) 16 | rh.add_result(-scores, indices) 17 | rh.finalize() 18 | corpus_scores, corpus_indices = -rh.D, rh.I 19 | 20 | return corpus_scores, corpus_indices 21 | 22 | 23 | def main(): 24 | parser = ArgumentParser() 25 | parser.add_argument('--score_dir', required=True) 26 | parser.add_argument('--query', required=True) 27 | parser.add_argument('--save_ranking_to', required=True) 28 | args = parser.parse_args() 29 | 30 | partitions = glob.glob(f'{args.score_dir}/*') 31 | 32 | corpus_scores, corpus_indices = combine_faiss_results(map(pickle_load, tqdm(partitions))) 33 | 34 | _, q_lookup = pickle_load(args.query) 35 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /ancetele/faiss_retriever/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import numpy as np 3 | import pickle 4 | import os 5 | import glob 6 | 7 | 8 | def clean_faiss_gpu(): 9 | ngpu = faiss.get_num_gpus() 10 | tempmem = 0 11 | for i in range(ngpu): 12 | res = faiss.StandardGpuResources() 13 | if tempmem >= 0: 14 | res.setTempMemory(tempmem) 15 | 16 | 17 | 18 | def get_gpu_index(cpu_index): 19 | gpu_resources = [] 20 | ngpu = faiss.get_num_gpus() 21 | tempmem = -1 22 | for i in range(ngpu): 23 | res = faiss.StandardGpuResources() 24 | if tempmem >= 0: 25 | res.setTempMemory(tempmem) 26 | gpu_resources.append(res) 27 | 28 | def make_vres_vdev(i0=0, i1=-1): 29 | " return vectors of device ids and resources useful for gpu_multiple" 30 | vres = faiss.GpuResourcesVector() 31 | vdev = faiss.IntVector() 32 | if i1 == -1: 33 | i1 = ngpu 34 | for i in range(i0, i1): 35 | vdev.push_back(i) 36 | vres.push_back(gpu_resources[i]) 37 | return vres, vdev 38 | 39 | co = faiss.GpuMultipleClonerOptions() 40 | co.shard = True 41 | gpu_vector_resources, gpu_devices_vector = make_vres_vdev(0, ngpu) 42 | gpu_index = faiss.index_cpu_to_gpu_multiple(gpu_vector_resources, gpu_devices_vector, cpu_index, co) 43 | return gpu_index -------------------------------------------------------------------------------- /ancetele/faiss_retriever/retriever.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | from tqdm import tqdm 4 | # from gpu_utils import get_gpu_index 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class BaseFaissIPRetriever: 10 | def __init__(self, init_reps: np.ndarray, use_gpu: bool): 11 | faiss.omp_set_num_threads(16) 12 | index = faiss.IndexFlatIP(init_reps.shape[1]) 13 | if use_gpu: 14 | from gpu_utils import get_gpu_index ## --- SS Modified --- 15 | index = get_gpu_index(index) 16 | logger.info('Gpu Index') 17 | self.index = index 18 | 19 | def search(self, q_reps: np.ndarray, k: int): 20 | return self.index.search(q_reps, k) 21 | 22 | def add(self, p_reps: np.ndarray): 23 | self.index.add(p_reps) 24 | 25 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): 26 | num_query = q_reps.shape[0] 27 | all_scores = [] 28 | all_indices = [] 29 | for start_idx in tqdm(range(0, num_query, batch_size)): 30 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 31 | all_scores.append(nn_scores) 32 | all_indices.append(nn_indices) 33 | all_scores = np.concatenate(all_scores, axis=0) 34 | all_indices = np.concatenate(all_indices, axis=0) 35 | 36 | return all_scores, all_indices -------------------------------------------------------------------------------- /ancetele/dataloaders/loader_utils.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from torch.utils.data import Dataset 3 | from transformers import ( 4 | PreTrainedTokenizer, 5 | BatchEncoding, 6 | DataCollatorWithPadding 7 | ) 8 | from dataclasses import dataclass 9 | from typing import Dict, List, Tuple, Optional, Any, Union 10 | 11 | class EncodeDataset(Dataset): 12 | input_keys = ['text_id', 'text'] 13 | 14 | def __init__( 15 | self, 16 | dataset: datasets.Dataset, 17 | tokenizer: PreTrainedTokenizer, 18 | max_len=128 19 | ): 20 | self.encode_data = dataset 21 | self.tok = tokenizer 22 | self.max_len = max_len 23 | 24 | def __len__(self): 25 | return len(self.encode_data) 26 | 27 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]: 28 | text_id, text = (self.encode_data[item][f] for f in self.input_keys) 29 | encoded_text = self.tok.encode_plus( 30 | text, 31 | max_length=self.max_len, 32 | truncation='only_first', 33 | padding=False, 34 | return_token_type_ids=False, 35 | ) 36 | return text_id, encoded_text 37 | 38 | 39 | @dataclass 40 | class EncodeCollator(DataCollatorWithPadding): 41 | def __call__(self, features): 42 | text_ids = [x[0] for x in features] 43 | text_features = [x[1] for x in features] 44 | collated_features = super().__call__(text_features) 45 | return text_ids, collated_features 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /preprocess/tokenize_nq_triviaqa_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | from dataclasses import dataclass 5 | from argparse import ArgumentParser 6 | from transformers import AutoTokenizer 7 | from transformers import PreTrainedTokenizer 8 | 9 | 10 | @dataclass 11 | class QuestionPreProcessor: 12 | tokenizer: PreTrainedTokenizer 13 | separator: str = '\t' 14 | 15 | def process_line(self, line: str): 16 | xx = json.loads(line.strip("\n")) 17 | text_id, text = xx["qid"], xx["question"] 18 | text_encoded = self.tokenizer.encode( 19 | self.tokenizer.sep_token.join([text]), 20 | add_special_tokens=False, 21 | truncation=True 22 | ) 23 | encoded = { 24 | 'text_id': text_id, 25 | 'text': text_encoded 26 | } 27 | return json.dumps(encoded) 28 | 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--tokenizer_name', required=True) 32 | parser.add_argument('--query_file', required=True) 33 | parser.add_argument('--save_to', required=True) 34 | args = parser.parse_args() 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 37 | processor = QuestionPreProcessor(tokenizer=tokenizer) 38 | 39 | with open(args.query_file, 'r') as f: 40 | lines = f.readlines() 41 | 42 | os.makedirs(os.path.split(args.save_to)[0], exist_ok=True) 43 | with open(args.save_to, 'w') as jfile: 44 | for x in tqdm(lines): 45 | q = processor.process_line(x) 46 | jfile.write(q + '\n') 47 | -------------------------------------------------------------------------------- /preprocess/tokenize_marco_passages.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | from argparse import ArgumentParser 5 | from transformers import AutoTokenizer 6 | from preprocessor import SimpleCollectionPreProcessor 7 | 8 | if __name__ == "__main__": 9 | parser = ArgumentParser() 10 | parser.add_argument('--tokenizer_name', required=True) 11 | parser.add_argument('--truncate', type=int, default=128) 12 | parser.add_argument('--file', required=True) 13 | parser.add_argument('--save_to', required=True) 14 | parser.add_argument('--n_splits', type=int, default=10) 15 | 16 | args = parser.parse_args() 17 | 18 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 19 | processor = SimpleCollectionPreProcessor(tokenizer=tokenizer, max_length=args.truncate) 20 | 21 | with open(args.file, 'r') as f: 22 | lines = f.readlines() 23 | 24 | n_lines = len(lines) 25 | if n_lines % args.n_splits == 0: 26 | split_size = int(n_lines / args.n_splits) 27 | else: 28 | split_size = int(n_lines / args.n_splits) + 1 29 | 30 | 31 | os.makedirs(args.save_to, exist_ok=True) 32 | with Pool() as p: 33 | for i in range(args.n_splits): 34 | with open(os.path.join(args.save_to, f'split{i:02d}.json'), 'w') as f: 35 | pbar = tqdm(lines[i*split_size: (i+1)*split_size]) 36 | pbar.set_description(f'split - {i:02d}') 37 | for jitem in p.imap(processor.process_line, pbar, chunksize=500): 38 | f.write(jitem + '\n') 39 | 40 | 41 | -------------------------------------------------------------------------------- /ancetele/grad_cache/cachex/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Any 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from .tree_utils import tree_unchunk 8 | 9 | Array = Any 10 | 11 | 12 | def grad_with_cache(f, **grad_kwargs): 13 | def cache_f(params, cache, *args, **kwargs): 14 | return jnp.sum(f(params, *args, **kwargs) * cache) 15 | return jax.grad(cache_f, **grad_kwargs) 16 | 17 | 18 | def encode_scan_fn(f, carry, x): 19 | return carry, f(**x) 20 | 21 | 22 | def cache_grad_scan_fn(f, params, acc, x): 23 | cached_grad, kwargs = x 24 | 25 | def fwd_fn(w): 26 | return f(params=w, **kwargs) 27 | 28 | chunk_grad = grad_with_cache(fwd_fn)(params, cached_grad) 29 | acc = jax.tree_multimap(lambda u, v: u + v, acc, chunk_grad) 30 | return acc, None 31 | 32 | 33 | def chunk_encode(encode_fn): 34 | def f(**xx): 35 | _, hh = jax.lax.scan(partial(encode_scan_fn, encode_fn), 0, xx) 36 | return hh 37 | return f 38 | 39 | 40 | def cache_grad(encode_fn): 41 | def f(params, grad_accumulator, cached_grad, **xx): 42 | grads, _ = jax.lax.scan( 43 | partial(cache_grad_scan_fn, encode_fn, params), grad_accumulator, [cached_grad, xx] 44 | ) 45 | return grads 46 | return f 47 | 48 | 49 | def unchunk_args(axis: int = 0, argnums: Iterable[int] = ()): 50 | def decorator_unchunk(f): 51 | def g(*args, **kwargs): 52 | new_args = list(args) 53 | for i in argnums: 54 | new_args[i] = tree_unchunk(args[i], axis) 55 | return f(*new_args, **kwargs) 56 | 57 | return g 58 | 59 | return decorator_unchunk 60 | -------------------------------------------------------------------------------- /ancetele/losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torch import distributed as dist 5 | 6 | 7 | class SimpleContrastiveLoss: 8 | def __init__(self, n_target: int = 1): 9 | self.target_per_qry = n_target 10 | 11 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 12 | if target is None: 13 | assert x.size(0) * self.target_per_qry == y.size(0) 14 | target = torch.arange( 15 | 0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device, dtype=torch.long) 16 | logits = torch.matmul(x, y.transpose(0, 1)) 17 | return F.cross_entropy(logits, target, reduction=reduction) 18 | 19 | 20 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 21 | def __init__(self, n_target: int = 0, scale_loss: bool = True): 22 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 23 | super().__init__(n_target=n_target) 24 | self.word_size = dist.get_world_size() 25 | self.rank = dist.get_rank() 26 | self.scale_loss = scale_loss 27 | 28 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 29 | dist_x = self.gather_tensor(x) 30 | dist_y = self.gather_tensor(y) 31 | loss = super().__call__(dist_x, dist_y, **kwargs) 32 | if self.scale_loss: 33 | loss = loss * self.word_size 34 | return loss 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) -------------------------------------------------------------------------------- /shells/train_ance-tele_nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=ance-tele_nq_tokenized-train-data 7 | ## OUTPUT 8 | export train_job_name=ance-tele_nq_NEW 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_steps 2000 \ 32 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 33 | --per_device_train_batch_size 32 \ 34 | --train_n_passages 12 \ 35 | --learning_rate 5e-6 \ 36 | --num_train_epochs 40 \ 37 | --q_max_len ${q_max_len} \ 38 | --p_max_len ${p_max_len} \ 39 | --dataloader_num_workers 2 \ 40 | --untie_encoder \ 41 | --negatives_x_device \ 42 | --positive_passage_no_shuffle \ 43 | 44 | 45 | # # # ******************************************************************** 46 | # # # If Your CUDA Memory is not enough, Please set the following augments 47 | # # # ******************************************************************** 48 | # --grad_cache \ 49 | # --gc_q_chunk_size 16 \ 50 | # --gc_p_chunk_size 128 \ 51 | 52 | ## Split a batch of queries to several gc_q_chunk_size 53 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-3-train-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-3-tele-neg.nq ## Mined Epi-3 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-3.ance-tele.nq 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_steps 2000 \ 32 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 33 | --per_device_train_batch_size 32 \ 34 | --train_n_passages 12 \ 35 | --learning_rate 5e-6 \ 36 | --num_train_epochs 40 \ 37 | --q_max_len ${q_max_len} \ 38 | --p_max_len ${p_max_len} \ 39 | --dataloader_num_workers 2 \ 40 | --untie_encoder \ 41 | --negatives_x_device \ 42 | --positive_passage_no_shuffle \ 43 | 44 | 45 | # # # ******************************************************************** 46 | # # # If Your CUDA Memory is not enough, Please set the following augments 47 | # # # ******************************************************************** 48 | # --grad_cache \ 49 | # --gc_q_chunk_size 16 \ 50 | # --gc_p_chunk_size 128 \ 51 | 52 | ## Split a batch of queries to several gc_q_chunk_size 53 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/train_ance-tele_triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=ance-tele_triviaqa_tokenized-train-data 7 | ## OUTPUT 8 | export train_job_name=ance-tele_triviaqa_NEW 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_steps 2000 \ 32 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 33 | --per_device_train_batch_size 32 \ 34 | --train_n_passages 12 \ 35 | --learning_rate 5e-6 \ 36 | --num_train_epochs 40 \ 37 | --q_max_len ${q_max_len} \ 38 | --p_max_len ${p_max_len} \ 39 | --dataloader_num_workers 2 \ 40 | --untie_encoder \ 41 | --negatives_x_device \ 42 | --positive_passage_no_shuffle \ 43 | 44 | 45 | # # # ******************************************************************** 46 | # # # If Your CUDA Memory is not enough, Please set the following augments 47 | # # # ******************************************************************** 48 | # --grad_cache \ 49 | # --gc_q_chunk_size 16 \ 50 | # --gc_p_chunk_size 128 \ 51 | 52 | ## Split a batch of queries to several gc_q_chunk_size 53 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-3-train-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-3-tele-neg.triviaqa ## Mined Epi-3 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-3.ance-tele.triviaqa 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_steps 2000 \ 32 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 33 | --per_device_train_batch_size 32 \ 34 | --train_n_passages 12 \ 35 | --learning_rate 5e-6 \ 36 | --num_train_epochs 40 \ 37 | --q_max_len ${q_max_len} \ 38 | --p_max_len ${p_max_len} \ 39 | --dataloader_num_workers 2 \ 40 | --untie_encoder \ 41 | --negatives_x_device \ 42 | --positive_passage_no_shuffle \ 43 | 44 | 45 | # # # ******************************************************************** 46 | # # # If Your CUDA Memory is not enough, Please set the following augments 47 | # # # ******************************************************************** 48 | # --grad_cache \ 49 | # --gc_q_chunk_size 16 \ 50 | # --gc_p_chunk_size 128 \ 51 | 52 | ## Split a batch of queries to several gc_q_chunk_size 53 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-1-train-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-1-tele-neg.nq ## Mined Epi-1 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-1.ance-tele.nq.checkp-2000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_strategy no \ 32 | --early_stop_step 2000 \ 33 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 34 | --per_device_train_batch_size 32 \ 35 | --train_n_passages 4 \ 36 | --learning_rate 5e-6 \ 37 | --num_train_epochs 40 \ 38 | --q_max_len ${q_max_len} \ 39 | --p_max_len ${p_max_len} \ 40 | --dataloader_num_workers 2 \ 41 | --untie_encoder \ 42 | --negatives_x_device \ 43 | --positive_passage_no_shuffle \ 44 | 45 | # # --train_n_passages 4 or 12 is ok, 4 is faster. 46 | 47 | # # # ******************************************************************** 48 | # # # If Your CUDA Memory is not enough, Please set the following augments 49 | # # # ******************************************************************** 50 | # --grad_cache \ 51 | # --gc_q_chunk_size 16 \ 52 | # --gc_p_chunk_size 128 \ 53 | 54 | ## Split a batch of queries to several gc_q_chunk_size 55 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-2-train-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-2-tele-neg.nq ## Mined Epi-2 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-2.ance-tele.nq.checkp-2000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_strategy no \ 32 | --early_stop_step 2000 \ 33 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 34 | --per_device_train_batch_size 32 \ 35 | --train_n_passages 8 \ 36 | --learning_rate 5e-6 \ 37 | --num_train_epochs 40 \ 38 | --q_max_len ${q_max_len} \ 39 | --p_max_len ${p_max_len} \ 40 | --dataloader_num_workers 2 \ 41 | --untie_encoder \ 42 | --negatives_x_device \ 43 | --positive_passage_no_shuffle \ 44 | 45 | # # --train_n_passages 8 or 12 is ok, 8 is faster. 46 | 47 | # # # ******************************************************************** 48 | # # # If Your CUDA Memory is not enough, Please set the following augments 49 | # # # ******************************************************************** 50 | # --grad_cache \ 51 | # --gc_q_chunk_size 16 \ 52 | # --gc_p_chunk_size 128 \ 53 | 54 | ## Split a batch of queries to several gc_q_chunk_size 55 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-1-train-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-1-tele-neg.triviaqa ## Mined Epi-1 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-1.ance-tele.triviaqa.checkp-2000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_strategy no \ 32 | --early_stop_step 2000 \ 33 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 34 | --per_device_train_batch_size 32 \ 35 | --train_n_passages 4 \ 36 | --learning_rate 5e-6 \ 37 | --num_train_epochs 40 \ 38 | --q_max_len ${q_max_len} \ 39 | --p_max_len ${p_max_len} \ 40 | --dataloader_num_workers 2 \ 41 | --untie_encoder \ 42 | --negatives_x_device \ 43 | --positive_passage_no_shuffle \ 44 | 45 | # # --train_n_passages 4 or 12 is ok, 4 is faster. 46 | 47 | # # # ******************************************************************** 48 | # # # If Your CUDA Memory is not enough, Please set the following augments 49 | # # # ******************************************************************** 50 | # --grad_cache \ 51 | # --gc_q_chunk_size 16 \ 52 | # --gc_p_chunk_size 128 \ 53 | 54 | ## Split a batch of queries to several gc_q_chunk_size 55 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-2-train-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-wiki 6 | export train_data=epi-2-tele-neg.triviaqa ## Mined Epi-2 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-2.ance-tele.triviaqa.checkp-2000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0,1,2,3" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | ## Length SetUp 17 | export q_max_len=32 18 | export p_max_len=156 19 | ## ************************************* 20 | TOKENIZER=bert-base-uncased 21 | TOKENIZER_ID=bert 22 | ## ************************************* 23 | 24 | # ********************************************** 25 | # Dist Train 26 | # ********************************************** 27 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 28 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 29 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 30 | --fp16 \ 31 | --save_strategy no \ 32 | --early_stop_step 2000 \ 33 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 34 | --per_device_train_batch_size 32 \ 35 | --train_n_passages 8 \ 36 | --learning_rate 5e-6 \ 37 | --num_train_epochs 40 \ 38 | --q_max_len ${q_max_len} \ 39 | --p_max_len ${p_max_len} \ 40 | --dataloader_num_workers 2 \ 41 | --untie_encoder \ 42 | --negatives_x_device \ 43 | --positive_passage_no_shuffle \ 44 | 45 | # # --train_n_passages 8 or 12 is ok, 8 is faster. 46 | 47 | # # # ******************************************************************** 48 | # # # If Your CUDA Memory is not enough, Please set the following augments 49 | # # # ******************************************************************** 50 | # --grad_cache \ 51 | # --gc_q_chunk_size 16 \ 52 | # --gc_p_chunk_size 128 \ 53 | 54 | ## Split a batch of queries to several gc_q_chunk_size 55 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /ancetele/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("..") 4 | from arguments import DataArguments 5 | from transformers import ( 6 | PreTrainedTokenizer, 7 | ) 8 | from .dense_dataset import ( 9 | DenseTrainDataset, 10 | DenseEncodeDataset, 11 | DenseQPCollator, 12 | DenseEncodeCollator, 13 | ) 14 | from .hf_dataset import ( 15 | HFDataset, 16 | HFQueryDataset, 17 | HFCorpusDataset 18 | ) 19 | from .dataset_utils import ( 20 | TrainPreProcessor, 21 | QueryPreProcessor, 22 | CorpusPreProcessor 23 | ) 24 | 25 | from .loader_utils import (EncodeCollator) 26 | 27 | 28 | 29 | 30 | def get_train_dataset( 31 | tokenizer: PreTrainedTokenizer, 32 | data_args: DataArguments, 33 | ): 34 | 35 | ## Transformer load dataset 36 | train_dataset = HFDataset( 37 | tokenizer=tokenizer, 38 | data_args=data_args, 39 | dataset_split="train", 40 | data_files=data_args.train_path, 41 | cache_dir=data_args.train_cache_dir, 42 | ) 43 | 44 | return ( 45 | DenseTrainDataset( 46 | data_args, 47 | train_dataset.process(), 48 | tokenizer), 49 | None, 50 | DenseQPCollator 51 | ) 52 | 53 | 54 | def get_encode_dataset( 55 | tokenizer: PreTrainedTokenizer, 56 | data_args: DataArguments, 57 | ): 58 | 59 | ## Dense-Retriever 60 | if data_args.encode_is_qry: 61 | encode_dataset = HFQueryDataset( 62 | tokenizer=tokenizer, 63 | data_args=data_args, 64 | dataset_split="encode", 65 | cache_dir=data_args.encode_in_path[0] + ".cache" 66 | ) 67 | else: 68 | encode_dataset = HFCorpusDataset( 69 | tokenizer=tokenizer, 70 | data_args=data_args, 71 | dataset_split="encode", 72 | cache_dir=data_args.encode_in_path[0] + ".cache" 73 | ) 74 | return ( 75 | DenseEncodeDataset( 76 | data_args, 77 | encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 78 | tokenizer), 79 | DenseEncodeCollator 80 | ) -------------------------------------------------------------------------------- /preprocess/tokenize_wikipedia_passages.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | from dataclasses import dataclass 6 | from argparse import ArgumentParser 7 | from transformers import AutoTokenizer, PreTrainedTokenizer 8 | 9 | @dataclass 10 | class WikiCollectionPreProcessor: 11 | tokenizer: PreTrainedTokenizer 12 | separator: str = '\t' 13 | max_length: int = 256 14 | 15 | def process_line(self, line: str): 16 | xx = line.strip().split(self.separator) 17 | text_id, body, title = xx[0], xx[1], xx[2] 18 | 19 | if text_id == "id": 20 | return None 21 | 22 | title = "" if title is None else title 23 | 24 | text = title + self.tokenizer.sep_token + body 25 | text_encoded = self.tokenizer.encode( 26 | text, 27 | add_special_tokens=False, 28 | max_length=self.max_length, 29 | truncation=True 30 | ) 31 | 32 | encoded = { 33 | 'text_id': text_id, 34 | 'text': text_encoded 35 | } 36 | return json.dumps(encoded) 37 | 38 | parser = ArgumentParser() 39 | parser.add_argument('--tokenizer_name', required=True) 40 | parser.add_argument('--truncate', type=int, default=256) 41 | parser.add_argument('--file', required=True) 42 | parser.add_argument('--save_to', required=True) 43 | parser.add_argument('--n_splits', type=int, default=20) 44 | 45 | args = parser.parse_args() 46 | 47 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 48 | processor = WikiCollectionPreProcessor(tokenizer=tokenizer, max_length=args.truncate) 49 | 50 | with open(args.file, 'r') as f: 51 | lines = f.readlines() 52 | 53 | n_lines = len(lines) 54 | if n_lines % args.n_splits == 0: 55 | split_size = int(n_lines / args.n_splits) 56 | else: 57 | split_size = int(n_lines / args.n_splits) + 1 58 | 59 | 60 | os.makedirs(args.save_to, exist_ok=True) 61 | with Pool() as p: 62 | for i in range(args.n_splits): 63 | with open(os.path.join(args.save_to, f'split{i:02d}.json'), 'w') as f: 64 | pbar = tqdm(lines[i*split_size: (i+1)*split_size]) 65 | pbar.set_description(f'split - {i:02d}') 66 | for jitem in p.imap(processor.process_line, pbar, chunksize=500): 67 | if jitem is not None: 68 | f.write(jitem + '\n') 69 | 70 | 71 | -------------------------------------------------------------------------------- /shells/epi-3-train-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-marco 6 | export train_data=epi-3-tele-neg.msmarco ## Mined Epi-3 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-3.ance-tele.msmarco 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0" ## multi-gpus: TOT_CUDA="0,1" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | TOKENIZER=bert-base-uncased 17 | TOKENIZER_ID=bert 18 | SplitNum=10 19 | ## ************************************* 20 | 21 | ## ********************************************** 22 | ## Train 23 | ## ********************************************** 24 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} python ../ancetele/train.py \ 25 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 26 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 27 | --fp16 \ 28 | --save_steps 20000 \ 29 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 30 | --per_device_train_batch_size 8 \ 31 | --train_n_passages 32 \ 32 | --learning_rate 5e-6 \ 33 | --num_train_epochs 3 \ 34 | --dataloader_num_workers 2 \ 35 | 36 | 37 | # # ********************************************** 38 | # # Dist Train 39 | # # ********************************************** 40 | # CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 41 | # --output_dir ${OUTPUT_DIR}/${train_job_name} \ 42 | # --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 43 | # --fp16 \ 44 | # --save_steps 20000 \ 45 | # --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 46 | # --per_device_train_batch_size 4 \ 47 | # --train_n_passages 32 \ 48 | # --learning_rate 5e-6 \ 49 | # --num_train_epochs 3 \ 50 | # --dataloader_num_workers 2 \ 51 | # --negatives_x_device \ 52 | 53 | 54 | # # # ******************************************************************** 55 | # # # If Your CUDA Memory is not enough, Please set the following augments 56 | # # # ******************************************************************** 57 | # --grad_cache \ 58 | # --gc_q_chunk_size 4 \ 59 | # --gc_p_chunk_size 8 \ 60 | 61 | ## Split a batch of queries to several gc_q_chunk_size 62 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/train_ance-tele_msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-marco 6 | export train_data=ance-tele_msmarco_tokenized-train-data 7 | ## OUTPUT 8 | export train_job_name=ance-tele_msmarco_qry-psg-encoder_NEW 9 | export infer_job_name=inference.${train_job_name} 10 | ## ************************************* 11 | ## TRAIN GPUs 12 | TOT_CUDA="0" ## multi-gpus: TOT_CUDA="0,1" 13 | CUDAs=(${TOT_CUDA//,/ }) 14 | CUDA_NUM=${#CUDAs[@]} 15 | PORT="1234" ## check the port does not occupied 16 | ## ************************************* 17 | TOKENIZER=bert-base-uncased 18 | TOKENIZER_ID=bert 19 | ## ************************************* 20 | 21 | ## ********************************************** 22 | ## Train 23 | ## ********************************************** 24 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} python ../ancetele/train.py \ 25 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 26 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 27 | --fp16 \ 28 | --save_steps 20000 \ 29 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 30 | --per_device_train_batch_size 8 \ 31 | --train_n_passages 32 \ 32 | --learning_rate 5e-6 \ 33 | --num_train_epochs 3 \ 34 | --dataloader_num_workers 2 \ 35 | 36 | 37 | # # ********************************************** 38 | # # Dist Train 39 | # # ********************************************** 40 | # CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 41 | # --output_dir ${OUTPUT_DIR}/${train_job_name} \ 42 | # --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 43 | # --fp16 \ 44 | # --save_steps 20000 \ 45 | # --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 46 | # --per_device_train_batch_size 4 \ 47 | # --train_n_passages 32 \ 48 | # --learning_rate 5e-6 \ 49 | # --num_train_epochs 3 \ 50 | # --dataloader_num_workers 2 \ 51 | # --negatives_x_device \ 52 | 53 | 54 | # # # ******************************************************************** 55 | # # # If Your CUDA Memory is not enough, Please set the following augments 56 | # # # ******************************************************************** 57 | # --grad_cache \ 58 | # --gc_q_chunk_size 4 \ 59 | # --gc_p_chunk_size 8 \ 60 | 61 | ## Split a batch of queries to several gc_q_chunk_size 62 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /preprocess/tokenize_marco_positives.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import json 4 | import random 5 | import datasets 6 | from tqdm import tqdm 7 | from datetime import datetime 8 | from multiprocessing import Pool 9 | from argparse import ArgumentParser 10 | from transformers import AutoTokenizer 11 | 12 | def read_qrel(relevance_file): 13 | qrel = {} 14 | with open(relevance_file, encoding='utf8') as f: 15 | tsvreader = csv.reader(f, delimiter="\t") 16 | for [topicid, _, docid, rel] in tsvreader: 17 | assert rel == "1" 18 | if topicid in qrel: 19 | qrel[topicid].append(docid) 20 | else: 21 | qrel[topicid] = [docid] 22 | return qrel 23 | 24 | def get_passage(p, collection, tokenizer, max_length=128): 25 | entry = collection[int(p)] 26 | title = entry['title'] 27 | title = "" if title is None else title 28 | body = entry['text'] 29 | content = title + tokenizer.sep_token + body 30 | 31 | passage_encoded = tokenizer.encode( 32 | content, 33 | add_special_tokens=False, 34 | max_length=max_length, 35 | truncation=True 36 | ) 37 | return passage_encoded 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = ArgumentParser() 42 | parser.add_argument('--data_dir', required=True) 43 | parser.add_argument('--tokenizer_name', required=True) 44 | parser.add_argument('--save_to', required=True) 45 | parser.add_argument('--truncate', type=int, default=128) 46 | args = parser.parse_args() 47 | 48 | queries_path = os.path.join(args.data_dir, "train.query.txt") 49 | collection_path = os.path.join(args.data_dir, "corpus.tsv") 50 | 51 | qrel = read_qrel(os.path.join(args.data_dir, "qrels.train.tsv")) 52 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 53 | 54 | collection = datasets.load_dataset( 55 | 'csv', 56 | data_files=collection_path, 57 | cache_dir=collection_path+".cache", 58 | column_names=['text_id', 'title', 'text'], 59 | delimiter='\t', 60 | )['train'] 61 | 62 | with open(args.save_to, 'w') as jfile: 63 | for qid, docids in tqdm(qrel.items()): 64 | text_encoded = get_passage(docids[0], collection, tokenizer, max_length=args.truncate) 65 | encoded = { 66 | 'text_id': qid, 67 | 'text': text_encoded 68 | } 69 | jfile.write(json.dumps(encoded) + '\n') -------------------------------------------------------------------------------- /ancetele/grad_cache/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from torch import distributed as dist 7 | 8 | 9 | class SimpleContrastiveLoss: 10 | def __init__(self, n_hard_negatives: int = 0): 11 | self.target_per_qry = n_hard_negatives + 1 12 | 13 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 14 | if target is None: 15 | assert x.size(0) * self.target_per_qry == y.size(0) 16 | target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device) 17 | 18 | logits = torch.matmul(x, y.transpose(0, 1)) 19 | return F.cross_entropy(logits, target, reduction=reduction) 20 | 21 | 22 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 23 | def __init__(self, n_hard_negatives: int = 0): 24 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 25 | 26 | super().__init__(n_hard_negatives=n_hard_negatives) 27 | self.word_size = dist.get_world_size() 28 | self.rank = dist.get_rank() 29 | 30 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 31 | dist_x = self.gather_tensor(x) 32 | dist_y = self.gather_tensor(y) 33 | 34 | return super().__call__(dist_x, dist_y, **kwargs) 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) 41 | 42 | 43 | class ContrastiveLossWithQueryClosure(SimpleContrastiveLoss): 44 | def __call__( 45 | self, 46 | *reps: Tensor, 47 | query_closure: Callable[[], Tensor] = None, 48 | target: Tensor = None, 49 | reduction: str = 'mean' 50 | ): 51 | if len(reps) == 0 or len(reps) > 2: 52 | raise ValueError(f'Expecting 1 or 2 tensor input, got {len(reps)} tensors') 53 | 54 | # no closure evaluation 55 | if len(reps) == 2: 56 | assert query_closure is None, 'received 2 representation tensors while query_closure is also set' 57 | return super().__call__(*reps, target=target, reduction=reduction) 58 | 59 | # run the closure 60 | assert query_closure is not None 61 | x = query_closure() 62 | y = reps[0] 63 | return super().__call__(x, y, target=target, reduction=reduction) 64 | -------------------------------------------------------------------------------- /shells/epi-2-train-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-marco 6 | export train_data=epi-2-tele-neg.msmarco ## Mined Epi-2 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-2.ance-tele.msmarco.checkp-20000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0" ## multi-gpus: TOT_CUDA="0,1" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | TOKENIZER=bert-base-uncased 17 | TOKENIZER_ID=bert 18 | SplitNum=10 19 | ## ************************************* 20 | 21 | ## ********************************************** 22 | ## Train (Early Stop) 23 | ## ********************************************** 24 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} python ../ancetele/train.py \ 25 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 26 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 27 | --fp16 \ 28 | --save_strategy no \ 29 | --early_stop_step 20000 \ 30 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 31 | --per_device_train_batch_size 8 \ 32 | --train_n_passages 32 \ 33 | --learning_rate 5e-6 \ 34 | --num_train_epochs 3 \ 35 | --dataloader_num_workers 2 \ 36 | 37 | 38 | # # ********************************************** 39 | # # Dist Train (Early Stop) 40 | # # ********************************************** 41 | # CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 42 | # --output_dir ${OUTPUT_DIR}/${train_job_name} \ 43 | # --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 44 | # --fp16 \ 45 | # --save_strategy no \ 46 | # --early_stop_step 20000 \ 47 | # --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 48 | # --per_device_train_batch_size 4 \ 49 | # --train_n_passages 32 \ 50 | # --learning_rate 5e-6 \ 51 | # --num_train_epochs 3 \ 52 | # --dataloader_num_workers 2 \ 53 | # --negatives_x_device \ 54 | 55 | # # --train_n_passages 16 or 32 is ok, 16 is faster. 56 | 57 | 58 | # # # ******************************************************************** 59 | # # # If Your CUDA Memory is not enough, Please set the following augments 60 | # # # ******************************************************************** 61 | # --grad_cache \ 62 | # --gc_q_chunk_size 4 \ 63 | # --gc_p_chunk_size 8 \ 64 | 65 | ## Split a batch of queries to several gc_q_chunk_size 66 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /shells/epi-1-train-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT 5 | export prev_train_job_name=co-condenser-marco 6 | export train_data=epi-1-tele-neg.msmarco ## Mined Epi-1 Tele-Neg 7 | ## OUTPUT 8 | export train_job_name=epi-1.ance-tele.msmarco.checkp-20000 9 | ## ************************************* 10 | ## TRAIN GPUs 11 | TOT_CUDA="0" ## multi-gpus: TOT_CUDA="0,1" 12 | CUDAs=(${TOT_CUDA//,/ }) 13 | CUDA_NUM=${#CUDAs[@]} 14 | PORT="1234" ## check the port does not occupied 15 | ## ************************************* 16 | TOKENIZER=bert-base-uncased 17 | TOKENIZER_ID=bert 18 | SplitNum=10 19 | ## ************************************* 20 | 21 | ## ********************************************** 22 | ## Train (Early Stop) 23 | ## ********************************************** 24 | CUDA_VISIBLE_DEVICES=${TOT_CUDA} python ../ancetele/train.py \ 25 | --output_dir ${OUTPUT_DIR}/${train_job_name} \ 26 | --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 27 | --fp16 \ 28 | --save_strategy no \ 29 | --early_stop_step 20000 \ 30 | --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 31 | --per_device_train_batch_size 8 \ 32 | --train_n_passages 16 \ 33 | --learning_rate 5e-6 \ 34 | --num_train_epochs 3 \ 35 | --dataloader_num_workers 2 \ 36 | 37 | # # --train_n_passages 16 or 32 is ok, 16 is faster. 38 | 39 | 40 | # # ********************************************** 41 | # # Dist Train (Early Stop) 42 | # # ********************************************** 43 | # CUDA_VISIBLE_DEVICES=${TOT_CUDA} OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=${CUDA_NUM} --master_port=${PORT} ../ancetele/train.py \ 44 | # --output_dir ${OUTPUT_DIR}/${train_job_name} \ 45 | # --model_name_or_path ${OUTPUT_DIR}/${prev_train_job_name} \ 46 | # --fp16 \ 47 | # --save_strategy no \ 48 | # --early_stop_step 20000 \ 49 | # --train_dir ${DATA_DIR}/${TOKENIZER_ID}/${train_data} \ 50 | # --per_device_train_batch_size 4 \ 51 | # --train_n_passages 16 \ 52 | # --learning_rate 5e-6 \ 53 | # --num_train_epochs 3 \ 54 | # --dataloader_num_workers 2 \ 55 | # --negatives_x_device \ 56 | 57 | # # --train_n_passages 16 or 32 is ok, 16 is faster. 58 | 59 | 60 | # # # ******************************************************************** 61 | # # # If Your CUDA Memory is not enough, Please set the following augments 62 | # # # ******************************************************************** 63 | # --grad_cache \ 64 | # --gc_q_chunk_size 4 \ 65 | # --gc_p_chunk_size 8 \ 66 | 67 | ## Split a batch of queries to several gc_q_chunk_size 68 | ## Split a batch of passages to several gc_p_chunk_size -------------------------------------------------------------------------------- /preprocess/combine_marco_negative.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | from argparse import ArgumentParser 6 | 7 | if __name__ == "__main__": 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument('--data_dir', required=True) 11 | parser.add_argument('--input_folder_1', required=True) 12 | parser.add_argument('--input_folder_2', required=True) 13 | parser.add_argument('--output_folder', required=True) 14 | args = parser.parse_args() 15 | 16 | input_path_1 = os.path.join(args.data_dir, args.input_folder_1) 17 | input_path_2 = os.path.join(args.data_dir, args.input_folder_2) 18 | output_path = os.path.join(args.data_dir, args.output_folder) 19 | 20 | # create output data 21 | if not os.path.exists(output_path): 22 | os.mkdir(output_path) 23 | 24 | # load input 1 25 | file_list_1 = [listx for listx in os.listdir(input_path_1) if "json" in listx] 26 | 27 | query2negatives = {} 28 | for file in tqdm(file_list_1): 29 | file_1 = os.path.join(input_path_1, file) 30 | with open(file_1, "r", encoding="utf-8") as fi: 31 | for line in fi: 32 | data = json.loads(line) 33 | query = "_".join(str(ids) for ids in data["query"]) 34 | negatives = data["negatives"] 35 | query2negatives[query] = negatives 36 | 37 | # load input 2 & write mix 38 | file_list_2 = [listx for listx in os.listdir(input_path_2) if "json" in listx] 39 | 40 | neg_num_list = [] 41 | diff_num = 0 42 | for file in tqdm(file_list_2): 43 | file_2 = os.path.join(input_path_2, file) 44 | output_file = os.path.join(output_path, file) 45 | with open(file_2, "r", encoding="utf-8") as fi, \ 46 | open(output_file, "w", encoding="utf-8") as fw: 47 | for line in fi: 48 | data = json.loads(line) 49 | query = data["query"] 50 | positives = data["positives"] 51 | qid = "_".join(str(ids) for ids in query) 52 | if qid in query2negatives: 53 | negatives = data["negatives"] + query2negatives[qid] 54 | neg_num_list.append(len(negatives)) 55 | else: 56 | negatives = data["negatives"] 57 | diff_num += 1 58 | 59 | mix_example = { 60 | 'query': query, 61 | 'positives': positives, 62 | 'negatives': negatives, 63 | } 64 | fw.write(json.dumps(mix_example) + '\n') 65 | 66 | 67 | print("diff num = ", diff_num) 68 | print("combine neg num = ", np.mean(neg_num_list)) 69 | -------------------------------------------------------------------------------- /ancetele/dataloaders/dataset_utils.py: -------------------------------------------------------------------------------- 1 | class TrainPreProcessor: 2 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256, separator=' '): 3 | self.tokenizer = tokenizer 4 | self.query_max_length = query_max_length 5 | self.text_max_length = text_max_length 6 | self.separator = separator 7 | 8 | def __call__(self, example): 9 | query = self.tokenizer.encode(example['query'], 10 | add_special_tokens=False, 11 | max_length=self.query_max_length, 12 | truncation=True) 13 | positives = [] 14 | for pos in example['positive_passages']: 15 | text = pos['title'] + self.separator + pos['text'] if 'title' in pos else pos['text'] 16 | positives.append(self.tokenizer.encode(text, 17 | add_special_tokens=False, 18 | max_length=self.text_max_length, 19 | truncation=True)) 20 | negatives = [] 21 | for neg in example['negative_passages']: 22 | text = neg['title'] + self.separator + neg['text'] if 'title' in neg else neg['text'] 23 | negatives.append(self.tokenizer.encode(text, 24 | add_special_tokens=False, 25 | max_length=self.text_max_length, 26 | truncation=True)) 27 | return {'query': query, 'positives': positives, 'negatives': negatives} 28 | 29 | 30 | class QueryPreProcessor: 31 | def __init__(self, tokenizer, query_max_length=32): 32 | self.tokenizer = tokenizer 33 | self.query_max_length = query_max_length 34 | 35 | def __call__(self, example): 36 | query_id = example['query_id'] 37 | query = self.tokenizer.encode(example['query'], 38 | add_special_tokens=False, 39 | max_length=self.query_max_length, 40 | truncation=True) 41 | return {'text_id': query_id, 'text': query} 42 | 43 | 44 | class CorpusPreProcessor: 45 | def __init__(self, tokenizer, text_max_length=256, separator=' '): 46 | self.tokenizer = tokenizer 47 | self.text_max_length = text_max_length 48 | self.separator = separator 49 | 50 | def __call__(self, example): 51 | docid = example['docid'] 52 | text = example['title'] + self.separator + example['text'] if 'title' in example else example['text'] 53 | text = self.tokenizer.encode(text, 54 | add_special_tokens=False, 55 | max_length=self.text_max_length, 56 | truncation=True) 57 | return {'text_id': docid, 'text': text} 58 | -------------------------------------------------------------------------------- /preprocess/build_train_hn.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoTokenizer 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | from multiprocessing import Pool 8 | from preprocessor import SimpleTrainPreProcessor 9 | 10 | 11 | def load_ranking(rank_file, relevance, n_sample, depth): 12 | with open(rank_file) as rf: 13 | lines = iter(rf) 14 | q_0, p_0, _ = next(lines).strip().split() 15 | 16 | curr_q = q_0 17 | negatives = [] if p_0 in relevance[q_0] else [p_0] 18 | 19 | while True: 20 | try: 21 | q, p, _ = next(lines).strip().split() 22 | if q != curr_q: 23 | negatives = negatives[:depth] 24 | random.shuffle(negatives) 25 | yield curr_q, relevance[curr_q], negatives[:n_sample] 26 | curr_q = q 27 | negatives = [] if p in relevance[q] else [p] 28 | else: 29 | if p not in relevance[q]: 30 | negatives.append(p) 31 | except StopIteration: 32 | negatives = negatives[:depth] 33 | random.shuffle(negatives) 34 | yield curr_q, relevance[curr_q], negatives[:n_sample] 35 | return 36 | 37 | 38 | random.seed(datetime.now()) 39 | parser = ArgumentParser() 40 | parser.add_argument('--tokenizer_name', required=True) 41 | parser.add_argument('--hn_file', required=True) 42 | parser.add_argument('--qrels', required=True) 43 | parser.add_argument('--queries', required=True) 44 | parser.add_argument('--collection', required=True) 45 | parser.add_argument('--save_to', required=True) 46 | 47 | parser.add_argument('--truncate', type=int, default=128) 48 | parser.add_argument('--n_sample', type=int, default=30) 49 | parser.add_argument('--depth', type=int, default=200) 50 | parser.add_argument('--mp_chunk_size', type=int, default=500) 51 | parser.add_argument('--shard_size', type=int, default=45000) 52 | 53 | args = parser.parse_args() 54 | 55 | qrel = SimpleTrainPreProcessor.read_qrel(args.qrels) 56 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 57 | processor = SimpleTrainPreProcessor( 58 | query_file=args.queries, 59 | collection_file=args.collection, 60 | tokenizer=tokenizer, 61 | max_length=args.truncate, 62 | ) 63 | 64 | counter = 0 65 | shard_id = 0 66 | f = None 67 | os.makedirs(args.save_to, exist_ok=True) 68 | 69 | pbar = tqdm(load_ranking(args.hn_file, qrel, args.n_sample, args.depth)) 70 | with Pool() as p: 71 | for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size): 72 | counter += 1 73 | if f is None: 74 | f = open(os.path.join(args.save_to, f'split{shard_id:02d}.hn.json'), 'w') 75 | pbar.set_description(f'split - {shard_id:02d}') 76 | f.write(x + '\n') 77 | 78 | if counter == args.shard_size: 79 | f.close() 80 | f = None 81 | shard_id += 1 82 | counter = 0 83 | 84 | if f is not None: 85 | f.close() -------------------------------------------------------------------------------- /preprocess/combine_nq_triviaqa_negative.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | from argparse import ArgumentParser 6 | 7 | if __name__ == "__main__": 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument('--data_dir', required=True) 11 | parser.add_argument('--input_folder_1', required=True) 12 | parser.add_argument('--input_folder_2', required=True) 13 | parser.add_argument('--output_folder', required=True) 14 | args = parser.parse_args() 15 | 16 | input_path_1 = os.path.join(args.data_dir, args.input_folder_1) 17 | input_path_2 = os.path.join(args.data_dir, args.input_folder_2) 18 | output_path = os.path.join(args.data_dir, args.output_folder) 19 | 20 | # create output data 21 | if not os.path.exists(output_path): 22 | os.mkdir(output_path) 23 | 24 | # load input 1 25 | file_list_1 = [listx for listx in os.listdir(input_path_1) if "json" in listx] 26 | 27 | qid2negatives = {} 28 | qid2positives = {} 29 | 30 | for file in tqdm(file_list_1): 31 | file_1 = os.path.join(input_path_1, file) 32 | with open(file_1, "r", encoding="utf-8") as fi: 33 | for line in fi: 34 | data = json.loads(line) 35 | qid = data["qid"] 36 | positives = data["positives"] 37 | negatives = data["negatives"] 38 | 39 | qid2positives[qid] = positives 40 | qid2negatives[qid] = negatives 41 | 42 | # load input 2 & write mix 43 | file_list_2 = [listx for listx in os.listdir(input_path_2) if "json" in listx] 44 | 45 | neg_num_list = [] 46 | diff_num = 0 47 | for file in tqdm(file_list_2): 48 | file_2 = os.path.join(input_path_2, file) 49 | output_file = os.path.join(output_path, file) 50 | with open(file_2, "r", encoding="utf-8") as fi, \ 51 | open(output_file, "w", encoding="utf-8") as fw: 52 | for line in fi: 53 | data = json.loads(line) 54 | qid = data["qid"] 55 | query = data["query"] 56 | 57 | ## Positive 58 | if qid in qid2positives: 59 | positives = data["positives"] + qid2positives[qid] 60 | 61 | else: 62 | positives = data["positives"] 63 | 64 | ## Negative 65 | if qid in qid2negatives: 66 | negatives = data["negatives"] + qid2negatives[qid] 67 | neg_num_list.append(len(negatives)) 68 | else: 69 | negatives = data["negatives"] 70 | diff_num += 1 71 | 72 | mix_example = { 73 | 'qid': qid, 74 | 'query': query, 75 | 'positives': positives, 76 | 'negatives': negatives, 77 | } 78 | fw.write(json.dumps(mix_example) + '\n') 79 | 80 | 81 | print("diff num = ", diff_num) 82 | print("combine neg num = ", np.mean(neg_num_list)) 83 | print("scuess!") 84 | -------------------------------------------------------------------------------- /preprocess/preprocessor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import datasets 4 | from transformers import PreTrainedTokenizer 5 | from dataclasses import dataclass 6 | 7 | @dataclass 8 | class SimpleCollectionPreProcessor: 9 | tokenizer: PreTrainedTokenizer 10 | separator: str = '\t' 11 | max_length: int = 128 12 | 13 | def process_line(self, line: str): 14 | xx = line.strip().split(self.separator) 15 | text_id, text = xx[0], xx[1:] 16 | text_encoded = self.tokenizer.encode( 17 | self.tokenizer.sep_token.join(text), 18 | add_special_tokens=False, 19 | max_length=self.max_length, 20 | truncation=True 21 | ) 22 | encoded = { 23 | 'text_id': text_id, 24 | 'text': text_encoded 25 | } 26 | return json.dumps(encoded) 27 | 28 | 29 | @dataclass 30 | class SimpleTrainPreProcessor: 31 | query_file: str 32 | collection_file: str 33 | tokenizer: PreTrainedTokenizer 34 | 35 | max_length: int = 128 36 | columns = ['text_id', 'title', 'text'] 37 | title_field = 'title' 38 | text_field = 'text' 39 | 40 | def __post_init__(self): 41 | self.queries = self.read_queries(self.query_file) 42 | self.collection = datasets.load_dataset( 43 | 'csv', 44 | data_files=self.collection_file, 45 | cache_dir=self.collection_file+".cache", 46 | column_names=self.columns, 47 | delimiter='\t', 48 | )['train'] 49 | 50 | @staticmethod 51 | def read_queries(queries): 52 | qmap = {} 53 | with open(queries) as f: 54 | for l in f: 55 | qid, qry = l.strip().split('\t') 56 | qmap[qid] = qry 57 | return qmap 58 | 59 | @staticmethod 60 | def read_qrel(relevance_file): 61 | qrel = {} 62 | with open(relevance_file, encoding='utf8') as f: 63 | tsvreader = csv.reader(f, delimiter="\t") 64 | for [topicid, _, docid, rel] in tsvreader: 65 | assert rel == "1" 66 | if topicid in qrel: 67 | qrel[topicid].append(docid) 68 | else: 69 | qrel[topicid] = [docid] 70 | return qrel 71 | 72 | def get_query(self, q): 73 | query_encoded = self.tokenizer.encode( 74 | self.queries[q], 75 | add_special_tokens=False, 76 | max_length=self.max_length, 77 | truncation=True 78 | ) 79 | return query_encoded 80 | 81 | def get_passage(self, p): 82 | entry = self.collection[int(p)] 83 | title = entry[self.title_field] 84 | title = "" if title is None else title 85 | body = entry[self.text_field] 86 | content = title + self.tokenizer.sep_token + body 87 | 88 | passage_encoded = self.tokenizer.encode( 89 | content, 90 | add_special_tokens=False, 91 | max_length=self.max_length, 92 | truncation=True 93 | ) 94 | 95 | return passage_encoded 96 | 97 | def process_one(self, train): 98 | q, pp, nn = train 99 | train_example = { 100 | 'query': self.get_query(q), 101 | 'positives': [self.get_passage(p) for p in pp], 102 | 'negatives': [self.get_passage(n) for n in nn], 103 | } 104 | 105 | return json.dumps(train_example) -------------------------------------------------------------------------------- /ancetele/grad_cache/functional.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Callable, Union, Tuple, Any 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch import distributed as dist 7 | 8 | from .context_managers import RandContext 9 | 10 | 11 | def cached(func: Callable[..., Tensor]): 12 | """ 13 | A decorator that takes a model call function into a cached compatible version. 14 | :param func: A function that calls the model and return representation tensor. 15 | :return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for 16 | the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor. 17 | """ 18 | @wraps(func) 19 | def cache_func(*args, **kwargs): 20 | rnd_state = RandContext() 21 | with torch.no_grad(): 22 | reps_no_grad = func(*args, **kwargs) 23 | if isinstance(reps_no_grad, Tensor): 24 | reps_no_grad = (reps_no_grad, ) 25 | else: 26 | assert all(isinstance(v, Tensor) for v in reps_no_grad) 27 | leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad) 28 | 29 | @wraps(func) 30 | def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]): 31 | with rnd_state: 32 | reps = func(*args, **kwargs) 33 | if isinstance(reps, Tensor): 34 | reps = (reps,) 35 | if isinstance(cache_reps, Tensor): 36 | cache_reps = (cache_reps,) 37 | assert len(reps) == len(cache_reps) 38 | 39 | surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), zip(reps, cache_reps)), 0) 40 | surrogate.backward() 41 | 42 | return leaf_reps + (forward_backward_func,) 43 | return cache_func 44 | 45 | 46 | def _cat_tensor_list(xx): 47 | if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx): 48 | return torch.cat(xx) 49 | else: 50 | return xx 51 | 52 | 53 | def cat_input_tensor(func: Callable[..., Tensor]): 54 | """ 55 | A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor 56 | on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple 57 | cached forward. 58 | :param func: A loss function 59 | :return: Decorated loss function for cached results. 60 | """ 61 | @wraps(func) 62 | def cat_f(*args, **kwargs): 63 | args_cat = [_cat_tensor_list(x) for x in args] 64 | kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values()) 65 | return func(*args_cat, **kwargs_cat) 66 | return cat_f 67 | 68 | 69 | def _maybe_gather_tensor(t: Any, axis: int): 70 | if not isinstance(t, Tensor): 71 | return t 72 | gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] 73 | dist.all_gather(gathered, t) 74 | gathered[dist.get_rank()] = t 75 | return torch.cat(gathered, dim=axis) 76 | 77 | 78 | def gather_input_tensor(func: Callable[..., Tensor], axis=0): 79 | """ 80 | A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. 81 | Intended to be used with distributed contrastive learning loss. 82 | :param func: A loss function 83 | :param axis: The axis the gathered tensors are concatenated. 84 | :return: Decorated loss function for distributed training. 85 | """ 86 | @wraps(func) 87 | def f(*args, **kwargs): 88 | args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args] 89 | kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values()) 90 | return func(*args_gathered, **kwargs_gathered) 91 | return f 92 | -------------------------------------------------------------------------------- /shells/infer_msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT/OUTPUT 5 | export train_job_name=ance-tele_msmarco_qry-psg-encoder 6 | export infer_job_name=inference.${train_job_name} 7 | ## ************************************* 8 | ## ENCODE Corpus GPUs 9 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 10 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 11 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 12 | ## Search Top-k GPUs 13 | SEARCH_CUDA="0,1,2,3,4" 14 | ## ************************************* 15 | TOKENIZER=bert-base-uncased 16 | TOKENIZER_ID=bert 17 | SplitNum=10 18 | ## ************************************* 19 | 20 | ## ********************************************** 21 | ## Infer 22 | ## ********************************************** 23 | ## Create Folder 24 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 25 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 26 | 27 | ## Encoding Corpus 28 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 29 | do 30 | ## ************************************* 31 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 32 | do 33 | ## ************************************* 34 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 35 | then 36 | break 2 37 | fi 38 | 39 | ## ************************************* 40 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 41 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 42 | echo ${OUTPUT_DIR}/${train_job_name} && 43 | echo split-${i} on gpu-${CUDA} && 44 | 45 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 46 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 47 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 48 | --fp16 \ 49 | --per_device_eval_batch_size 1024 \ 50 | --dataloader_num_workers 2 \ 51 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 52 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 53 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 54 | ## ************************************* 55 | sleep 3 & 56 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 57 | done 58 | done 59 | 60 | ## ************************************* 61 | ## Encoding Dev query 62 | ## ************************************* 63 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 64 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 65 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 66 | --fp16 \ 67 | --q_max_len 32 \ 68 | --encode_is_qry \ 69 | --per_device_eval_batch_size 1024 \ 70 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/dev.query.json \ 71 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 72 | 73 | ## ************************************* 74 | ## Search Dev (GPU/CPU) 75 | ## ************************************* 76 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 77 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 78 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 79 | --index_num ${SplitNum} \ 80 | --use_gpu \ 81 | --batch_size 1024 \ 82 | --save_text \ 83 | --depth 10 \ 84 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/dev.rank.tsv \ 85 | 86 | # --sub_split_num 5 \ 87 | # ## if CUDA memory is not enough, set this augments. 88 | 89 | 90 | ## ************************************* 91 | ## Compute Dev MRR@10 92 | ## ************************************* 93 | python ../scripts/score_to_marco.py ${OUTPUT_DIR}/${infer_job_name}/dev.rank.tsv 94 | python ../scripts/ms_marco_eval.py ${DATA_DIR}/qrels.dev.small.tsv ${OUTPUT_DIR}/${infer_job_name}/dev.rank.tsv.marco &> \ 95 | ${OUTPUT_DIR}/${infer_job_name}/dev_mrr.log 96 | 97 | ## The Dev MRR@10 resuls are saved in dev_mrr.log -------------------------------------------------------------------------------- /ancetele/encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import json 11 | import torch 12 | 13 | from torch.utils.data import DataLoader 14 | from transformers import AutoConfig, AutoTokenizer 15 | from transformers import ( 16 | HfArgumentParser, 17 | ) 18 | 19 | import networks 20 | import dataloaders 21 | 22 | from arguments import ModelArguments, DataArguments, \ 23 | DenseTrainingArguments as TrainingArguments 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def main(): 30 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: TrainingArguments 38 | 39 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 40 | raise NotImplementedError('Multi-GPU encoding is not supported.') 41 | 42 | # Setup logging 43 | logging.basicConfig( 44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 45 | datefmt="%m/%d/%Y %H:%M:%S", 46 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 47 | ) 48 | 49 | num_labels = 1 50 | config = AutoConfig.from_pretrained( 51 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 52 | num_labels=num_labels, 53 | cache_dir=model_args.cache_dir, 54 | ) 55 | tokenizer = AutoTokenizer.from_pretrained( 56 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 57 | cache_dir=model_args.cache_dir, 58 | use_fast=False, 59 | ) 60 | 61 | ## Model 62 | model = networks.get_network( 63 | model_args, 64 | data_args, 65 | training_args, 66 | config=config, 67 | cache_dir=model_args.cache_dir, 68 | do_train=False, 69 | ) 70 | 71 | ## Train dataset and batchfy 72 | encode_dataset, EncodeCollator = dataloaders.get_encode_dataset( 73 | tokenizer=tokenizer, 74 | data_args=data_args, 75 | ) 76 | 77 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 78 | 79 | encode_loader = DataLoader( 80 | encode_dataset, 81 | batch_size=training_args.per_device_eval_batch_size, 82 | collate_fn=EncodeCollator( 83 | tokenizer, 84 | max_length=text_max_length, 85 | padding='max_length' 86 | ), 87 | shuffle=False, 88 | drop_last=False, 89 | num_workers=training_args.dataloader_num_workers, 90 | ) 91 | 92 | model = model.to(training_args.device) 93 | model.eval() 94 | 95 | 96 | ## *********************************** 97 | ## Dense-Encoder 98 | ## *********************************** 99 | encoded = [] 100 | lookup_indices = [] 101 | for (batch_ids, batch) in tqdm(encode_loader): 102 | lookup_indices.extend(batch_ids) 103 | 104 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 105 | with torch.no_grad(): 106 | for k, v in batch.items(): 107 | batch[k] = v.to(training_args.device) 108 | if data_args.encode_is_qry: 109 | model_output: DenseOutput = model(query=batch) 110 | encoded.append(model_output.q_reps.cpu().detach().numpy()) 111 | else: 112 | model_output: DenseOutput = model(passage=batch) 113 | encoded.append(model_output.p_reps.cpu().detach().numpy()) 114 | 115 | encoded = np.concatenate(encoded) 116 | 117 | with open(data_args.encoded_save_path, 'wb') as f: 118 | pickle.dump((encoded, lookup_indices), f) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /shells/infer_nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | export pyserini_eval_topics=dpr-nq-test 5 | ## ************************************* 6 | ## INPUT/OUTPUT 7 | export qry_encoder_name=ance-tele_nq_qry-encoder 8 | export psg_encoder_name=ance-tele_nq_psg-encoder 9 | export infer_job_name=inference.ance-tele.nq 10 | ## ************************************* 11 | ## ENCODE Corpus GPUs 12 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 13 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 14 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 15 | ## Search Top-k GPUs 16 | SEARCH_CUDA="0,1,2,3,4" 17 | ## ************************************* 18 | ## Length SetUp 19 | export q_max_len=32 20 | export p_max_len=156 21 | ## ************************************* 22 | TOKENIZER=bert-base-uncased 23 | TOKENIZER_ID=bert 24 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 25 | ## ************************************* 26 | 27 | ## ********************************************** 28 | ## Infer 29 | ## ********************************************** 30 | ## Create Folder 31 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 32 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 33 | 34 | ## Encoding Corpus 35 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 36 | do 37 | ## ************************************* 38 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 39 | do 40 | ## ************************************* 41 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 42 | then 43 | break 2 44 | fi 45 | 46 | ## ************************************* 47 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 48 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 49 | echo ${OUTPUT_DIR}/${train_job_name} && 50 | echo split-${i} on gpu-${CUDA} && 51 | 52 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 53 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 54 | --model_name_or_path ${OUTPUT_DIR}/${psg_encoder_name} \ 55 | --fp16 \ 56 | --per_device_eval_batch_size 1024 \ 57 | --dataloader_num_workers 2 \ 58 | --p_max_len ${p_max_len} \ 59 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 60 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 61 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 62 | ## ************************************* 63 | sleep 3 & 64 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 65 | done 66 | done 67 | 68 | 69 | ## ************************************* 70 | ## Encode [Test-Query] 71 | ## ************************************* 72 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 73 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 74 | --model_name_or_path ${OUTPUT_DIR}/${qry_encoder_name} \ 75 | --fp16 \ 76 | --q_max_len ${q_max_len} \ 77 | --encode_is_qry \ 78 | --per_device_eval_batch_size 1024 \ 79 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/test.query.json \ 80 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 81 | 82 | ## ************************************* 83 | ## Search [Test] 84 | ## ************************************* 85 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 86 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 87 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 88 | --index_num ${SplitNum} \ 89 | --batch_size 1024 \ 90 | --use_gpu \ 91 | --save_text \ 92 | --depth 100 \ 93 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv \ 94 | # --sub_split_num 5 \ 95 | ## if CUDA memory is not enough, set this augment. 96 | 97 | 98 | ## ************************************* 99 | ## Eval [Test] 100 | ## ************************************* 101 | python ../scripts/convert_result_to_trec.py --input ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv \ 102 | 103 | python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run \ 104 | --topics ${pyserini_eval_topics} \ 105 | --index ${CORPUS_DATA_DIR}/index-wikipedia-dpr-20210120-d1b9e6 \ 106 | --input ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.teIn \ 107 | --output ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.json \ 108 | 109 | python -m pyserini.eval.evaluate_dpr_retrieval \ 110 | --retrieval ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.json --topk 5 20 100 &> ${OUTPUT_DIR}/${infer_job_name}/test-hits.log 111 | 112 | ## The Test R@5/20/100 resuls are saved in test-hits.log -------------------------------------------------------------------------------- /shells/infer_triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | export pyserini_eval_topics=dpr-trivia-test 5 | ## ************************************* 6 | ## INPUT/OUTPUT 7 | export qry_encoder_name=ance-tele_triviaqa_qry-encoder 8 | export psg_encoder_name=ance-tele_triviaqa_psg-encoder 9 | export infer_job_name=inference.ance-tele.triviaqa 10 | ## ************************************* 11 | ## ENCODE Corpus GPUs 12 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 13 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 14 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 15 | ## Search Top-k GPUs 16 | SEARCH_CUDA="0,1,2,3,4" 17 | ## ************************************* 18 | ## Length SetUp 19 | export q_max_len=64 ## TriviaQA query is longer, so we use length 64 for inference (training still uses length 32 to save cost) 20 | export p_max_len=156 21 | ## ************************************* 22 | TOKENIZER=bert-base-uncased 23 | TOKENIZER_ID=bert 24 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 25 | ## ************************************* 26 | 27 | ## ********************************************** 28 | ## Infer 29 | ## ********************************************** 30 | ## Create Folder 31 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 32 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 33 | 34 | ## Encoding Corpus 35 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 36 | do 37 | ## ************************************* 38 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 39 | do 40 | ## ************************************* 41 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 42 | then 43 | break 2 44 | fi 45 | 46 | ## ************************************* 47 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 48 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 49 | echo ${OUTPUT_DIR}/${train_job_name} && 50 | echo split-${i} on gpu-${CUDA} && 51 | 52 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 53 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 54 | --model_name_or_path ${OUTPUT_DIR}/${psg_encoder_name} \ 55 | --fp16 \ 56 | --per_device_eval_batch_size 1024 \ 57 | --dataloader_num_workers 2 \ 58 | --p_max_len ${p_max_len} \ 59 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 60 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 61 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 62 | ## ************************************* 63 | sleep 3 & 64 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 65 | done 66 | done 67 | 68 | 69 | ## ************************************* 70 | ## Encode [Test-Query] 71 | ## ************************************* 72 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 73 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 74 | --model_name_or_path ${OUTPUT_DIR}/${qry_encoder_name} \ 75 | --fp16 \ 76 | --q_max_len ${q_max_len} \ 77 | --encode_is_qry \ 78 | --per_device_eval_batch_size 1024 \ 79 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/test.query.json \ 80 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 81 | 82 | ## ************************************* 83 | ## Search [Test] 84 | ## ************************************* 85 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 86 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/qry.pt \ 87 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 88 | --index_num ${SplitNum} \ 89 | --batch_size 1024 \ 90 | --use_gpu \ 91 | --save_text \ 92 | --depth 100 \ 93 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv \ 94 | # --sub_split_num 5 \ 95 | ## if CUDA memory is not enough, set this augment. 96 | 97 | 98 | ## ************************************* 99 | ## Eval [Test] 100 | ## ************************************* 101 | python ../scripts/convert_result_to_trec.py --input ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv \ 102 | 103 | python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run \ 104 | --topics ${pyserini_eval_topics} \ 105 | --index ${CORPUS_DATA_DIR}/index-wikipedia-dpr-20210120-d1b9e6 \ 106 | --input ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.teIn \ 107 | --output ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.json \ 108 | 109 | python -m pyserini.eval.evaluate_dpr_retrieval \ 110 | --retrieval ${OUTPUT_DIR}/${infer_job_name}/test.rank.tsv.json --topk 5 20 100 &> ${OUTPUT_DIR}/${infer_job_name}/test-hits.log 111 | 112 | ## The Test R@5/20/100 resuls are saved in test-hits.log -------------------------------------------------------------------------------- /ancetele/trainers/dense_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | from typing import Dict, List, Tuple, Optional, Any, Union 4 | 5 | from transformers.trainer import Trainer 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | 11 | 12 | import sys 13 | sys.path.append("..") 14 | from losses import SimpleContrastiveLoss, DistributedContrastiveLoss 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | # from ..grad_cache import GradCache 20 | # _grad_cache_available = True 21 | 22 | try: 23 | from grad_cache import GradCache 24 | _grad_cache_available = True 25 | except ModuleNotFoundError: 26 | _grad_cache_available = False 27 | 28 | 29 | class DenseTrainer(Trainer): 30 | def __init__(self, *args, **kwargs): 31 | super(DenseTrainer, self).__init__(*args, **kwargs) 32 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1 33 | 34 | def _save(self, output_dir: Optional[str] = None): 35 | output_dir = output_dir if output_dir is not None else self.args.output_dir 36 | os.makedirs(output_dir, exist_ok=True) 37 | logger.info("Saving model checkpoint to %s", output_dir) 38 | self.model.save(output_dir) 39 | 40 | def _prepare_inputs( 41 | self, 42 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...] 43 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]: 44 | prepared = [] 45 | for x in inputs: 46 | if isinstance(x, torch.Tensor): 47 | prepared.append(x.to(self.args.device)) 48 | else: 49 | prepared.append(super()._prepare_inputs(x)) 50 | return prepared 51 | 52 | def get_train_dataloader(self) -> DataLoader: 53 | if self.train_dataset is None: 54 | raise ValueError("Trainer: training requires a train_dataset.") 55 | train_sampler = self._get_train_sampler() 56 | 57 | return DataLoader( 58 | self.train_dataset, 59 | batch_size=self.args.train_batch_size, 60 | sampler=train_sampler, 61 | collate_fn=self.data_collator, 62 | drop_last=True, 63 | num_workers=self.args.dataloader_num_workers, 64 | ) 65 | 66 | ## ------------ Prev Script ------------ ## 67 | def compute_loss(self, model, inputs): 68 | query, passage = inputs 69 | return model(query=query, passage=passage).loss 70 | ## ------------ Prev Script ------------ ## 71 | 72 | # ## ------------ SS Modified ------------ ## 73 | # def compute_loss(self, model, inputs): 74 | # query, passage, distil_scores = inputs 75 | # return model(query=query, passage=passage, distil_scores=distil_scores).loss 76 | # ## ------------ SS Modified ------------ ## 77 | 78 | def training_step(self, *args): 79 | return super(DenseTrainer, self).training_step(*args) / self._dist_loss_scale_factor 80 | 81 | 82 | def split_dense_inputs(model_input: dict, chunk_size: int): 83 | assert len(model_input) == 1 84 | arg_key = list(model_input.keys())[0] 85 | arg_val = model_input[arg_key] 86 | 87 | keys = list(arg_val.keys()) 88 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] 89 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 90 | 91 | return [{arg_key: c} for c in chunked_arg_val] 92 | 93 | 94 | def get_dense_rep(x): 95 | if x.q_reps is None: 96 | return x.p_reps 97 | else: 98 | return x.q_reps 99 | 100 | 101 | class GCDenseTrainer(DenseTrainer): 102 | def __init__(self, *args, **kwargs): 103 | logger.info('Initializing Gradient Cache Trainer') 104 | if not _grad_cache_available: 105 | raise ValueError( 106 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') 107 | super(GCDenseTrainer, self).__init__(*args, **kwargs) 108 | 109 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss 110 | loss_fn = loss_fn_cls(self.model.data_args.train_n_passages) 111 | 112 | self.gc = GradCache( 113 | models=[self.model, self.model], 114 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], 115 | loss_fn=loss_fn, 116 | split_input_fn=split_dense_inputs, 117 | get_rep_fn=get_dense_rep, 118 | fp16=self.args.fp16, 119 | scaler=self.scaler if self.args.fp16 else None 120 | ) 121 | 122 | def training_step(self, model, inputs) -> torch.Tensor: 123 | model.train() 124 | queries, passages = self._prepare_inputs(inputs) 125 | queries, passages = {'query': queries}, {'passage': passages} 126 | 127 | _distributed = self.args.local_rank > -1 128 | self.gc.models = [model, model] 129 | loss = self.gc(queries, passages, no_sync_except_last=_distributed) 130 | 131 | return loss / self._dist_loss_scale_factor 132 | -------------------------------------------------------------------------------- /ancetele/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | from transformers import ( 7 | HfArgumentParser, 8 | set_seed, 9 | ) 10 | from transformers.integrations import TrainerCallback, TensorBoardCallback 11 | 12 | # ## ------- Modified by SS. 13 | # import sys 14 | # sys.path.append("..") 15 | # sys.path.append(os.getcwd()) ## cloud 16 | # ## ------- Modified by SS. 17 | 18 | # from arguments import ModelArguments, DataArguments 19 | # from arguments import DenseTrainingArguments as TrainingArguments 20 | # from ancetele import trainers 21 | # from ancetele import utils 22 | # from ancetele import networks 23 | # from ancetele import dataloaders 24 | 25 | from arguments import ModelArguments, DataArguments 26 | from arguments import DenseTrainingArguments as TrainingArguments 27 | import trainers 28 | import networks 29 | import dataloaders 30 | 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class MyStopTrainCallback(TrainerCallback): 36 | "A callback that prints a message at the end of training step" 37 | 38 | def on_step_end(self, args, state, control, **kwargs): 39 | if state.global_step == args.early_stop_step: 40 | logger.info("End training at step: %d", state.global_step) 41 | control.should_training_stop = True 42 | 43 | return control 44 | 45 | 46 | def main(): 47 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 48 | 49 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 50 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 51 | else: 52 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 53 | model_args: ModelArguments 54 | data_args: DataArguments 55 | training_args: TrainingArguments 56 | 57 | if ( 58 | os.path.exists(training_args.output_dir) 59 | and os.listdir(training_args.output_dir) 60 | and training_args.do_train 61 | and not training_args.overwrite_output_dir 62 | ): 63 | raise ValueError( 64 | f"Output directory ({training_args.output_dir}) \ 65 | already exists and is not empty. Use --overwrite_output_dir to overcome." 66 | ) 67 | 68 | # Setup logging 69 | logging.basicConfig( 70 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 71 | datefmt="%m/%d/%Y %H:%M:%S", 72 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 73 | ) 74 | 75 | logger.warning( 76 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 77 | training_args.local_rank, 78 | training_args.device, 79 | training_args.n_gpu, 80 | bool(training_args.local_rank != -1), 81 | training_args.fp16, 82 | ) 83 | logger.info("Training/evaluation parameters %s", training_args) 84 | logger.info("MODEL parameters %s", model_args) 85 | 86 | set_seed(training_args.seed) 87 | 88 | num_labels = 1 89 | config = AutoConfig.from_pretrained( 90 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 91 | num_labels=num_labels, 92 | cache_dir=model_args.cache_dir, 93 | ) 94 | tokenizer = AutoTokenizer.from_pretrained( 95 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 96 | cache_dir=model_args.cache_dir, 97 | use_fast=False, 98 | ) 99 | 100 | ## Model 101 | model = networks.get_network( 102 | model_args, 103 | data_args, 104 | training_args, 105 | config=config, 106 | cache_dir=model_args.cache_dir, 107 | do_train=True, 108 | ) 109 | 110 | ## Train dataset and batchfy 111 | train_dataset, eval_dataset, QPCollator = dataloaders.get_train_dataset( 112 | tokenizer=tokenizer, 113 | data_args=data_args, 114 | ) 115 | 116 | ## early-stop or tensorboard 117 | callbacks = [] 118 | if training_args.early_stop_step > 0: 119 | logger.info("Setting early stop step at: %d", training_args.early_stop_step) 120 | callbacks.append(MyStopTrainCallback) 121 | if training_args.tensorboard: 122 | logger.info("Setting Tensorboard ...") 123 | callbacks.append(TensorBoardCallback()) 124 | 125 | ## training func 126 | trainer = trainers.get_trainer( 127 | model=model, 128 | args=training_args, 129 | train_dataset=train_dataset, 130 | eval_dataset=eval_dataset, 131 | data_collator=QPCollator( 132 | tokenizer, 133 | max_p_len=data_args.p_max_len, 134 | max_q_len=data_args.q_max_len 135 | ), 136 | callbacks=callbacks, 137 | ) 138 | 139 | train_dataset.trainer = trainer 140 | 141 | trainer.train() # TODO: resume training 142 | 143 | trainer.save_model() 144 | if trainer.is_world_process_zero(): 145 | tokenizer.save_pretrained(training_args.output_dir) 146 | 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /ancetele/dataloaders/dense_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import List, Tuple 4 | 5 | import torch 6 | import datasets 7 | from torch.utils.data import Dataset 8 | from transformers import ( 9 | PreTrainedTokenizer, 10 | BatchEncoding, 11 | DataCollatorWithPadding 12 | ) 13 | 14 | import sys 15 | sys.path.append("..") 16 | from arguments import DataArguments 17 | from trainers import DenseTrainer 18 | # from .trainer import DenseTrainer 19 | 20 | import logging 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class DenseTrainDataset(Dataset): 25 | def __init__( 26 | self, 27 | data_args: DataArguments, 28 | dataset: datasets.Dataset, 29 | tokenizer: PreTrainedTokenizer, 30 | trainer: DenseTrainer = None, 31 | ): 32 | self.train_data = dataset 33 | self.tok = tokenizer 34 | self.trainer = trainer 35 | 36 | self.data_args = data_args 37 | self.total_len = len(self.train_data) 38 | 39 | def create_one_example(self, text_encoding: List[int], is_query=False): 40 | item = self.tok.encode_plus( 41 | text_encoding, 42 | truncation='only_first', 43 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, 44 | padding=False, 45 | return_attention_mask=False, 46 | return_token_type_ids=False, 47 | ) 48 | return item 49 | 50 | def __len__(self): 51 | return self.total_len 52 | 53 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: 54 | group = self.train_data[item] 55 | epoch = int(self.trainer.state.epoch) 56 | 57 | _hashed_seed = hash(item + self.trainer.args.seed) 58 | 59 | qry = group['query'] 60 | encoded_query = self.create_one_example(qry, is_query=True) 61 | 62 | encoded_passages = [] 63 | group_positives = group['positives'] 64 | group_negatives = group['negatives'] 65 | 66 | if self.data_args.positive_passage_no_shuffle: 67 | pos_psg = group_positives[0] 68 | else: 69 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)] 70 | encoded_passages.append(self.create_one_example(pos_psg)) 71 | 72 | negative_size = self.data_args.train_n_passages - 1 73 | if len(group_negatives) < negative_size: 74 | negs = random.choices(group_negatives, k=negative_size) 75 | elif self.data_args.train_n_passages == 1: 76 | negs = [] 77 | elif self.data_args.negative_passage_no_shuffle: 78 | negs = group_negatives[:negative_size] 79 | else: 80 | _offset = epoch * negative_size % len(group_negatives) 81 | negs = [x for x in group_negatives] 82 | random.Random(_hashed_seed).shuffle(negs) 83 | negs = negs * 2 84 | negs = negs[_offset: _offset + negative_size] 85 | 86 | for neg_psg in negs: 87 | encoded_passages.append(self.create_one_example(neg_psg)) 88 | 89 | return encoded_query, encoded_passages 90 | 91 | 92 | 93 | @dataclass 94 | class DenseQPCollator(DataCollatorWithPadding): 95 | """ 96 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] 97 | and pass batch separately to the actual collator. 98 | Abstract out data detail for the model. 99 | """ 100 | max_q_len: int = 32 101 | max_p_len: int = 128 102 | 103 | def __call__(self, features): 104 | qq = [f[0] for f in features] 105 | dd = [f[1] for f in features] 106 | 107 | if isinstance(qq[0], list): 108 | qq = sum(qq, []) 109 | if isinstance(dd[0], list): 110 | dd = sum(dd, []) 111 | 112 | q_collated = self.tokenizer.pad( 113 | qq, 114 | padding='max_length', 115 | max_length=self.max_q_len, 116 | return_tensors="pt", 117 | ) 118 | d_collated = self.tokenizer.pad( 119 | dd, 120 | padding='max_length', 121 | max_length=self.max_p_len, 122 | return_tensors="pt", 123 | ) 124 | 125 | return q_collated, d_collated 126 | 127 | 128 | 129 | 130 | class DenseEncodeDataset(Dataset): 131 | input_keys = ['text_id', 'text'] 132 | 133 | def __init__( 134 | self, 135 | data_args: DataArguments, 136 | dataset: datasets.Dataset, 137 | tokenizer: PreTrainedTokenizer, 138 | ): 139 | self.encode_data = dataset 140 | self.tok = tokenizer 141 | self.max_len = data_args.q_max_len if data_args.encode_is_qry \ 142 | else data_args.p_max_len 143 | 144 | def __len__(self): 145 | return len(self.encode_data) 146 | 147 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]: 148 | text_id, text = (self.encode_data[item][f] for f in self.input_keys) 149 | encoded_text = self.tok.encode_plus( 150 | text, 151 | max_length=self.max_len, 152 | truncation='only_first', 153 | padding=False, 154 | return_token_type_ids=False, 155 | ) 156 | return text_id, encoded_text 157 | 158 | 159 | @dataclass 160 | class DenseEncodeCollator(DataCollatorWithPadding): 161 | def __call__(self, features): 162 | text_ids = [x[0] for x in features] 163 | text_features = [x[1] for x in features] 164 | collated_features = super().__call__(text_features) 165 | return text_ids, collated_features 166 | 167 | -------------------------------------------------------------------------------- /ancetele/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | target_model_path: str = field( 13 | default=None, 14 | metadata={"help": "Path to pretrained reranker target model"} 15 | ) 16 | config_name: Optional[str] = field( 17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 18 | ) 19 | tokenizer_name: Optional[str] = field( 20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 21 | ) 22 | cache_dir: Optional[str] = field( 23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 24 | ) 25 | 26 | # modeling 27 | untie_encoder: bool = field( 28 | default=False, 29 | metadata={"help": "no weight sharing between qry passage encoders"} 30 | ) 31 | 32 | # out projection 33 | add_pooler: bool = field(default=False) 34 | projection_in_dim: int = field(default=768) 35 | projection_out_dim: int = field(default=768) 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | train_dir: str = field( 41 | default=None, metadata={"help": "Path to train directory"} 42 | ) 43 | eval_dir: str = field( 44 | default=None, metadata={"help": "Path to eval directory"} 45 | ) 46 | dataset_name: str = field( 47 | default=None, metadata={"help": "huggingface dataset name"} 48 | ) 49 | passage_field_separator: str = field(default=' ') 50 | dataset_proc_num: int = field( 51 | default=12, metadata={"help": "number of proc used in dataset preprocess"} 52 | ) 53 | train_n_passages: int = field(default=8) 54 | positive_passage_no_shuffle: bool = field( 55 | default=False, metadata={"help": "always use the first positive passage"}) 56 | negative_passage_no_shuffle: bool = field( 57 | default=False, metadata={"help": "always use the first negative passages"}) 58 | 59 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"}) 60 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"}) 61 | encode_is_qry: bool = field(default=False) 62 | encode_num_shard: int = field(default=1) 63 | encode_shard_index: int = field(default=0) 64 | 65 | q_max_len: int = field( 66 | default=32, 67 | metadata={ 68 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 69 | "than this will be truncated, sequences shorter will be padded." 70 | }, 71 | ) 72 | p_max_len: int = field( 73 | default=128, 74 | metadata={ 75 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " 76 | "than this will be truncated, sequences shorter will be padded." 77 | }, 78 | ) 79 | train_cache_dir: Optional[str] = field( 80 | default=None, metadata={"help": "Where do you want to store the train data downloaded from huggingface, if None, repeated download"} 81 | ) 82 | 83 | eval_cache_dir: Optional[str] = field( 84 | default=None, metadata={"help": "Where do you want to store the eval data downloaded from huggingface, if None, repeated download"} 85 | ) 86 | 87 | ## split load 88 | split_load_data: bool = field(default=False) 89 | 90 | def __post_init__(self): 91 | if self.dataset_name is not None: 92 | info = self.dataset_name.split('/') 93 | # self.dataset_split = info[-1] if len(info) == 3 else 'train' 94 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info) 95 | self.dataset_language = 'default' 96 | if ':' in self.dataset_name: 97 | self.dataset_name, self.dataset_language = self.dataset_name.split(':') 98 | else: 99 | self.dataset_name = 'json' 100 | # self.dataset_split = 'train' 101 | self.dataset_language = 'default' 102 | 103 | if self.train_dir is not None: 104 | ## SS Modified -------------------- 105 | self.train_cache_dir = os.path.join(self.train_dir, "cache") 106 | ## SS Modified -------------------- 107 | files = os.listdir(self.train_dir) 108 | self.train_path = [ 109 | os.path.join(self.train_dir, f) 110 | for f in files 111 | if f.endswith('jsonl') or f.endswith('json') 112 | ] 113 | else: 114 | self.train_path = None 115 | 116 | if self.eval_dir is not None: 117 | ## SS Modified -------------------- 118 | self.eval_cache_dir = os.path.join(self.eval_dir, "cache") 119 | ## SS Modified -------------------- 120 | files = os.listdir(self.eval_dir) 121 | self.eval_path = [ 122 | os.path.join(self.eval_dir, f) 123 | for f in files 124 | if f.endswith('jsonl') or f.endswith('json') 125 | ] 126 | else: 127 | self.eval_path = None 128 | 129 | 130 | @dataclass 131 | class DenseTrainingArguments(TrainingArguments): 132 | warmup_ratio: float = field(default=0.1) 133 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) 134 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) 135 | 136 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) 137 | gc_q_chunk_size: int = field(default=4) 138 | gc_p_chunk_size: int = field(default=32) 139 | 140 | ## SS added 141 | early_stop_step: int = field(default=-1) 142 | tensorboard: bool = field(default=False) 143 | -------------------------------------------------------------------------------- /ancetele/dataloaders/hf_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset, concatenate_datasets 3 | from transformers import PreTrainedTokenizer 4 | from .dataset_utils import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor 5 | 6 | import sys 7 | sys.path.append("..") 8 | from arguments import DataArguments 9 | 10 | DEFAULT_PROCESSORS = [TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor] 11 | PROCESSOR_INFO = { 12 | 'json': [None, None, None] 13 | } 14 | 15 | 16 | class HFDataset: 17 | def __init__( 18 | self, 19 | tokenizer: PreTrainedTokenizer, 20 | data_args: DataArguments, 21 | dataset_split: str, 22 | data_files: str, 23 | cache_dir: str 24 | ): 25 | 26 | ## ******************************************************** 27 | if data_args.split_load_data: 28 | dataset_list = [] 29 | for filepath in data_files: 30 | dataset_list.append( 31 | load_dataset( 32 | data_args.dataset_name, 33 | data_args.dataset_language, 34 | data_files=[filepath], 35 | cache_dir=os.path.join(filepath+".cache"), 36 | )[dataset_split] 37 | ) 38 | 39 | self.dataset = concatenate_datasets(dataset_list) 40 | ## ******************************************************** 41 | else: 42 | 43 | if data_files: 44 | data_files = {dataset_split: data_files} 45 | ## {"train/eval": [filepath_1, filepath_2, ...]} 46 | 47 | self.dataset = load_dataset( 48 | data_args.dataset_name, 49 | data_args.dataset_language, 50 | data_files=data_files, 51 | cache_dir=cache_dir 52 | )[dataset_split] 53 | 54 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][0] if data_args.dataset_name in PROCESSOR_INFO \ 55 | else DEFAULT_PROCESSORS[0] ## None 56 | 57 | self.tokenizer = tokenizer 58 | self.q_max_len = data_args.q_max_len 59 | self.p_max_len = data_args.p_max_len 60 | self.proc_num = data_args.dataset_proc_num 61 | self.neg_num = data_args.train_n_passages - 1 62 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 63 | 64 | def process(self, shard_num=1, shard_idx=0): 65 | self.dataset = self.dataset.shard(shard_num, shard_idx) 66 | if self.preprocessor is not None: 67 | self.dataset = self.dataset.map( 68 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len, self.separator), 69 | batched=False, 70 | num_proc=self.proc_num, 71 | remove_columns=self.dataset.column_names, 72 | desc="Running tokenizer on train dataset", 73 | ) 74 | return self.dataset 75 | 76 | 77 | class HFQueryDataset: 78 | def __init__( 79 | self, 80 | tokenizer: PreTrainedTokenizer, 81 | data_args: DataArguments, 82 | cache_dir: str, 83 | dataset_split: str, 84 | ): 85 | data_files = data_args.encode_in_path 86 | if data_files: 87 | data_files = {dataset_split: data_files} 88 | self.dataset = load_dataset(data_args.dataset_name, 89 | data_args.dataset_language, 90 | data_files=data_files, cache_dir=cache_dir)[dataset_split] 91 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][1] if data_args.dataset_name in PROCESSOR_INFO \ 92 | else DEFAULT_PROCESSORS[1] 93 | self.tokenizer = tokenizer 94 | self.q_max_len = data_args.q_max_len 95 | self.proc_num = data_args.dataset_proc_num 96 | 97 | def process(self, shard_num=1, shard_idx=0): 98 | self.dataset = self.dataset.shard(shard_num, shard_idx) 99 | if self.preprocessor is not None: 100 | self.dataset = self.dataset.map( 101 | self.preprocessor(self.tokenizer, self.q_max_len), 102 | batched=False, 103 | num_proc=self.proc_num, 104 | remove_columns=self.dataset.column_names, 105 | desc="Running tokenization", 106 | ) 107 | return self.dataset 108 | 109 | 110 | class HFCorpusDataset: 111 | def __init__( 112 | self, 113 | tokenizer: PreTrainedTokenizer, 114 | data_args: DataArguments, 115 | cache_dir: str, 116 | dataset_split: str, 117 | ): 118 | data_files = data_args.encode_in_path 119 | if data_files: 120 | data_files = {dataset_split: data_files} 121 | self.dataset = load_dataset(data_args.dataset_name, 122 | data_args.dataset_language, 123 | data_files=data_files, cache_dir=cache_dir)[dataset_split] 124 | script_prefix = data_args.dataset_name 125 | if script_prefix.endswith('-corpus'): 126 | script_prefix = script_prefix[:-7] 127 | self.preprocessor = PROCESSOR_INFO[script_prefix][2] \ 128 | if script_prefix in PROCESSOR_INFO else DEFAULT_PROCESSORS[2] 129 | self.tokenizer = tokenizer 130 | self.p_max_len = data_args.p_max_len 131 | self.proc_num = data_args.dataset_proc_num 132 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 133 | 134 | def process(self, shard_num=1, shard_idx=0): 135 | self.dataset = self.dataset.shard(shard_num, shard_idx) 136 | if self.preprocessor is not None: 137 | self.dataset = self.dataset.map( 138 | self.preprocessor(self.tokenizer, self.p_max_len, self.separator), 139 | batched=False, 140 | num_proc=self.proc_num, 141 | remove_columns=self.dataset.column_names, 142 | desc="Running tokenization", 143 | ) 144 | return self.dataset 145 | -------------------------------------------------------------------------------- /shells/epi-1-mine-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT/OUTPUT 5 | export train_job_name=co-condenser-marco 6 | export infer_job_name=inference.${train_job_name} 7 | ## OUTPUT 8 | export new_ann_hn_file_name=ann-neg.${train_job_name} 9 | export new_la_hn_file_name=la-neg.${train_job_name} 10 | export new_tele_file_name_wo_mom=epi-1-tele-neg.msmarco 11 | ## ************************************* 12 | ## ************************************* 13 | TOKENIZER=bert-base-uncased 14 | TOKENIZER_ID=bert 15 | SplitNum=10 16 | ## ************************************* 17 | ## ENCODE Corpus GPUs 18 | ENCODE_CUDA="0,1,2,3,4" 19 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 20 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 21 | ## Search Top-k GPUs 22 | SEARCH_CUDA="0,1,2,3,4" 23 | 24 | ## ********************************************** 25 | ## Infer 26 | ## ********************************************** 27 | ## Create Folder 28 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 29 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 30 | 31 | ## Encoding Corpus 32 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 33 | do 34 | ## ************************************* 35 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 36 | do 37 | ## ************************************* 38 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 39 | then 40 | break 2 41 | fi 42 | 43 | ## ************************************* 44 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 45 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 46 | echo ${OUTPUT_DIR}/${train_job_name} && 47 | echo split-${i} on gpu-${CUDA} && 48 | 49 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 50 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 51 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 52 | --fp16 \ 53 | --per_device_eval_batch_size 1024 \ 54 | --dataloader_num_workers 2 \ 55 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 56 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 57 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 58 | ## ************************************* 59 | sleep 3 & 60 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 61 | done 62 | done 63 | 64 | ## ************************************* 65 | ## Encoding Train-Queries 66 | ## ************************************* 67 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 68 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 69 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 70 | --fp16 \ 71 | --q_max_len 32 \ 72 | --encode_is_qry \ 73 | --per_device_eval_batch_size 2048 \ 74 | --dataloader_num_workers 2 \ 75 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 76 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 77 | 78 | ## ************************************* 79 | ## Encoding Train-Positives 80 | ## ************************************* 81 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 82 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 83 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 84 | --fp16 \ 85 | --per_device_eval_batch_size 1024 \ 86 | --dataloader_num_workers 2 \ 87 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.positives.json \ 88 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 89 | 90 | 91 | ## ************************************* 92 | ## Search Train (GPU) 93 | ## ************************************* 94 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 95 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 96 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 97 | --index_num ${SplitNum} \ 98 | --use_gpu \ 99 | --batch_size 1024 \ 100 | --save_text \ 101 | --depth 200 \ 102 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 103 | # --sub_split_num 5 \ 104 | # ## sub_split_num: if CUDA memory is not enough, set this augments. 105 | 106 | ## ************************************* 107 | ## Search Train-Positives (GPU) 108 | ## ************************************* 109 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 110 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 111 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 112 | --index_num ${SplitNum} \ 113 | --use_gpu \ 114 | --batch_size 1024 \ 115 | --save_text \ 116 | --depth 200 \ 117 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 118 | # --sub_split_num 5 \ 119 | # ## sub_split_num: if CUDA memory is not enough, set this augments. 120 | 121 | ## ************************************* 122 | ## Mine Train Negative 123 | ## ************************************* 124 | python ../preprocess/build_train_hn.py \ 125 | --tokenizer_name ${TOKENIZER} \ 126 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 127 | --qrels ${DATA_DIR}/qrels.train.tsv \ 128 | --queries ${DATA_DIR}/train.query.txt \ 129 | --collection ${DATA_DIR}/corpus.tsv \ 130 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 131 | --depth 200 \ 132 | --n_sample 30 \ 133 | 134 | ## ************************************* 135 | ## Mine Train-Positive Negative 136 | ## ************************************* 137 | python ../preprocess/build_train_hn.py \ 138 | --tokenizer_name ${TOKENIZER} \ 139 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 140 | --qrels ${DATA_DIR}/qrels.train.tsv \ 141 | --queries ${DATA_DIR}/train.query.txt \ 142 | --collection ${DATA_DIR}/corpus.tsv \ 143 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 144 | --depth 200 \ 145 | --n_sample 30 \ 146 | 147 | # # ************************************* 148 | # # Combine ANN + LA Negatives 149 | # # ************************************* 150 | python ../preprocess/combine_marco_negative.py \ 151 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 152 | --input_folder_1 ${new_la_hn_file_name} \ 153 | --input_folder_2 ${new_ann_hn_file_name} \ 154 | --output_folder ${new_tele_file_name_wo_mom} \ -------------------------------------------------------------------------------- /shells/epi-2-mine-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT/OUTPUT 5 | export train_job_name=epi-1.ance-tele.msmarco.checkp-20000 6 | export infer_job_name=inference.${train_job_name} 7 | ## OUTPUT 8 | export new_ann_hn_file_name=ann-neg.${train_job_name} 9 | export new_la_hn_file_name=la-neg.${train_job_name} 10 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 11 | 12 | export mom_tele_file_name=epi-1-tele-neg.msmarco 13 | export new_tele_file_name=epi-2-tele-neg.msmarco 14 | ## ************************************* 15 | ## ************************************* 16 | TOKENIZER=bert-base-uncased 17 | TOKENIZER_ID=bert 18 | SplitNum=10 19 | ## ************************************* 20 | ## ENCODE Corpus GPUs 21 | ENCODE_CUDA="0,1,2,3,4" 22 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 23 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 24 | ## Search Top-k GPUs 25 | SEARCH_CUDA="0,1,2,3,4" 26 | 27 | ## ********************************************** 28 | ## Infer 29 | ## ********************************************** 30 | ## Create Folder 31 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 32 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 33 | 34 | ## Encoding Corpus 35 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 36 | do 37 | ## ************************************* 38 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 39 | do 40 | ## ************************************* 41 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 42 | then 43 | break 2 44 | fi 45 | 46 | ## ************************************* 47 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 48 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 49 | echo ${OUTPUT_DIR}/${train_job_name} && 50 | echo split-${i} on gpu-${CUDA} && 51 | 52 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 53 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 54 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 55 | --fp16 \ 56 | --per_device_eval_batch_size 1024 \ 57 | --dataloader_num_workers 2 \ 58 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 59 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 60 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 61 | ## ************************************* 62 | sleep 3 & 63 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 64 | done 65 | done 66 | 67 | ## ************************************* 68 | ## Encoding Train-Queries 69 | ## ************************************* 70 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 71 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 72 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 73 | --fp16 \ 74 | --q_max_len 32 \ 75 | --encode_is_qry \ 76 | --per_device_eval_batch_size 2048 \ 77 | --dataloader_num_workers 2 \ 78 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 79 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 80 | 81 | ## ************************************* 82 | ## Encoding Train-Positives 83 | ## ************************************* 84 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 85 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 86 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 87 | --fp16 \ 88 | --per_device_eval_batch_size 1024 \ 89 | --dataloader_num_workers 2 \ 90 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.positives.json \ 91 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 92 | 93 | 94 | ## ************************************* 95 | ## Search Train (GPU) 96 | ## ************************************* 97 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 98 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 99 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 100 | --index_num ${SplitNum} \ 101 | --use_gpu \ 102 | --batch_size 1024 \ 103 | --save_text \ 104 | --depth 200 \ 105 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 106 | --sub_split_num 5 \ 107 | ## sub_split_num: if CUDA memory is not enough, set this augments. 108 | 109 | ## ************************************* 110 | ## Search Train-Positives (GPU) 111 | ## ************************************* 112 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 113 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 114 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 115 | --index_num ${SplitNum} \ 116 | --use_gpu \ 117 | --batch_size 1024 \ 118 | --save_text \ 119 | --depth 200 \ 120 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 121 | --sub_split_num 5 \ 122 | ## sub_split_num: if CUDA memory is not enough, set this augments. 123 | 124 | ## ************************************* 125 | ## Mine Train Negative 126 | ## ************************************* 127 | python ../preprocess/build_train_hn.py \ 128 | --tokenizer_name ${TOKENIZER} \ 129 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 130 | --qrels ${DATA_DIR}/qrels.train.tsv \ 131 | --queries ${DATA_DIR}/train.query.txt \ 132 | --collection ${DATA_DIR}/corpus.tsv \ 133 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 134 | --depth 200 \ 135 | --n_sample 30 \ 136 | 137 | ## ************************************* 138 | ## Mine Train-Positive Negative 139 | ## ************************************* 140 | python ../preprocess/build_train_hn.py \ 141 | --tokenizer_name ${TOKENIZER} \ 142 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 143 | --qrels ${DATA_DIR}/qrels.train.tsv \ 144 | --queries ${DATA_DIR}/train.query.txt \ 145 | --collection ${DATA_DIR}/corpus.tsv \ 146 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 147 | --depth 200 \ 148 | --n_sample 30 \ 149 | 150 | # # ************************************* 151 | # # Combine (ANN + LA) Negatives 152 | # # ************************************* 153 | python ../preprocess/combine_marco_negative.py \ 154 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 155 | --input_folder_1 ${new_la_hn_file_name} \ 156 | --input_folder_2 ${new_ann_hn_file_name} \ 157 | --output_folder ${new_tele_file_name_wo_mom} \ 158 | 159 | # # ************************************* 160 | # # Combine (ANN + LA + Mom) Negatives 161 | # # ************************************* 162 | python ../preprocess/combine_marco_negative.py \ 163 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 164 | --input_folder_1 ${new_tele_file_name_wo_mom} \ 165 | --input_folder_2 ${mom_tele_file_name} \ 166 | --output_folder ${new_tele_file_name} \ 167 | -------------------------------------------------------------------------------- /shells/epi-3-mine-msmarco.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/msmarco 2 | export OUTPUT_DIR=/home/sunsi/experiments/msmarco-results 3 | ## ************************************* 4 | ## INPUT/OUTPUT 5 | export train_job_name=epi-2.ance-tele.msmarco.checkp-20000 6 | export infer_job_name=inference.${train_job_name} 7 | ## OUTPUT 8 | export new_ann_hn_file_name=ann-neg.${train_job_name} 9 | export new_la_hn_file_name=la-neg.${train_job_name} 10 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 11 | 12 | export mom_tele_file_name=epi-2-tele-neg.msmarco 13 | export new_tele_file_name=epi-3-tele-neg.msmarco 14 | ## ************************************* 15 | ## ************************************* 16 | TOKENIZER=bert-base-uncased 17 | TOKENIZER_ID=bert 18 | SplitNum=10 19 | ## ************************************* 20 | ## ENCODE Corpus GPUs 21 | ENCODE_CUDA="0,1,2,3,4" 22 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 23 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 24 | ## Search Top-k GPUs 25 | SEARCH_CUDA="0,1,2,3,4" 26 | 27 | ## ********************************************** 28 | ## Infer 29 | ## ********************************************** 30 | ## Create Folder 31 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 32 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 33 | 34 | ## Encoding Corpus 35 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 36 | do 37 | ## ************************************* 38 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 39 | do 40 | ## ************************************* 41 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 42 | then 43 | break 2 44 | fi 45 | 46 | ## ************************************* 47 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 48 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 49 | echo ${OUTPUT_DIR}/${train_job_name} && 50 | echo split-${i} on gpu-${CUDA} && 51 | 52 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 53 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 54 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 55 | --fp16 \ 56 | --per_device_eval_batch_size 1024 \ 57 | --dataloader_num_workers 2 \ 58 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 59 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 60 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 61 | ## ************************************* 62 | sleep 3 & 63 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 64 | done 65 | done 66 | 67 | ## ************************************* 68 | ## Encoding Train-Queries 69 | ## ************************************* 70 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 71 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 72 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 73 | --fp16 \ 74 | --q_max_len 32 \ 75 | --encode_is_qry \ 76 | --per_device_eval_batch_size 2048 \ 77 | --dataloader_num_workers 2 \ 78 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 79 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 80 | 81 | ## ************************************* 82 | ## Encoding Train-Positives 83 | ## ************************************* 84 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 85 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 86 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 87 | --fp16 \ 88 | --per_device_eval_batch_size 1024 \ 89 | --dataloader_num_workers 2 \ 90 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.positives.json \ 91 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 92 | 93 | 94 | ## ************************************* 95 | ## Search Train (GPU) 96 | ## ************************************* 97 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 98 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.pt \ 99 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 100 | --index_num ${SplitNum} \ 101 | --use_gpu \ 102 | --batch_size 1024 \ 103 | --save_text \ 104 | --depth 200 \ 105 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 106 | --sub_split_num 5 \ 107 | ## sub_split_num: if CUDA memory is not enough, set this augments. 108 | 109 | ## ************************************* 110 | ## Search Train-Positives (GPU) 111 | ## ************************************* 112 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 113 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 114 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 115 | --index_num ${SplitNum} \ 116 | --use_gpu \ 117 | --batch_size 1024 \ 118 | --save_text \ 119 | --depth 200 \ 120 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 121 | --sub_split_num 5 \ 122 | ## sub_split_num: if CUDA memory is not enough, set this augments. 123 | 124 | ## ************************************* 125 | ## Mine Train Negative 126 | ## ************************************* 127 | python ../preprocess/build_train_hn.py \ 128 | --tokenizer_name ${TOKENIZER} \ 129 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 130 | --qrels ${DATA_DIR}/qrels.train.tsv \ 131 | --queries ${DATA_DIR}/train.query.txt \ 132 | --collection ${DATA_DIR}/corpus.tsv \ 133 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 134 | --depth 200 \ 135 | --n_sample 30 \ 136 | 137 | ## ************************************* 138 | ## Mine Train-Positive Negative 139 | ## ************************************* 140 | python ../preprocess/build_train_hn.py \ 141 | --tokenizer_name ${TOKENIZER} \ 142 | --hn_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 143 | --qrels ${DATA_DIR}/qrels.train.tsv \ 144 | --queries ${DATA_DIR}/train.query.txt \ 145 | --collection ${DATA_DIR}/corpus.tsv \ 146 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 147 | --depth 200 \ 148 | --n_sample 30 \ 149 | 150 | # # ************************************* 151 | # # Combine (ANN + LA) Negatives 152 | # # ************************************* 153 | python ../preprocess/combine_marco_negative.py \ 154 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 155 | --input_folder_1 ${new_la_hn_file_name} \ 156 | --input_folder_2 ${new_ann_hn_file_name} \ 157 | --output_folder ${new_tele_file_name_wo_mom} \ 158 | 159 | # # ************************************* 160 | # # Combine (ANN + LA + Mom) Negatives 161 | # # ************************************* 162 | python ../preprocess/combine_marco_negative.py \ 163 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 164 | --input_folder_1 ${new_tele_file_name_wo_mom} \ 165 | --input_folder_2 ${mom_tele_file_name} \ 166 | --output_folder ${new_tele_file_name} \ 167 | -------------------------------------------------------------------------------- /shells/epi-1-mine-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=co-condenser-wiki 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=epi-1-tele-neg.nq 12 | ## ************************************* 13 | 14 | ## ************************************* 15 | ## ENCODE Corpus GPUs 16 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 17 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 18 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 19 | ## Search Top-k GPUs 20 | SEARCH_CUDA="0,1,2,3,4" 21 | ## ************************************* 22 | ## Length SetUp 23 | export q_max_len=32 24 | export p_max_len=156 25 | ## ************************************* 26 | TOKENIZER=bert-base-uncased 27 | TOKENIZER_ID=bert 28 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 29 | ## ************************************* 30 | 31 | ## ********************************************** 32 | ## Infer 33 | ## ********************************************** 34 | ## Create Folder 35 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 36 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 37 | 38 | ## Encoding Corpus 39 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 40 | do 41 | ## ************************************* 42 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 43 | do 44 | ## ************************************* 45 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 46 | then 47 | break 2 48 | fi 49 | 50 | ## ************************************* 51 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 52 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 53 | echo ${OUTPUT_DIR}/${train_job_name} && 54 | echo split-${i} on gpu-${CUDA} && 55 | 56 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 57 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 58 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 59 | --fp16 \ 60 | --per_device_eval_batch_size 1024 \ 61 | --dataloader_num_workers 2 \ 62 | --p_max_len ${p_max_len} \ 63 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 64 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 65 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 66 | ## ************************************* 67 | sleep 3 & 68 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 69 | done 70 | done 71 | 72 | 73 | ## ************************************* 74 | ## Encode [Train Query] 75 | ## ************************************* 76 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 77 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 78 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 79 | --fp16 \ 80 | --q_max_len ${q_max_len} \ 81 | --encode_is_qry \ 82 | --per_device_eval_batch_size 1024 \ 83 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 84 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 85 | 86 | 87 | ## ************************************* 88 | ## Search [Train] 89 | ## ************************************* 90 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 91 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 92 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 93 | --index_num ${SplitNum} \ 94 | --batch_size 1024 \ 95 | --use_gpu \ 96 | --save_text \ 97 | --depth 200 \ 98 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 99 | --sub_split_num 5 \ 100 | ## if CUDA memory is not enough, set this augment. 101 | 102 | 103 | # # *************************************************** 104 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 105 | # # *************************************************** 106 | python ../preprocess/build_train_em_hn.py \ 107 | --tokenizer_name ${TOKENIZER} \ 108 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 109 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 110 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 111 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 112 | --n_sample 80 \ 113 | --depth 200 \ 114 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 115 | --mark hn \ 116 | 117 | 118 | # # *************************************************** 119 | # # Encode [Train-Positive] 120 | # # *************************************************** 121 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 122 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 123 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 124 | --fp16 \ 125 | --p_max_len ${p_max_len} \ 126 | --per_device_eval_batch_size 1024 \ 127 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 128 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 129 | 130 | 131 | ## *************************************************** 132 | ## Search [Train-Positive] 133 | ## *************************************************** 134 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 135 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 136 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 137 | --index_num ${SplitNum} \ 138 | --batch_size 1024 \ 139 | --use_gpu \ 140 | --save_text \ 141 | --depth 200 \ 142 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 143 | --sub_split_num 5 \ 144 | ## if CUDA memory is not enough, set this augment. 145 | 146 | # *************************************************** 147 | # Filter [Train-Positive] & Generate LA Negatives 148 | # *************************************************** 149 | python ../preprocess/build_train_em_hn.py \ 150 | --tokenizer_name ${TOKENIZER} \ 151 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 152 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 153 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 154 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 155 | --n_sample 80 \ 156 | --depth 200 \ 157 | --mark la.hn \ 158 | 159 | 160 | # # ************************************* 161 | # # Combine ANN + LA Negatives 162 | # # ************************************* 163 | python ../preprocess/combine_nq_triviaqa_negative.py \ 164 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 165 | --input_folder_1 ${new_la_hn_file_name} \ 166 | --input_folder_2 ${new_ann_hn_file_name} \ 167 | --output_folder ${new_tele_file_name_wo_mom} \ -------------------------------------------------------------------------------- /shells/epi-1-mine-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=co-condenser-wiki 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=epi-1-tele-neg.triviaqa 12 | ## ************************************* 13 | 14 | ## ************************************* 15 | ## ENCODE Corpus GPUs 16 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 17 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 18 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 19 | ## Search Top-k GPUs 20 | SEARCH_CUDA="0,1,2,3,4" 21 | ## ************************************* 22 | ## Length SetUp 23 | export q_max_len=32 24 | export p_max_len=156 25 | ## ************************************* 26 | TOKENIZER=bert-base-uncased 27 | TOKENIZER_ID=bert 28 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 29 | ## ************************************* 30 | 31 | ## ********************************************** 32 | ## Infer 33 | ## ********************************************** 34 | ## Create Folder 35 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 36 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 37 | 38 | ## Encoding Corpus 39 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 40 | do 41 | ## ************************************* 42 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 43 | do 44 | ## ************************************* 45 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 46 | then 47 | break 2 48 | fi 49 | 50 | ## ************************************* 51 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 52 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 53 | echo ${OUTPUT_DIR}/${train_job_name} && 54 | echo split-${i} on gpu-${CUDA} && 55 | 56 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 57 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 58 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 59 | --fp16 \ 60 | --per_device_eval_batch_size 1024 \ 61 | --dataloader_num_workers 2 \ 62 | --p_max_len ${p_max_len} \ 63 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 64 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 65 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 66 | ## ************************************* 67 | sleep 3 & 68 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 69 | done 70 | done 71 | 72 | 73 | ## ************************************* 74 | ## Encode [Train Query] 75 | ## ************************************* 76 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 77 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 78 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 79 | --fp16 \ 80 | --q_max_len ${q_max_len} \ 81 | --encode_is_qry \ 82 | --per_device_eval_batch_size 1024 \ 83 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 84 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 85 | 86 | 87 | ## ************************************* 88 | ## Search [Train] 89 | ## ************************************* 90 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 91 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 92 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 93 | --index_num ${SplitNum} \ 94 | --batch_size 1024 \ 95 | --use_gpu \ 96 | --save_text \ 97 | --depth 200 \ 98 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 99 | --sub_split_num 5 \ 100 | ## if CUDA memory is not enough, set this augment. 101 | 102 | 103 | # # *************************************************** 104 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 105 | # # *************************************************** 106 | python ../preprocess/build_train_em_hn.py \ 107 | --tokenizer_name ${TOKENIZER} \ 108 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 109 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 110 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 111 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 112 | --n_sample 80 \ 113 | --depth 200 \ 114 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 115 | --mark hn \ 116 | 117 | 118 | # # *************************************************** 119 | # # Encode [Train-Positive] 120 | # # *************************************************** 121 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 122 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 123 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name} \ 124 | --fp16 \ 125 | --p_max_len ${p_max_len} \ 126 | --per_device_eval_batch_size 1024 \ 127 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 128 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 129 | 130 | 131 | ## *************************************************** 132 | ## Search [Train-Positive] 133 | ## *************************************************** 134 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 135 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 136 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 137 | --index_num ${SplitNum} \ 138 | --batch_size 1024 \ 139 | --use_gpu \ 140 | --save_text \ 141 | --depth 200 \ 142 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 143 | --sub_split_num 5 \ 144 | ## if CUDA memory is not enough, set this augment. 145 | 146 | # *************************************************** 147 | # Filter [Train-Positive] & Generate LA Negatives 148 | # *************************************************** 149 | python ../preprocess/build_train_em_hn.py \ 150 | --tokenizer_name ${TOKENIZER} \ 151 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 152 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 153 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 154 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 155 | --n_sample 80 \ 156 | --depth 200 \ 157 | --mark la.hn \ 158 | 159 | 160 | # # ************************************* 161 | # # Combine ANN + LA Negatives 162 | # # ************************************* 163 | python ../preprocess/combine_nq_triviaqa_negative.py \ 164 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 165 | --input_folder_1 ${new_la_hn_file_name} \ 166 | --input_folder_2 ${new_ann_hn_file_name} \ 167 | --output_folder ${new_tele_file_name_wo_mom} \ -------------------------------------------------------------------------------- /shells/epi-2-mine-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=epi-1.ance-tele.nq.checkp-2000 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 12 | 13 | export mom_tele_file_name=epi-1-tele-neg.nq 14 | export new_tele_file_name=epi-2-tele-neg.nq 15 | ## ************************************* 16 | 17 | ## ************************************* 18 | ## ENCODE Corpus GPUs 19 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 20 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 21 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 22 | ## Search Top-k GPUs 23 | SEARCH_CUDA="0,1,2,3,4" 24 | ## ************************************* 25 | ## Length SetUp 26 | export q_max_len=32 27 | export p_max_len=156 28 | ## ************************************* 29 | TOKENIZER=bert-base-uncased 30 | TOKENIZER_ID=bert 31 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 32 | ## ************************************* 33 | 34 | ## ********************************************** 35 | ## Infer 36 | ## ********************************************** 37 | ## Create Folder 38 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 39 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 40 | 41 | ## Encoding Corpus 42 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 43 | do 44 | ## ************************************* 45 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 46 | do 47 | ## ************************************* 48 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 49 | then 50 | break 2 51 | fi 52 | 53 | ## ************************************* 54 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 55 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 56 | echo ${OUTPUT_DIR}/${train_job_name} && 57 | echo split-${i} on gpu-${CUDA} && 58 | 59 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 60 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 61 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 62 | --fp16 \ 63 | --per_device_eval_batch_size 1024 \ 64 | --dataloader_num_workers 2 \ 65 | --p_max_len ${p_max_len} \ 66 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 67 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 68 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 69 | ## ************************************* 70 | sleep 3 & 71 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 72 | done 73 | done 74 | 75 | 76 | ## ************************************* 77 | ## Encode [Train Query] 78 | ## ************************************* 79 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 80 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 81 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/query_model \ 82 | --fp16 \ 83 | --q_max_len ${q_max_len} \ 84 | --encode_is_qry \ 85 | --per_device_eval_batch_size 1024 \ 86 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 87 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 88 | 89 | 90 | ## ************************************* 91 | ## Search [Train] 92 | ## ************************************* 93 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 94 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 95 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 96 | --index_num ${SplitNum} \ 97 | --batch_size 1024 \ 98 | --use_gpu \ 99 | --save_text \ 100 | --depth 200 \ 101 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 102 | --sub_split_num 5 \ 103 | ## if CUDA memory is not enough, set this augment. 104 | 105 | 106 | # # *************************************************** 107 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 108 | # # *************************************************** 109 | python ../preprocess/build_train_em_hn.py \ 110 | --tokenizer_name ${TOKENIZER} \ 111 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 112 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 113 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 114 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 115 | --n_sample 80 \ 116 | --depth 200 \ 117 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 118 | --mark hn \ 119 | 120 | 121 | # # *************************************************** 122 | # # Encode [Train-Positive] 123 | # # *************************************************** 124 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 125 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 126 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 127 | --fp16 \ 128 | --p_max_len ${p_max_len} \ 129 | --per_device_eval_batch_size 1024 \ 130 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 131 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 132 | 133 | 134 | ## *************************************************** 135 | ## Search [Train-Positive] 136 | ## *************************************************** 137 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 138 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 139 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 140 | --index_num ${SplitNum} \ 141 | --batch_size 1024 \ 142 | --use_gpu \ 143 | --save_text \ 144 | --depth 200 \ 145 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 146 | --sub_split_num 5 \ 147 | ## if CUDA memory is not enough, set this augment. 148 | 149 | # *************************************************** 150 | # Filter [Train-Positive] & Generate LA Negatives 151 | # *************************************************** 152 | python ../preprocess/build_train_em_hn.py \ 153 | --tokenizer_name ${TOKENIZER} \ 154 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 155 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 156 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 157 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 158 | --n_sample 80 \ 159 | --depth 200 \ 160 | --mark la.hn \ 161 | 162 | 163 | # # ************************************* 164 | # # Combine ANN + LA Negatives 165 | # # ************************************* 166 | python ../preprocess/combine_nq_triviaqa_negative.py \ 167 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 168 | --input_folder_1 ${new_la_hn_file_name} \ 169 | --input_folder_2 ${new_ann_hn_file_name} \ 170 | --output_folder ${new_tele_file_name_wo_mom} \ 171 | 172 | 173 | # # ************************************* 174 | # # Combine (ANN + LA + Mom) Negatives 175 | # # ************************************* 176 | python ../preprocess/combine_nq_triviaqa_negative.py \ 177 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 178 | --input_folder_1 ${mom_tele_file_name} \ 179 | --input_folder_2 ${new_tele_file_name_wo_mom} \ 180 | --output_folder ${new_tele_file_name} \ -------------------------------------------------------------------------------- /shells/epi-3-mine-nq.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/nq 2 | export OUTPUT_DIR=/home/sunsi/experiments/nq-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=epi-2.ance-tele.nq.checkp-2000 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 12 | 13 | export mom_tele_file_name=epi-2-tele-neg.nq 14 | export new_tele_file_name=epi-3-tele-neg.nq 15 | ## ************************************* 16 | 17 | ## ************************************* 18 | ## ENCODE Corpus GPUs 19 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 20 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 21 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 22 | ## Search Top-k GPUs 23 | SEARCH_CUDA="0,1,2,3,4" 24 | ## ************************************* 25 | ## Length SetUp 26 | export q_max_len=32 27 | export p_max_len=156 28 | ## ************************************* 29 | TOKENIZER=bert-base-uncased 30 | TOKENIZER_ID=bert 31 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 32 | ## ************************************* 33 | 34 | ## ********************************************** 35 | ## Infer 36 | ## ********************************************** 37 | ## Create Folder 38 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 39 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 40 | 41 | ## Encoding Corpus 42 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 43 | do 44 | ## ************************************* 45 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 46 | do 47 | ## ************************************* 48 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 49 | then 50 | break 2 51 | fi 52 | 53 | ## ************************************* 54 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 55 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 56 | echo ${OUTPUT_DIR}/${train_job_name} && 57 | echo split-${i} on gpu-${CUDA} && 58 | 59 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 60 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 61 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 62 | --fp16 \ 63 | --per_device_eval_batch_size 1024 \ 64 | --dataloader_num_workers 2 \ 65 | --p_max_len ${p_max_len} \ 66 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 67 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 68 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 69 | ## ************************************* 70 | sleep 3 & 71 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 72 | done 73 | done 74 | 75 | 76 | ## ************************************* 77 | ## Encode [Train Query] 78 | ## ************************************* 79 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 80 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 81 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/query_model \ 82 | --fp16 \ 83 | --q_max_len ${q_max_len} \ 84 | --encode_is_qry \ 85 | --per_device_eval_batch_size 1024 \ 86 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 87 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 88 | 89 | 90 | ## ************************************* 91 | ## Search [Train] 92 | ## ************************************* 93 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 94 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 95 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 96 | --index_num ${SplitNum} \ 97 | --batch_size 1024 \ 98 | --use_gpu \ 99 | --save_text \ 100 | --depth 200 \ 101 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 102 | --sub_split_num 5 \ 103 | ## if CUDA memory is not enough, set this augment. 104 | 105 | 106 | # # *************************************************** 107 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 108 | # # *************************************************** 109 | python ../preprocess/build_train_em_hn.py \ 110 | --tokenizer_name ${TOKENIZER} \ 111 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 112 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 113 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 114 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 115 | --n_sample 80 \ 116 | --depth 200 \ 117 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 118 | --mark hn \ 119 | 120 | 121 | # # *************************************************** 122 | # # Encode [Train-Positive] 123 | # # *************************************************** 124 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 125 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 126 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 127 | --fp16 \ 128 | --p_max_len ${p_max_len} \ 129 | --per_device_eval_batch_size 1024 \ 130 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 131 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 132 | 133 | 134 | ## *************************************************** 135 | ## Search [Train-Positive] 136 | ## *************************************************** 137 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 138 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 139 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 140 | --index_num ${SplitNum} \ 141 | --batch_size 1024 \ 142 | --use_gpu \ 143 | --save_text \ 144 | --depth 200 \ 145 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 146 | --sub_split_num 5 \ 147 | ## if CUDA memory is not enough, set this augment. 148 | 149 | # *************************************************** 150 | # Filter [Train-Positive] & Generate LA Negatives 151 | # *************************************************** 152 | python ../preprocess/build_train_em_hn.py \ 153 | --tokenizer_name ${TOKENIZER} \ 154 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 155 | --queries ${DATA_DIR}/nq-train-qrels.jsonl \ 156 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 157 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 158 | --n_sample 80 \ 159 | --depth 200 \ 160 | --mark la.hn \ 161 | 162 | 163 | # # ************************************* 164 | # # Combine ANN + LA Negatives 165 | # # ************************************* 166 | python ../preprocess/combine_nq_triviaqa_negative.py \ 167 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 168 | --input_folder_1 ${new_la_hn_file_name} \ 169 | --input_folder_2 ${new_ann_hn_file_name} \ 170 | --output_folder ${new_tele_file_name_wo_mom} \ 171 | 172 | 173 | # # ************************************* 174 | # # Combine (ANN + LA + Mom) Negatives 175 | # # ************************************* 176 | python ../preprocess/combine_nq_triviaqa_negative.py \ 177 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 178 | --input_folder_1 ${mom_tele_file_name} \ 179 | --input_folder_2 ${new_tele_file_name_wo_mom} \ 180 | --output_folder ${new_tele_file_name} \ -------------------------------------------------------------------------------- /shells/epi-2-mine-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=epi-1.ance-tele.triviaqa.checkp-2000 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 12 | 13 | export mom_tele_file_name=epi-1-tele-neg.triviaqa 14 | export new_tele_file_name=epi-2-tele-neg.triviaqa 15 | ## ************************************* 16 | 17 | ## ************************************* 18 | ## ENCODE Corpus GPUs 19 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 20 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 21 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 22 | ## Search Top-k GPUs 23 | SEARCH_CUDA="0,1,2,3,4" 24 | ## ************************************* 25 | ## Length SetUp 26 | export q_max_len=32 27 | export p_max_len=156 28 | ## ************************************* 29 | TOKENIZER=bert-base-uncased 30 | TOKENIZER_ID=bert 31 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 32 | ## ************************************* 33 | 34 | ## ********************************************** 35 | ## Infer 36 | ## ********************************************** 37 | ## Create Folder 38 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 39 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 40 | 41 | ## Encoding Corpus 42 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 43 | do 44 | ## ************************************* 45 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 46 | do 47 | ## ************************************* 48 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 49 | then 50 | break 2 51 | fi 52 | 53 | ## ************************************* 54 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 55 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 56 | echo ${OUTPUT_DIR}/${train_job_name} && 57 | echo split-${i} on gpu-${CUDA} && 58 | 59 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 60 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 61 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 62 | --fp16 \ 63 | --per_device_eval_batch_size 1024 \ 64 | --dataloader_num_workers 2 \ 65 | --p_max_len ${p_max_len} \ 66 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 67 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 68 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 69 | ## ************************************* 70 | sleep 3 & 71 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 72 | done 73 | done 74 | 75 | 76 | ## ************************************* 77 | ## Encode [Train Query] 78 | ## ************************************* 79 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 80 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 81 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/query_model \ 82 | --fp16 \ 83 | --q_max_len ${q_max_len} \ 84 | --encode_is_qry \ 85 | --per_device_eval_batch_size 1024 \ 86 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 87 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 88 | 89 | 90 | ## ************************************* 91 | ## Search [Train] 92 | ## ************************************* 93 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 94 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 95 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 96 | --index_num ${SplitNum} \ 97 | --batch_size 1024 \ 98 | --use_gpu \ 99 | --save_text \ 100 | --depth 200 \ 101 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 102 | --sub_split_num 5 \ 103 | ## if CUDA memory is not enough, set this augment. 104 | 105 | 106 | # # *************************************************** 107 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 108 | # # *************************************************** 109 | python ../preprocess/build_train_em_hn.py \ 110 | --tokenizer_name ${TOKENIZER} \ 111 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 112 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 113 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 114 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 115 | --n_sample 80 \ 116 | --depth 200 \ 117 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 118 | --mark hn \ 119 | 120 | 121 | # # *************************************************** 122 | # # Encode [Train-Positive] 123 | # # *************************************************** 124 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 125 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 126 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 127 | --fp16 \ 128 | --p_max_len ${p_max_len} \ 129 | --per_device_eval_batch_size 1024 \ 130 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 131 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 132 | 133 | 134 | ## *************************************************** 135 | ## Search [Train-Positive] 136 | ## *************************************************** 137 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 138 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 139 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 140 | --index_num ${SplitNum} \ 141 | --batch_size 1024 \ 142 | --use_gpu \ 143 | --save_text \ 144 | --depth 200 \ 145 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 146 | --sub_split_num 5 \ 147 | ## if CUDA memory is not enough, set this augment. 148 | 149 | # *************************************************** 150 | # Filter [Train-Positive] & Generate LA Negatives 151 | # *************************************************** 152 | python ../preprocess/build_train_em_hn.py \ 153 | --tokenizer_name ${TOKENIZER} \ 154 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 155 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 156 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 157 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 158 | --n_sample 80 \ 159 | --depth 200 \ 160 | --mark la.hn \ 161 | 162 | 163 | # # ************************************* 164 | # # Combine ANN + LA Negatives 165 | # # ************************************* 166 | python ../preprocess/combine_nq_triviaqa_negative.py \ 167 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 168 | --input_folder_1 ${new_la_hn_file_name} \ 169 | --input_folder_2 ${new_ann_hn_file_name} \ 170 | --output_folder ${new_tele_file_name_wo_mom} \ 171 | 172 | 173 | # # ************************************* 174 | # # Combine (ANN + LA + Mom) Negatives 175 | # # ************************************* 176 | python ../preprocess/combine_nq_triviaqa_negative.py \ 177 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 178 | --input_folder_1 ${mom_tele_file_name} \ 179 | --input_folder_2 ${new_tele_file_name_wo_mom} \ 180 | --output_folder ${new_tele_file_name} \ -------------------------------------------------------------------------------- /shells/epi-3-mine-triviaqa.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/home/sunsi/dataset/triviaqa 2 | export OUTPUT_DIR=/home/sunsi/experiments/triviaqa-results 3 | export CORPUS_DATA_DIR=/home/sunsi/dataset/wikipedia-corpus-index 4 | ## ************************************* 5 | ## INPUT/OUTPUT 6 | export train_job_name=epi-2.ance-tele.triviaqa.checkp-2000 7 | export infer_job_name=inference.${train_job_name} 8 | ## OUTPUT 9 | export new_ann_hn_file_name=ann-neg.${train_job_name} 10 | export new_la_hn_file_name=la-neg.${train_job_name} 11 | export new_tele_file_name_wo_mom=ann-la-neg.${train_job_name} 12 | 13 | export mom_tele_file_name=epi-2-tele-neg.triviaqa 14 | export new_tele_file_name=epi-3-tele-neg.triviaqa 15 | ## ************************************* 16 | 17 | ## ************************************* 18 | ## ENCODE Corpus GPUs 19 | ENCODE_CUDA="0,1,2,3,4" ## ENCODE_CUDA="0" 20 | ENCODE_CUDAs=(${ENCODE_CUDA//,/ }) 21 | ENCODE_CUDA_NUM=${#ENCODE_CUDAs[@]} 22 | ## Search Top-k GPUs 23 | SEARCH_CUDA="0,1,2,3,4" 24 | ## ************************************* 25 | ## Length SetUp 26 | export q_max_len=32 27 | export p_max_len=156 28 | ## ************************************* 29 | TOKENIZER=bert-base-uncased 30 | TOKENIZER_ID=bert 31 | SplitNum=20 ## Wikipedia is splited into 20 sub-files 32 | ## ************************************* 33 | 34 | ## ********************************************** 35 | ## Infer 36 | ## ********************************************** 37 | ## Create Folder 38 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/corpus 39 | mkdir -p ${OUTPUT_DIR}/${infer_job_name}/query 40 | 41 | ## Encoding Corpus 42 | for((tmp=0; tmp<$SplitNum; tmp+=$ENCODE_CUDA_NUM)) 43 | do 44 | ## ************************************* 45 | for((CUDA_INDEX=0; CUDA_INDEX<$ENCODE_CUDA_NUM; CUDA_INDEX++)) 46 | do 47 | ## ************************************* 48 | if [ $[CUDA_INDEX + $tmp] -eq $SplitNum ] 49 | then 50 | break 2 51 | fi 52 | 53 | ## ************************************* 54 | printf -v i "%02g" $[CUDA_INDEX + $tmp] && 55 | CUDA=${ENCODE_CUDAs[$CUDA_INDEX]} && 56 | echo ${OUTPUT_DIR}/${train_job_name} && 57 | echo split-${i} on gpu-${CUDA} && 58 | 59 | CUDA_VISIBLE_DEVICES=${CUDA} python ../ancetele/encode.py \ 60 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 61 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 62 | --fp16 \ 63 | --per_device_eval_batch_size 1024 \ 64 | --dataloader_num_workers 2 \ 65 | --p_max_len ${p_max_len} \ 66 | --encode_in_path ${CORPUS_DATA_DIR}/${TOKENIZER_ID}/corpus/split${i}.json \ 67 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.pt &> \ 68 | ${OUTPUT_DIR}/${infer_job_name}/corpus/split${i}.log && 69 | ## ************************************* 70 | sleep 3 & 71 | [ $CUDA_INDEX -eq `expr $ENCODE_CUDA_NUM - 1` ] && wait 72 | done 73 | done 74 | 75 | 76 | ## ************************************* 77 | ## Encode [Train Query] 78 | ## ************************************* 79 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 80 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 81 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/query_model \ 82 | --fp16 \ 83 | --q_max_len ${q_max_len} \ 84 | --encode_is_qry \ 85 | --per_device_eval_batch_size 1024 \ 86 | --encode_in_path ${DATA_DIR}/${TOKENIZER_ID}/query/train.query.json \ 87 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 88 | 89 | 90 | ## ************************************* 91 | ## Search [Train] 92 | ## ************************************* 93 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 94 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.query.pt \ 95 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 96 | --index_num ${SplitNum} \ 97 | --batch_size 1024 \ 98 | --use_gpu \ 99 | --save_text \ 100 | --depth 200 \ 101 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 102 | --sub_split_num 5 \ 103 | ## if CUDA memory is not enough, set this augment. 104 | 105 | 106 | # # *************************************************** 107 | # # Filter [Train] & Generate ANN Negatives & Generate [Train-Positive] 108 | # # *************************************************** 109 | python ../preprocess/build_train_em_hn.py \ 110 | --tokenizer_name ${TOKENIZER} \ 111 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.rank.tsv \ 112 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 113 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 114 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_ann_hn_file_name} \ 115 | --n_sample 80 \ 116 | --depth 200 \ 117 | --gen_pos_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 118 | --mark hn \ 119 | 120 | 121 | # # *************************************************** 122 | # # Encode [Train-Positive] 123 | # # *************************************************** 124 | CUDA_VISIBLE_DEVICES=${ENCODE_CUDAs[-1]} python ../ancetele/encode.py \ 125 | --output_dir ${OUTPUT_DIR}/${infer_job_name} \ 126 | --model_name_or_path ${OUTPUT_DIR}/${train_job_name}/passage_model \ 127 | --fp16 \ 128 | --p_max_len ${p_max_len} \ 129 | --per_device_eval_batch_size 1024 \ 130 | --encode_in_path ${OUTPUT_DIR}/${infer_job_name}/train.positives.json \ 131 | --encoded_save_path ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 132 | 133 | 134 | ## *************************************************** 135 | ## Search [Train-Positive] 136 | ## *************************************************** 137 | CUDA_VISIBLE_DEVICES=${SEARCH_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \ 138 | --query_reps ${OUTPUT_DIR}/${infer_job_name}/query/train.positives.pt \ 139 | --passage_reps ${OUTPUT_DIR}/${infer_job_name}/corpus/'*.pt' \ 140 | --index_num ${SplitNum} \ 141 | --batch_size 1024 \ 142 | --use_gpu \ 143 | --save_text \ 144 | --depth 200 \ 145 | --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 146 | --sub_split_num 5 \ 147 | ## if CUDA memory is not enough, set this augment. 148 | 149 | # *************************************************** 150 | # Filter [Train-Positive] & Generate LA Negatives 151 | # *************************************************** 152 | python ../preprocess/build_train_em_hn.py \ 153 | --tokenizer_name ${TOKENIZER} \ 154 | --input_file ${OUTPUT_DIR}/${infer_job_name}/train.positives.rank.tsv \ 155 | --queries ${DATA_DIR}/triviaqa-train-qrels.jsonl \ 156 | --collection ${CORPUS_DATA_DIR}/psgs_w100.tsv \ 157 | --save_to ${DATA_DIR}/${TOKENIZER_ID}/${new_la_hn_file_name} \ 158 | --n_sample 80 \ 159 | --depth 200 \ 160 | --mark la.hn \ 161 | 162 | 163 | # # ************************************* 164 | # # Combine ANN + LA Negatives 165 | # # ************************************* 166 | python ../preprocess/combine_nq_triviaqa_negative.py \ 167 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 168 | --input_folder_1 ${new_la_hn_file_name} \ 169 | --input_folder_2 ${new_ann_hn_file_name} \ 170 | --output_folder ${new_tele_file_name_wo_mom} \ 171 | 172 | 173 | # # ************************************* 174 | # # Combine (ANN + LA + Mom) Negatives 175 | # # ************************************* 176 | python ../preprocess/combine_nq_triviaqa_negative.py \ 177 | --data_dir ${DATA_DIR}/${TOKENIZER_ID} \ 178 | --input_folder_1 ${mom_tele_file_name} \ 179 | --input_folder_2 ${new_tele_file_name_wo_mom} \ 180 | --output_folder ${new_tele_file_name} \ -------------------------------------------------------------------------------- /ancetele/faiss_retriever/do_retrieval.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import gc 4 | import time 5 | import numpy as np 6 | import glob 7 | from argparse import ArgumentParser 8 | from itertools import chain 9 | from tqdm import tqdm 10 | 11 | from retriever import BaseFaissIPRetriever 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | logging.basicConfig( 16 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 17 | datefmt="%m/%d/%Y %H:%M:%S", 18 | level=logging.INFO, 19 | ) 20 | 21 | 22 | def search_queries(retriever, q_reps, p_lookup, args): 23 | if args.batch_size > 0: 24 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size) 25 | else: 26 | all_scores, all_indices = retriever.search(q_reps, args.depth) 27 | 28 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in tqdm(all_indices)] 29 | psg_indices = np.array(psg_indices) 30 | return all_scores, psg_indices 31 | 32 | 33 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): 34 | with open(ranking_save_file, 'w') as f: 35 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): 36 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 37 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 38 | for s, idx in score_list: 39 | f.write(f'{qid}\t{idx}\t{s}\n') 40 | 41 | 42 | def write_trec(q_lookup, qid2docids, ranking_save_file, depth): 43 | with open(ranking_save_file, 'w') as f: 44 | for qid in q_lookup: 45 | score_list = qid2docids[qid] 46 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 47 | for s, idx in score_list[:depth]: 48 | f.write(f'{qid}\t{idx}\t{s}\n') 49 | 50 | def pickle_load(path): 51 | with open(path, 'rb') as f: 52 | obj = pickle.load(f) 53 | return obj 54 | 55 | 56 | def pickle_save(obj, path): 57 | with open(path, 'wb') as f: 58 | pickle.dump(obj, f) 59 | 60 | 61 | def main(): 62 | parser = ArgumentParser() 63 | parser.add_argument('--query_reps', required=True) 64 | parser.add_argument('--passage_reps', required=True) 65 | parser.add_argument('--batch_size', type=int, default=128) 66 | parser.add_argument('--index_num', type=int, required=True) 67 | parser.add_argument('--use_gpu', action='store_true') 68 | parser.add_argument('--depth', type=int, default=1000) 69 | parser.add_argument('--save_ranking_to', required=True) 70 | parser.add_argument('--save_text', action='store_true') 71 | parser.add_argument('--sub_split_num', type=int, default=None) 72 | 73 | args = parser.parse_args() 74 | 75 | ## ******************************************* 76 | ## Single Search 77 | ## ******************************************* 78 | if args.sub_split_num is None: 79 | index_files = glob.glob(args.passage_reps) 80 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') 81 | 82 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 83 | retriever = BaseFaissIPRetriever(p_reps_0, args.use_gpu) 84 | 85 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 86 | if len(index_files) > 1: 87 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 88 | 89 | assert len(index_files) == args.index_num 90 | 91 | p_reps = [] 92 | look_up = [] 93 | for _p_reps, p_lookup in shards: 94 | p_reps.append(_p_reps) 95 | look_up += p_lookup 96 | 97 | p_reps = np.concatenate(p_reps, axis=0) 98 | retriever.add(p_reps) 99 | 100 | q_reps, q_lookup = pickle_load(args.query_reps) 101 | q_reps = q_reps 102 | 103 | logger.info('Index Search Start') 104 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) 105 | logger.info('Index Search Finished') 106 | 107 | if args.save_text: 108 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) 109 | else: 110 | pickle_save((all_scores, psg_indices), args.save_ranking_to) 111 | 112 | ## ******************************************* 113 | ## Split Search 114 | ## ******************************************* 115 | else: 116 | print("split corpus search!") 117 | 118 | ## Load qry 119 | q_reps, q_lookup = pickle_load(args.query_reps) 120 | q_reps = q_reps 121 | 122 | ## Load corpus 123 | filenames = [[filename, int(filename.split("/")[-1].strip(".pt")[-2:])] for filename in glob.glob(args.passage_reps)] 124 | sorted_filenames = sorted(filenames, key=lambda item:item[1]) 125 | tot_index_files = [item[0] for item in sorted_filenames] 126 | 127 | assert len(tot_index_files) == args.index_num 128 | logger.info(f'Pattern match found {len(tot_index_files)} files; loading them into index.') 129 | 130 | ## container 131 | merge_qid2docids = {qid:[] for qid in q_lookup} 132 | 133 | search_time = round(len(tot_index_files) / args.sub_split_num) 134 | for search_idx in range(search_time): 135 | index_files = tot_index_files[args.sub_split_num*search_idx:args.sub_split_num*(search_idx+1)] 136 | print("searching ", search_idx+1, " total: ", search_time) 137 | 138 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 139 | retriever = BaseFaissIPRetriever(p_reps_0, args.use_gpu) 140 | 141 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 142 | if len(index_files) > 1: 143 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 144 | 145 | assert len(index_files) == len(index_files) 146 | 147 | p_reps = [] 148 | look_up = [] 149 | for _p_reps, p_lookup in shards: 150 | p_reps.append(_p_reps) 151 | look_up += p_lookup 152 | 153 | p_reps = np.concatenate(p_reps, axis=0) 154 | retriever.add(p_reps) 155 | 156 | logger.info('Index Search Start: {}/{}'.format(search_idx+1, search_time)) 157 | sub_scores, sub_psg_indices = search_queries(retriever, q_reps, look_up, args) 158 | logger.info('Index Search Finished: {}/{}'.format(search_idx+1, search_time)) 159 | 160 | ## Merge 161 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, sub_scores, sub_psg_indices): 162 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 163 | merge_qid2docids[qid].extend(score_list) 164 | 165 | # merge_qid2docids[qid] = sorted(merge_qid2docids[qid], key=lambda x: x[0], reverse=True) 166 | # merge_qid2docids[qid] = merge_qid2docids[qid][:args.depth] 167 | 168 | del retriever 169 | gc.collect() 170 | torch.cuda.empty_cache() 171 | torch.cuda.synchronize() 172 | time.sleep(5) # just in case the gpu has not cleaned up the memory 173 | torch.cuda.reset_peak_memory_stats() 174 | 175 | 176 | write_trec(q_lookup, merge_qid2docids, args.save_ranking_to, depth=args.depth) 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | main() 182 | -------------------------------------------------------------------------------- /scripts/ms_marco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH! 3 | Command line: 4 | python msmarco_eval_ranking.py 5 | 6 | Creation Date : 06/12/2018 7 | Last Modified : 4/09/2019 8 | Authors : Daniel Campos , Rutger van Haasteren 9 | """ 10 | import sys 11 | import statistics 12 | 13 | from collections import Counter 14 | 15 | MaxMRRRank = 10 16 | 17 | def load_reference_from_stream(f): 18 | """Load Reference reference relevant passages 19 | Args:f (stream): stream to load. 20 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 21 | """ 22 | qids_to_relevant_passageids = {} 23 | for l in f: 24 | try: 25 | l = l.strip().split('\t') 26 | qid = int(l[0]) 27 | if qid in qids_to_relevant_passageids: 28 | pass 29 | else: 30 | qids_to_relevant_passageids[qid] = [] 31 | qids_to_relevant_passageids[qid].append(int(l[2])) 32 | except: 33 | raise IOError('\"%s\" is not valid format' % l) 34 | return qids_to_relevant_passageids 35 | 36 | def load_reference(path_to_reference): 37 | """Load Reference reference relevant passages 38 | Args:path_to_reference (str): path to a file to load. 39 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 40 | """ 41 | with open(path_to_reference,'r') as f: 42 | qids_to_relevant_passageids = load_reference_from_stream(f) 43 | return qids_to_relevant_passageids 44 | 45 | def load_candidate_from_stream(f): 46 | """Load candidate data from a stream. 47 | Args:f (stream): stream to load. 48 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 49 | """ 50 | qid_to_ranked_candidate_passages = {} 51 | for l in f: 52 | try: 53 | l = l.strip().split('\t') 54 | qid = int(l[0]) 55 | pid = int(l[1]) 56 | rank = int(l[2]) 57 | if qid in qid_to_ranked_candidate_passages: 58 | pass 59 | else: 60 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 61 | tmp = [0] * 1000 62 | qid_to_ranked_candidate_passages[qid] = tmp 63 | qid_to_ranked_candidate_passages[qid][rank-1]=pid 64 | except: 65 | raise IOError('\"%s\" is not valid format' % l) 66 | return qid_to_ranked_candidate_passages 67 | 68 | def load_candidate(path_to_candidate): 69 | """Load candidate data from a file. 70 | Args:path_to_candidate (str): path to file to load. 71 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 72 | """ 73 | 74 | with open(path_to_candidate,'r') as f: 75 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 76 | return qid_to_ranked_candidate_passages 77 | 78 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 79 | """Perform quality checks on the dictionaries 80 | 81 | Args: 82 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 83 | Dict as read in with load_reference or load_reference_from_stream 84 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 85 | Returns: 86 | bool,str: Boolean whether allowed, message to be shown in case of a problem 87 | """ 88 | message = '' 89 | allowed = True 90 | 91 | # Create sets of the QIDs for the submitted and reference queries 92 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 93 | ref_set = set(qids_to_relevant_passageids.keys()) 94 | 95 | # Check that we do not have multiple passages per query 96 | for qid in qids_to_ranked_candidate_passages: 97 | # Remove all zeros from the candidates 98 | duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 99 | 100 | if len(duplicate_pids-set([0])) > 0: 101 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 102 | qid=qid, pid=list(duplicate_pids)[0]) 103 | allowed = False 104 | 105 | return allowed, message 106 | 107 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 108 | """Compute MRR metric 109 | Args: 110 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 111 | Dict as read in with load_reference or load_reference_from_stream 112 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 113 | Returns: 114 | dict: dictionary of metrics {'MRR': } 115 | """ 116 | all_scores = {} 117 | MRR = 0 118 | qids_with_relevant_passages = 0 119 | ranking = [] 120 | for qid in qids_to_ranked_candidate_passages: 121 | if qid in qids_to_relevant_passageids: 122 | ranking.append(0) 123 | target_pid = qids_to_relevant_passageids[qid] 124 | candidate_pid = qids_to_ranked_candidate_passages[qid] 125 | for i in range(0,MaxMRRRank): 126 | if candidate_pid[i] in target_pid: 127 | MRR += 1/(i + 1) 128 | ranking.pop() 129 | ranking.append(i+1) 130 | break 131 | if len(ranking) == 0: 132 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 133 | 134 | MRR = MRR/len(qids_to_relevant_passageids) 135 | all_scores['MRR @10'] = MRR 136 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 137 | return all_scores 138 | 139 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 140 | """Compute MRR metric 141 | Args: 142 | p_path_to_reference_file (str): path to reference file. 143 | Reference file should contain lines in the following format: 144 | QUERYID\tPASSAGEID 145 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 146 | p_path_to_candidate_file (str): path to candidate file. 147 | Candidate file sould contain lines in the following format: 148 | QUERYID\tPASSAGEID1\tRank 149 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 150 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 151 | Where the values are separated by tabs and ranked in order of relevance 152 | Returns: 153 | dict: dictionary of metrics {'MRR': } 154 | """ 155 | 156 | qids_to_relevant_passageids = load_reference(path_to_reference) 157 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 158 | if perform_checks: 159 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 160 | if message != '': print(message) 161 | 162 | # ## SS 163 | # print("qids_to_ranked_candidate_passages") 164 | # print(qids_to_ranked_candidate_passages) 165 | 166 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 167 | 168 | def main(): 169 | """Command line: 170 | python msmarco_eval_ranking.py 171 | """ 172 | path_to_candidate = sys.argv[2] 173 | path_to_reference = sys.argv[1] 174 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 175 | print('#####################') 176 | for metric in sorted(metrics): 177 | print('{}: {}'.format(metric, metrics[metric])) 178 | print('#####################') 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /preprocess/build_train_em_hn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import datasets 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | from multiprocessing import Pool 8 | from dataclasses import dataclass 9 | from argparse import ArgumentParser 10 | from transformers import AutoTokenizer 11 | from transformers import PreTrainedTokenizer 12 | from pyserini.eval.evaluate_dpr_retrieval import SimpleTokenizer, has_answers 13 | 14 | 15 | @dataclass 16 | class WikiTrainPreProcessor: 17 | query_file: str 18 | collection_file: str 19 | tokenizer: PreTrainedTokenizer 20 | 21 | max_length: int = 256 22 | columns = ['text_id', 'text', 'title'] ## psgs_w100.tsv split 23 | title_field = 'title' 24 | text_field = 'text' 25 | 26 | def __post_init__(self): 27 | 28 | ## qid: qry, ans, pos 29 | self.queries = self.read_queries(self.query_file) 30 | 31 | ## Corpus 32 | self.collection = datasets.load_dataset( 33 | 'csv', 34 | data_files=self.collection_file, 35 | cache_dir=self.collection_file+".cache", 36 | column_names=self.columns, 37 | delimiter='\t', 38 | )['train'] 39 | 40 | 41 | @staticmethod 42 | def read_queries(query_file): 43 | queries = datasets.load_dataset( 44 | 'json', 45 | 'default', 46 | data_files=query_file, 47 | cache_dir=query_file+".cache", 48 | )['train'] 49 | 50 | qid2queries = {} 51 | for item in queries: 52 | qid = int(item["qid"]) 53 | qid2queries[qid] = item 54 | return qid2queries 55 | 56 | 57 | def get_query(self, qid): 58 | query_encoded = self.tokenizer.encode( 59 | self.queries[qid]["question"], 60 | add_special_tokens=False, 61 | max_length=self.max_length, 62 | truncation=True 63 | ) 64 | return query_encoded 65 | 66 | 67 | def get_passage(self, p): 68 | entry = self.collection[p] 69 | title = entry[self.title_field] 70 | title = "" if title is None else title 71 | body = entry[self.text_field] 72 | content = title + self.tokenizer.sep_token + body 73 | 74 | passage_encoded = self.tokenizer.encode( 75 | content, 76 | add_special_tokens=False, 77 | max_length=self.max_length, 78 | truncation=True 79 | ) 80 | 81 | return passage_encoded 82 | 83 | 84 | def tokenize_passage(self, title, body): 85 | title = "" if title is None else title 86 | content = title + self.tokenizer.sep_token + body 87 | 88 | passage_encoded = self.tokenizer.encode( 89 | content, 90 | add_special_tokens=False, 91 | max_length=self.max_length, 92 | truncation=True 93 | ) 94 | return passage_encoded 95 | 96 | 97 | def process_one(self, train): 98 | q, origin_pp, pp, nn = train 99 | 100 | if len(origin_pp) > 0: 101 | positives = [ 102 | self.tokenize_passage( 103 | origin_p["title"], 104 | origin_p["text"]) for origin_p in origin_pp 105 | ] 106 | else: 107 | positives = [self.get_passage(p) for p in pp] 108 | 109 | train_example = { 110 | 'qid': q, 111 | 'query': self.get_query(q), 112 | 'positives': positives, 113 | 'negatives': [self.get_passage(n) for n in nn], 114 | } 115 | 116 | return json.dumps(train_example) 117 | 118 | 119 | 120 | def load_ranking( 121 | rank_file, 122 | queries, 123 | collection, 124 | em_tokenizer, 125 | n_sample, 126 | depth, 127 | minimum_negatives=1 128 | ): 129 | with open(rank_file) as rf: 130 | lines = iter(rf) 131 | q_0, p_0, _ = next(lines).strip().split() 132 | q_0 = int(q_0) 133 | p_0 = int(p_0) 134 | 135 | curr_q = q_0 136 | content = collection[p_0]['text'] ## only match main body text not title 137 | answers = queries[q_0]["answers"] 138 | 139 | negatives = [] 140 | new_positives = [] 141 | if not has_answers(content, answers, em_tokenizer, regex=False): 142 | negatives.append(p_0) 143 | else: 144 | new_positives.append(p_0) 145 | 146 | while True: 147 | try: 148 | q, p, _ = next(lines).strip().split() 149 | q = int(q) 150 | p = int(p) 151 | ## ************************* 152 | ## Time to finish curr_q ! 153 | ## ************************* 154 | if q != curr_q: 155 | 156 | ## Positive 157 | origin_positives = queries[curr_q]["positive_ctxs"] ## {"title": , "text", } 158 | 159 | ## Negative 160 | negatives = negatives[:depth] 161 | random.shuffle(negatives) 162 | 163 | if (len(origin_positives) + len(new_positives)) >= 1 and len(negatives) >= minimum_negatives: 164 | yield curr_q, origin_positives, new_positives[:1], negatives[:n_sample] 165 | 166 | ## ************************* 167 | ## Time to next q ! 168 | ## ************************* 169 | curr_q = q 170 | content = collection[p]['text'] 171 | answers = queries[q]["answers"] 172 | 173 | negatives = [] 174 | new_positives = [] 175 | if not has_answers(content, answers, em_tokenizer, regex=False): 176 | negatives.append(p) 177 | else: 178 | new_positives.append(p) 179 | 180 | ## ************************* 181 | ## Continue curr_q ... 182 | ## ************************* 183 | else: 184 | content = collection[p]['text'] 185 | answers = queries[q]["answers"] 186 | if not has_answers(content, answers, em_tokenizer, regex=False): 187 | negatives.append(p) 188 | else: 189 | new_positives.append(p) 190 | 191 | ## ************************* 192 | ## END 193 | ## ************************* 194 | except StopIteration: 195 | 196 | ## Positive 197 | origin_positives = queries[curr_q]["positive_ctxs"] ## {"title": , "text", } 198 | 199 | ## Negative 200 | negatives = negatives[:depth] 201 | random.shuffle(negatives) 202 | 203 | if (len(origin_positives) + len(new_positives)) >= 1 and len(negatives) >= minimum_negatives: 204 | yield curr_q, origin_positives, new_positives[:1], negatives[:n_sample] 205 | return 206 | 207 | if __name__ == "__main__": 208 | 209 | random.seed(datetime.now()) 210 | parser = ArgumentParser() 211 | parser.add_argument('--tokenizer_name', required=True) 212 | parser.add_argument('--input_file', required=True) 213 | parser.add_argument('--queries', required=True) 214 | parser.add_argument('--collection', required=True) 215 | parser.add_argument('--save_to', required=True) 216 | parser.add_argument('--mark', type=str, default="hn") 217 | 218 | parser.add_argument('--truncate', type=int, default=156) 219 | parser.add_argument('--n_sample', type=int, default=30) 220 | parser.add_argument('--depth', type=int, default=200) 221 | parser.add_argument('--mp_chunk_size', type=int, default=500) 222 | parser.add_argument('--shard_size', type=int, default=45000) 223 | parser.add_argument('--gen_pos_file', type=str, default=None) 224 | 225 | 226 | args = parser.parse_args() 227 | 228 | tokenizer = AutoTokenizer.from_pretrained( 229 | args.tokenizer_name, 230 | use_fast=True 231 | ) 232 | 233 | processor = WikiTrainPreProcessor( 234 | query_file=args.queries, 235 | collection_file=args.collection, 236 | tokenizer=tokenizer, 237 | max_length=args.truncate, 238 | ) 239 | 240 | counter = 0 241 | shard_id = 0 242 | f = None 243 | os.makedirs(args.save_to, exist_ok=True) 244 | 245 | 246 | pbar = tqdm( 247 | load_ranking( 248 | rank_file=args.input_file, 249 | queries=processor.queries, 250 | collection=processor.collection, 251 | em_tokenizer=SimpleTokenizer(), 252 | n_sample=args.n_sample, 253 | depth=args.depth, 254 | ) 255 | ) 256 | 257 | with Pool() as p: 258 | for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size): 259 | counter += 1 260 | if f is None: 261 | f = open(os.path.join(args.save_to, f'split{shard_id:02d}.{args.mark}.json'), 'w') 262 | pbar.set_description(f'split - {shard_id:02d}') 263 | f.write(x + '\n') 264 | 265 | if counter == args.shard_size: 266 | f.close() 267 | f = None 268 | shard_id += 1 269 | counter = 0 270 | 271 | if f is not None: 272 | f.close() 273 | 274 | 275 | if args.gen_pos_file: 276 | file_list = [os.path.join(args.save_to, listx) for listx in os.listdir(args.save_to) \ 277 | if "json" in listx and "cache" not in listx] 278 | with open(args.gen_pos_file, "w", encoding="utf-8") as fw: 279 | for file_path in tqdm(file_list): 280 | with open(file_path, "r", encoding="utf-8") as fi: 281 | for line in fi: 282 | data = json.loads(line) 283 | 284 | qid = data["qid"] 285 | positives = data["positives"][0] 286 | 287 | save_item = {"text_id":qid, "text":positives} 288 | fw.write(json.dumps(save_item) + '\n') 289 | --------------------------------------------------------------------------------