├── .gitignore ├── FIRST.png ├── README.md ├── bash ├── run_1st_retrieval.sh ├── run_2nd_retrieval.sh ├── run_convert_results.sh ├── run_distill.sh ├── run_eval.sh ├── run_prepare_distill.sh ├── run_rerank_CE.sh ├── run_rerank_llm.sh └── run_train.sh ├── scripts ├── convert_results.py ├── distill.py ├── eval.py ├── prepare_distill.py ├── rerank_CE.py ├── rerank_llm.py ├── train_ranking.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── llm_util.py │ ├── loss.py │ ├── rank_listwise_os_llm.py │ ├── rankllm.py │ ├── reranker.py │ ├── result.py │ └── train_utils.py ├── tevatron ├── .gitignore ├── LICENSE ├── README.md ├── mkdocs.yml ├── setup.py └── src │ ├── .DS_Store │ └── tevatron │ ├── __init__.py │ ├── arguments.py │ ├── data.py │ ├── datasets │ ├── __init__.py │ ├── dataset.py │ └── preprocessor.py │ ├── driver │ ├── __init__.py │ ├── encode.py │ ├── encode_new.py │ ├── jax_encode.py │ ├── jax_train.py │ └── train.py │ ├── faiss_retriever │ ├── __init__.py │ ├── __main__.py │ ├── reducer.py │ └── retriever.py │ ├── loss.py │ ├── modeling │ ├── __init__.py │ ├── colbert.py │ ├── dense.py │ ├── encoder.py │ ├── splade.py │ └── unicoil.py │ ├── preprocessor │ ├── __init__.py │ ├── normalize_text.py │ └── preprocessor_tsv.py │ ├── tevax │ ├── __init__.py │ ├── loss.py │ └── training.py │ ├── trainer.py │ └── utils │ ├── __init__.py │ ├── convert_from_dpr.py │ ├── evaluation.py │ ├── format │ ├── __init__.py │ ├── convert_result_to_marco.py │ └── convert_result_to_trec.py │ └── normalize_text.py └── train_configs ├── accel_config_deepspeed.yaml └── zero3_bf16.json /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | data/* 3 | scripts/__pycache__/* 4 | scripts/utils/__pycache__/* 5 | models/* 6 | scripts/*.ipynb 7 | wandb/* 8 | logs/* 9 | qrels/* 10 | outputs/* 11 | scripts/latency_test.py 12 | scripts/logits_reranking_test.py 13 | temp/* -------------------------------------------------------------------------------- /FIRST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/FIRST.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FIRST: Faster Improved Listwise Reranking with Single Token Decoding 2 | 3 | This repository contains the code for the paper [FIRST: Faster Improved Listwise Reranking with Single Token Decoding](https://arxiv.org/pdf/2406.15657) 4 | 5 | FIRST is a novel listwise LLM reranking approach leveraging the output logits of the first generated identifier to obtain a ranked ordering of the input candidates directly. FIRST incorporates a learning-to-rank loss during training, prioritizing ranking accuracy for the more relevant passages. 6 | 7 | 8 | 9 | 10 | ## Installation 11 | You need to install the tevatron library (original source [here](https://github.com/texttron/tevatron)) which provides the framework for retrieval. 12 | 13 | ``` 14 | git clone https://github.com/gangiswag/llm-reranker.git 15 | cd llm-reranker 16 | conda create --name reranker python=3.9.18 17 | cd tevatron 18 | pip install --editable . 19 | pip install beir 20 | ``` 21 | **Note:** You need to install the vLLM library (instructions [here](https://docs.vllm.ai/en/latest/getting_started/installation.html)) which provides optimization for LLM inference. 22 | 23 | Before running the scripts below, do 24 | ``` 25 | export REPO_DIR= 38 | ``` 39 | To get the retrieval scores, run: 40 | 41 | ``` 42 | bash bash/beir/run_eval.sh rank 43 | ``` 44 | 45 | ## 2. Reranking 46 | ### 2a. Baseline Cross-encoder reranking 47 | 48 | To run the baseline cross encoder re-ranking, run: 49 | ``` 50 | bash bash/beir/run_rerank.sh 51 | ``` 52 | ### 2b. FIRST LLM Reranking 53 | 54 | To convert the retrieval results to input for LLM reranking, run: 55 | 56 | ``` 57 | bash bash/beir/run_convert_results.sh 58 | ``` 59 | 60 | We provide the trained FIRST reranker [here](https://huggingface.co/rryisthebest/First_Model). 61 | 62 | To run the FIRST reranking, run: 63 | 64 | ``` 65 | bash bash/beir/run_rerank_llm.sh 66 | ``` 67 | 68 | To evaluate the reranking performance, run: 69 | 70 | ``` 71 | bash bash/run_eval.sh rerank 72 | 73 | ``` 74 | **Note:** Set flag --suffix to "llm_FIRST_alpha" for FIRST reranker evaluation or "ce" for cross encoder reranker 75 | 76 | ## 3. Model Training 77 | We also provide the data and scripts to train the LLM reranker by yourself if you wish to do so. 78 | ### 3a. Training Dataset 79 | Converted training dataset (alphabetic IDs) is on [HF](https://huggingface.co/datasets/rryisthebest/rank_zephyr_training_data_alpha). The standard numeric training dataset can be found [here](https://huggingface.co/datasets/castorini/rank_zephyr_training_data). 80 | 81 | ### 3b. Training 82 | We support three training objectives: 83 | 84 | - **Ranking**: The Ranking objective uses a learning-to-rank algorithm to output the logits for the highest-ranked passage ID. 85 | - **Generation**: The Generation objective follows the principles of Causal Language Modeling, focusing on permutation generation. 86 | - **Combined**: The Combined objective, which we introduce in our paper, is a novel weighted approach that seamlessly integrates both ranking and generation principles, and is the setting applied to the FIRST model. 87 | 88 | 89 | To train the model, run: 90 | ``` 91 | bash bash/beir/run_train.sh 92 | ``` 93 | 94 | To train a gated model, login to Huggingface and get token access at huggingface.co/settings/tokens. 95 | ``` 96 | huggingface-cli login 97 | ``` 98 | ## 4. Relevance Feedback 99 | We also provide scripts here to use the LLM reranker for a downstream task, such as relevance feedback. [Inference-time relevance feedback](https://arxiv.org/pdf/2305.11744) uses the reranker's output to distill the retriever's query embedding to improve recall. 100 | ### 4a. Dataset preparation for relevance feedback 101 | To prepare dataset(s) for relevance feedback, run: 102 | ``` 103 | bash bash/beir/run_prepare_distill.sh 104 | ``` 105 | ### 4b. Distillation (Relevance Feedback Step) 106 | You can choose to run distillation with either the cross encoder or the LLM reranker or both sequentially. 107 | To perform the relevance feedback distillation step, run: 108 | ``` 109 | bash bash/beir/run_distill.sh 110 | ``` 111 | This step creates new query embeddings after distillation. 112 | 113 | ### 4c. 2nd Retrieval 114 | To perform the retrieval step with the new query embedding after distillation, run: 115 | ``` 116 | bash bash/beir/run_2nd_retrieval.sh 117 | ``` 118 | 119 | ### 4d. Relevance feedback evaluation 120 | To evaluate the 2nd retrieval step, run: 121 | ``` 122 | bash bash/beir/run_eval.sh rank_refit 123 | ``` 124 | 125 | ## Citation 126 | 127 | If you found this repo useful for your work, please consider citing our papers: 128 | 129 | ``` 130 | @article{reddy2024first, 131 | title={FIRST: Faster Improved Listwise Reranking with Single Token Decoding}, 132 | author={Reddy, Revanth Gangi and Doo, JaeHyeok and Xu, Yifei and Sultan, Md Arafat and Swain, Deevya and Sil, Avirup and Ji, Heng}, 133 | journal={arXiv preprint arXiv:2406.15657}, 134 | year={2024} 135 | } 136 | ``` 137 | 138 | ``` 139 | @article{reddy2023inference, 140 | title={Inference-time Re-ranker Relevance Feedback for Neural Information Retrieval}, 141 | author={Reddy, Revanth Gangi and Dasigi, Pradeep and Sultan, Md Arafat and Cohan, Arman and Sil, Avirup and Ji, Heng and Hajishirzi, Hannaneh}, 142 | journal={arXiv preprint arXiv:2305.11744}, 143 | year={2023} 144 | } 145 | ``` 146 | 147 | We also acknowledge the following opens-source repos, which were instrumental for this work: 148 | - [Tevatron](https://github.com/texttron/tevatron) for retrieval framework 149 | - [RankLLM](https://github.com/castorini/rank_llm/) for LLM reranking inference backbone. -------------------------------------------------------------------------------- /bash/run_1st_retrieval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | echo "Usage: $0 " 5 | exit 1 6 | fi 7 | 8 | input_dir="$1" 9 | 10 | output_dir="${REPO_DIR}/outputs/beir" 11 | data_dir="${REPO_DIR}/datasets/beir" 12 | 13 | mkdir -p "$output_dir" "$data_dir" 14 | 15 | # Datasets to process 16 | datasets=('trec-covid') # 'climate-fever' 'dbpedia-entity' 'fever' 'fiqa' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'scidocs' 'scifact' 17 | 18 | # Iterate over datasets 19 | for dataset in "${datasets[@]}"; do 20 | echo "Processing dataset: ${dataset}" 21 | 22 | dataset_output_dir="${output_dir}/${dataset}" 23 | mkdir -p "$dataset_output_dir" 24 | 25 | python -m tevatron.faiss_retriever \ 26 | --query_reps "${input_dir}/${dataset}/original_query/qry.pt" \ 27 | --passage_reps "${input_dir}/${dataset}/original_corpus/*.pt" \ 28 | --depth 1000 \ 29 | --batch_size -1 \ 30 | --save_text \ 31 | --save_ranking_to "${dataset_output_dir}/rank.tsv" 32 | done 33 | -------------------------------------------------------------------------------- /bash/run_2nd_retrieval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | echo "Usage: $0 " 5 | exit 1 6 | fi 7 | 8 | input_dir=$1 9 | 10 | # Check if output directory exists 11 | mkdir -p "${REPO_DIR}/outputs/beir" 12 | 13 | # List of datasets to process 14 | datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 15 | 16 | # Process each dataset 17 | for dataset in "${datasets[@]}"; do 18 | echo "Processing dataset: ${dataset}" 19 | 20 | dataset_output_dir="${REPO_DIR}/outputs/beir/${dataset}" 21 | mkdir -p "$dataset_output_dir" 22 | 23 | python -m tevatron.faiss_retriever \ 24 | --query_reps "${dataset_output_dir}/qry_refit.pt" \ 25 | --passage_reps "${input_dir}/${dataset}/original_corpus/*.pt" \ 26 | --depth 1000 \ 27 | --batch_size -1 \ 28 | --save_text \ 29 | --save_ranking_to "${dataset_output_dir}/rank_refit.tsv" 30 | 31 | if [ $? -ne 0 ]; then 32 | echo "Error processing dataset: ${dataset}" 33 | exit 1 34 | fi 35 | 36 | echo "Finished processing dataset: ${dataset}" 37 | done 38 | 39 | echo "All datasets processed successfully." 40 | -------------------------------------------------------------------------------- /bash/run_convert_results.sh: -------------------------------------------------------------------------------- 1 | data_dir=${REPO_DIR}/datasets/beir/ 2 | output_dir=${REPO_DIR}/outputs/beir/ 3 | 4 | # List of datasets to process 5 | datasets=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 6 | 7 | # Iterate over datasets and process each one 8 | for datasets in "${datasets[@]}"; do 9 | echo "Processing dataset: ${datasets}" 10 | 11 | if python "${REPO_DIR}/scripts/convert_results.py" \ 12 | --dataset "${datasets}" \ 13 | --output_dir "${output_dir}" \ 14 | --data_type "beir" \ 15 | --data_dir "${data_dir}" \ 16 | --top_k 100; then 17 | echo "Successfully processed ${datasets}" 18 | else 19 | echo "Failed to process ${datasets}" >&2 20 | exit 1 21 | fi 22 | done -------------------------------------------------------------------------------- /bash/run_distill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if output directory exists 4 | mkdir -p "${REPO_DIR}/outputs/beir" 5 | output_dir="${REPO_DIR}/outputs/beir/" 6 | data_dir="${REPO_DIR}/datasets/beir/" 7 | 8 | # Configuration flags 9 | use_logits=1 # Whether to use FIRST single token logit decoding 10 | use_alpha=1 # Whether to use Alphabetic Identifiers 11 | 12 | # List of datasets to process 13 | datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 14 | 15 | # Process each dataset 16 | for dataset in "${datasets[@]}"; do 17 | echo "Processing dataset: ${dataset}" 18 | 19 | python ${REPO_DIR}/scripts/distill.py \ 20 | --inp_path ${output_dir}/${dataset}/distill_input.pt \ 21 | --rerank_path ${output_dir}/${dataset} \ 22 | --output_path ${output_dir}/${dataset}/qry_refit.pt \ 23 | --ce_top_k 100 \ 24 | --llm_top_k 100 \ 25 | --use_logits ${use_logits} \ 26 | --use_alpha ${use_alpha} \ 27 | --loss_path ${output_dir}/${dataset} \ 28 | --llm_loss ranknet 29 | 30 | if [ $? -ne 0 ]; then 31 | echo "Error processing dataset: ${dataset}" 32 | exit 1 33 | fi 34 | 35 | echo "Finished processing dataset: ${dataset}" 36 | done 37 | 38 | echo "All datasets processed successfully." 39 | -------------------------------------------------------------------------------- /bash/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if eval_type argument is provided 4 | if [ -z "$1" ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | EVAL_TYPE=$1 10 | DATA_DIR="${REPO_DIR}/datasets/beir/" 11 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/" 12 | 13 | # List of datasets to process 14 | DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 15 | 16 | # Iterate over datasets and process each one 17 | for DATASET in "${DATASETS[@]}"; do 18 | echo "Evaluating dataset: ${DATASET}" 19 | 20 | # suffix: ce -> cross encoder reranker | llm_FIRST_alpha -> FIRST Model 21 | if python "${REPO_DIR}/scripts/eval.py" \ 22 | --dataset "${DATASET}" \ 23 | --output_path "${OUTPUT_DIR}" \ 24 | --data_type "beir" \ 25 | --suffix "llm_FIRST_alpha" \ 26 | --eval_type "${EVAL_TYPE}" \ 27 | --data_dir "${DATA_DIR}"; then 28 | echo "Successfully evaluated ${DATASET}" 29 | else 30 | echo "Failed to evaluate ${DATASET}" >&2 31 | exit 1 32 | fi 33 | done -------------------------------------------------------------------------------- /bash/run_prepare_distill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if input directory is provided 4 | if [ -z "$1" ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | input_dir=$1 10 | output_dir="${REPO_DIR}/outputs/beir/" 11 | 12 | # List of datasets to process 13 | datasets=('trec-covid' 'dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 14 | 15 | # Process each dataset 16 | for dataset in "${datasets[@]}"; do 17 | echo "Processing dataset: ${dataset}" 18 | 19 | python ${REPO_DIR}/scripts/prepare_distill.py \ 20 | --output_path ${output_dir}/${dataset}/distill_input.pt \ 21 | --rank_path ${output_dir}/${dataset}/rank.tsv \ 22 | --psg_embs_dir ${input_dir}/${dataset}/original_corpus/ \ 23 | --qry_embs_path ${input_dir}/${dataset}/original_query/qry.pt 24 | 25 | if [ $? -ne 0 ]; then 26 | echo "Error processing dataset: ${dataset}" 27 | exit 1 28 | fi 29 | 30 | echo "Finished processing dataset: ${dataset}" 31 | done 32 | -------------------------------------------------------------------------------- /bash/run_rerank_CE.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set directories 4 | DATA_DIR="${REPO_DIR}/datasets/beir/" 5 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/" 6 | 7 | # List of datasets to rerank 8 | DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity' 9 | 10 | # Iterate over datasets and rerank each one 11 | for DATASET in "${DATASETS[@]}"; do 12 | echo "Reranking dataset: ${DATASET}" 13 | 14 | if python "${REPO_DIR}/scripts/rerank_CE.py" \ 15 | --dataset "${DATASET}" \ 16 | --output_dir "${OUTPUT_DIR}" \ 17 | --data_dir "${DATA_DIR}" \ 18 | --data_type "beir" \ 19 | --top_k 100; then 20 | echo "Successfully reranked ${DATASET} with CE reranker" 21 | else 22 | echo "Failed to rerank ${DATASET}" >&2 23 | exit 1 24 | fi 25 | done 26 | -------------------------------------------------------------------------------- /bash/run_rerank_llm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set directories and model 4 | DATA_DIR="${REPO_DIR}/datasets/beir/" 5 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/" 6 | MODEL_IN_USE="rryisthebest/First_Model" 7 | 8 | # Configuration flags 9 | USE_LOGITS=1 # Whether to use FIRST single token logit decoding 10 | USE_ALPHA=1 # Whether to use Alphabetic Identifiers 11 | 12 | # List of datasets to rerank 13 | DATASETS=('dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'trec-covid' 14 | 15 | # Iterate over datasets and rerank each one 16 | for DATASET in "${DATASETS[@]}"; do 17 | echo "Reranking dataset: ${DATASET}" 18 | 19 | if python "${REPO_DIR}/scripts/rerank_llm.py" \ 20 | --model "${MODEL_IN_USE}" \ 21 | --dataset "${DATASET}" \ 22 | --output_dir "${OUTPUT_DIR}" \ 23 | --data_type "beir" \ 24 | --data_dir "${DATA_DIR}" \ 25 | --use_logits "${USE_LOGITS}" \ 26 | --use_alpha "${USE_ALPHA}" \ 27 | --llm_top_k 100 \ 28 | --window_size 20 \ 29 | --step_size 10 \ 30 | --do_batched 1; then 31 | echo "Successfully reranked ${DATASET} with LLM reranker" 32 | else 33 | echo "Failed to rerank ${DATASET} with LLM reranker" >&2 34 | exit 1 35 | fi 36 | done -------------------------------------------------------------------------------- /bash/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define model, dataset paths, and output directory 4 | BASE_MODEL="HuggingFaceH4/zephyr-7b-beta" 5 | TRAIN_DATA_PATH="rryisthebest/rank_zephyr_training_data_alpha" # Train Dataset --> Hugging Face dataset or Local dataset 6 | EVAL_DATA_PATH="rryisthebest/evaluation_data_alpha" # Eval Dataset --> Hugging Face dataset or Local dataset 7 | OUTPUT_DIR="${REPO_DIR}/models/ranking/FIRST_Model" # Directory to save the trained model 8 | BEIR_DATA_DIR="${REPO_DIR}/datasets/beir/" 9 | 10 | # Launch training with DeepSpeed configuration 11 | accelerate launch --config_file "${REPO_DIR}/train_configs/accel_config_deepspeed.yaml" "${REPO_DIR}/scripts/train_ranking.py" \ 12 | --model_name_or_path "${BASE_MODEL}" \ 13 | --train_dataset_path "${TRAIN_DATA_PATH}" \ 14 | --eval_dataset_path "${EVAL_DATA_PATH}" \ 15 | --beir_data_path "${BEIR_DATA_DIR}" \ 16 | --per_device_eval_batch_size 1 \ 17 | --num_train_epochs 3 \ 18 | --seed 42 \ 19 | --per_device_train_batch_size 2 \ 20 | --eval_steps 400 \ 21 | --gradient_checkpointing \ 22 | --gradient_accumulation_steps 16 \ 23 | --lr_scheduler_type cosine \ 24 | --num_warmup_steps 50 \ 25 | --output_dir "${OUTPUT_DIR}" \ 26 | --noisy_embedding_alpha 5 \ 27 | --objective combined 28 | -------------------------------------------------------------------------------- /scripts/convert_results.py: -------------------------------------------------------------------------------- 1 | from beir.datasets.data_loader import GenericDataLoader 2 | import csv 3 | import os 4 | import json 5 | from collections import defaultdict 6 | from argparse import ArgumentParser 7 | from utils.result import Result, ResultsWriter 8 | 9 | def convert_results(output_path, data_dir, dataset, data_type, top_k, rerank_type="text"): 10 | """Convert ranking results to format suitable for reranking 11 | 12 | Args: 13 | rerank_type (str): Whether this is for "text" or "code" reranking 14 | """ 15 | print(f"Loading {dataset} dataset") 16 | 17 | try: 18 | # Load datasets based on type 19 | if rerank_type == "code": 20 | if dataset in ('swebench_function', 'swebench_file'): 21 | return convert_results_swebench(output_path, data_dir, dataset, data_type, top_k) 22 | 23 | data_path = os.path.join(data_dir, dataset) 24 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test") 25 | 26 | # Load rank data 27 | rank_path = os.path.join(output_path, dataset, "rank.tsv") 28 | dataset_output_path = os.path.join(output_path, dataset) 29 | 30 | else: # text reranking 31 | out_dir = os.path.join(data_dir, "beir") 32 | data_path = os.path.join(out_dir, dataset) 33 | 34 | split = "dev" if dataset == "msmarco" else "test" 35 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 36 | 37 | # Load rank data 38 | dataset_output_path = os.path.join(output_path, "beir", dataset) 39 | rank_path = os.path.join(dataset_output_path, "rank.tsv") 40 | 41 | print("Loading rank data") 42 | if not os.path.exists(rank_path): 43 | print(f"Rank file not found: {rank_path}") 44 | return 45 | 46 | results = {} 47 | with open(rank_path, 'r') as rank_file: 48 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|') 49 | if rerank_type == "code": 50 | next(csv_reader) # Skip header for code files 51 | for row in csv_reader: 52 | qid = str(row[0]) 53 | pid = str(row[1]) 54 | score = float(row[2]) 55 | if qid not in results: 56 | results[qid] = {} 57 | results[qid][pid] = score 58 | 59 | print("Converting to reranker results") 60 | # Remove dummy entries if present (for code reranking) 61 | if 'dummy' in results: 62 | results.pop('dummy') 63 | 64 | results_to_rerank = to_reranker_results(results, queries, corpus, top_k) 65 | 66 | # Ensure output directory exists 67 | os.makedirs(dataset_output_path, exist_ok=True) 68 | 69 | results_output_path = os.path.join(dataset_output_path, f'rank_{top_k}.json') 70 | results_writer = ResultsWriter(results_to_rerank) 71 | results_writer.write_in_json_format(results_output_path) 72 | print(f"Results saved to {results_output_path}") 73 | 74 | except Exception as e: 75 | print(f"Error in convert_results: {e}") 76 | raise 77 | 78 | def convert_results_swebench(output_path, data_dir, dataset, data_type, top_k): 79 | """Special handling for swebench datasets""" 80 | prefx = f"csn_{dataset.split('_')[1]}" 81 | instance_list = [instance for instance in os.listdir(data_dir) if instance.startswith(prefx)] 82 | 83 | for dataset_instance in instance_list: 84 | print(f"Loading {dataset_instance} dataset") 85 | data_path = os.path.join(data_dir, "code_datasets", dataset_instance) 86 | 87 | try: 88 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test") 89 | 90 | print("Loading rank data") 91 | rank_path = os.path.join(output_path, "code_datasets", dataset_instance, "rank.tsv") 92 | 93 | results = {} 94 | with open(rank_path, 'r') as rank_file: 95 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|') 96 | next(csv_reader) # Skip header 97 | for row in csv_reader: 98 | qid = str(row[0]) 99 | pid = str(row[1]) 100 | score = float(row[2]) 101 | if qid not in results: 102 | results[qid] = {} 103 | results[qid][pid] = score 104 | 105 | print("Converting to reranker results") 106 | results_to_rerank = to_reranker_results(results, queries, corpus, top_k) 107 | 108 | dataset_output_path = os.path.join(output_path, "code_datasets", dataset_instance) 109 | os.makedirs(dataset_output_path, exist_ok=True) 110 | 111 | results_output_path = os.path.join(dataset_output_path, f'rank_{top_k}.json') 112 | results_writer = ResultsWriter(results_to_rerank) 113 | results_writer.write_in_json_format(results_output_path) 114 | print(f"Results saved to {results_output_path}") 115 | 116 | except Exception as e: 117 | print(f"Error processing {dataset_instance}: {e}") 118 | continue 119 | 120 | def to_reranker_results(results, queries, corpus, top_k): 121 | """Convert results to format needed by reranker""" 122 | retrieved_results_with_text = [] 123 | for qid, docs_scores in results.items(): 124 | query_text = queries[qid] 125 | for doc_id, score in docs_scores.items(): 126 | doc_text = corpus[doc_id] 127 | result_with_text = { 128 | 'qid': qid, 129 | 'query_text': query_text, 130 | 'doc_id': doc_id, 131 | 'doc_text': doc_text, 132 | 'score': score 133 | } 134 | retrieved_results_with_text.append(result_with_text) 135 | 136 | hits_by_query = defaultdict(list) 137 | for result in retrieved_results_with_text: 138 | content_string = '' 139 | if isinstance(result['doc_text'], dict): 140 | if result['doc_text'].get('title'): 141 | content_string += result['doc_text']['title'] + ". " 142 | content_string += result['doc_text']['text'] 143 | else: 144 | content_string = result['doc_text'] 145 | 146 | hits_by_query[result['query_text']].append({ 147 | 'qid': result['qid'], 148 | 'docid': result['doc_id'], 149 | 'score': result['score'], 150 | 'content': content_string 151 | }) 152 | 153 | results_to_rerank = [] 154 | for query_text, hits in hits_by_query.items(): 155 | sorted_hits = sorted(hits, reverse=True, key=lambda x: x['score'])[:top_k] 156 | result = Result(query=query_text, hits=sorted_hits) 157 | results_to_rerank.append(result) 158 | 159 | return results_to_rerank 160 | 161 | if __name__ == '__main__': 162 | parser = ArgumentParser() 163 | parser.add_argument('--dataset', required=True) 164 | parser.add_argument('--output_dir', required=True) 165 | parser.add_argument('--data_dir', required=True) 166 | parser.add_argument('--data_type', required=True) 167 | parser.add_argument('--top_k', required=True, type=int) 168 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"], 169 | help="Whether to convert for code or text reranking") 170 | args = parser.parse_args() 171 | 172 | assert args.data_type in ["beir", "codedataset"], "Invalid data_type. Must be 'beir' or 'codedataset'." 173 | 174 | if args.data_type == "codedataset" and args.rerank_type != "code": 175 | print("Warning: codedataset data_type implies code reranking. Setting rerank_type to 'code'") 176 | args.rerank_type = "code" 177 | 178 | convert_results(args.output_dir, args.data_dir, args.dataset, args.data_type, args.top_k, args.rerank_type) 179 | -------------------------------------------------------------------------------- /scripts/distill.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | from itertools import product 8 | from argparse import ArgumentParser 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch import Tensor 14 | from utils.loss import loss_dict 15 | 16 | class QueryImpModel(nn.Module): 17 | def __init__(self, query_rep, scaler): 18 | super().__init__() 19 | self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True) 20 | self.scaler = scaler 21 | 22 | def forward(self, psg_embs: Tensor, attn_mask: Tensor = None): 23 | pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1)) 24 | if attn_mask is not None: 25 | extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min 26 | pred_scores += extended_attention_mask 27 | pred_probs = nn.functional.log_softmax(pred_scores, dim=-1) 28 | return pred_probs 29 | 30 | 31 | class QueryScoreModel(nn.Module): 32 | def __init__(self, query_rep, scaler=2.0): 33 | super().__init__() 34 | self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True) 35 | self.scaler = scaler 36 | 37 | def forward(self, psg_embs: Tensor, attn_mask: Tensor = None): 38 | pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1)) 39 | if attn_mask is not None: 40 | extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min 41 | pred_scores += extended_attention_mask 42 | return pred_scores.unsqueeze(0) 43 | 44 | 45 | def load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha): 46 | llm_rerank = None 47 | ce_rerank = None 48 | 49 | if llm_top_k > 0: 50 | suffix = "_llm" 51 | suffix += "_FIRST" if use_logits else "_gen" 52 | suffix += "_alpha" if use_alpha else "_num" 53 | llm_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{llm_top_k}{suffix}.json"))) 54 | 55 | if ce_top_k > 0: 56 | ce_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{ce_top_k}_ce.json"))) 57 | 58 | examples = pickle.load(open(inp_path, "rb")) 59 | return examples, ce_rerank, llm_rerank 60 | 61 | 62 | def prepare_distill_ce(data, ce_rerank, ce_top_k): 63 | qid = data["query_id"] 64 | pids = data["passage_ids"][:ce_top_k] 65 | 66 | data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])} 67 | target_scores = [ce_rerank[qid][pid] for pid in pids] 68 | psg_embs = [data_passage_mapping[pid] for pid in pids] 69 | 70 | target_scores = torch.FloatTensor(target_scores) 71 | target_probs = nn.functional.log_softmax(target_scores, dim=-1) 72 | 73 | baseline_rep = torch.FloatTensor(data["query_rep"]) 74 | passage_reps = torch.FloatTensor(np.array(psg_embs)) 75 | 76 | init_scores = torch.matmul(baseline_rep, passage_reps.transpose(0, 1)) 77 | scaler = (target_scores.max() - target_scores.min()) / (init_scores.max().item() - init_scores.min().item()) 78 | 79 | return passage_reps, target_probs, scaler 80 | 81 | 82 | def prepare_distill_llm(data, llm_rerank, query_rep, llm_top_k): 83 | qid = data["query_id"] 84 | pids = data["passage_ids"][:llm_top_k] 85 | 86 | data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])} 87 | reranked_target_scores = [llm_rerank[qid][pid] for pid in pids] 88 | reranked_psg_embs = [data_passage_mapping[pid] for pid in pids] 89 | 90 | reranked_target_scores = torch.FloatTensor(reranked_target_scores) 91 | reranked_passage_reps = torch.FloatTensor(np.array(reranked_psg_embs)) 92 | 93 | init_scores = torch.matmul(query_rep, reranked_passage_reps.transpose(0, 1)) 94 | scaler = (reranked_target_scores.max() - reranked_target_scores.min()) / \ 95 | (init_scores.max().item() - init_scores.min().item()) 96 | 97 | return reranked_passage_reps, reranked_target_scores.unsqueeze(0), scaler 98 | 99 | 100 | def run_query_teacher_importance_learner(inp_path, rerank_path, output_path, loss_path, ce_top_k, llm_top_k, learning_rate, 101 | num_updates, use_logits, use_alpha, llm_loss): 102 | assert llm_loss in loss_dict 103 | examples, ce_rerank, llm_rerank = load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha) 104 | 105 | reps = [] 106 | ids = [] 107 | 108 | for data in tqdm(examples): 109 | baseline_rep = torch.FloatTensor(data["query_rep"]) 110 | 111 | try: 112 | learned_rep = baseline_rep 113 | if ce_top_k > 0: 114 | passage_reps, target_probs, scaler = prepare_distill_ce(data, ce_rerank, ce_top_k) 115 | ce_dstl_model = QueryImpModel(query_rep=baseline_rep.numpy(), scaler=scaler) 116 | loss_function = nn.KLDivLoss(reduction="batchmean", log_target=True) 117 | optimizer = optim.Adam(ce_dstl_model.parameters(), lr=learning_rate) 118 | 119 | for _ in range(num_updates): 120 | optimizer.zero_grad() 121 | pred_probs = ce_dstl_model(psg_embs=passage_reps) 122 | loss = loss_function(pred_probs.unsqueeze(0), target_probs.unsqueeze(0)) 123 | loss.backward() 124 | optimizer.step() 125 | 126 | learned_rep = ce_dstl_model.query_rep.data.cpu().detach() 127 | 128 | reranked_passage_reps, reranked_target_scores, scaler = prepare_distill_llm(data, llm_rerank, learned_rep, 129 | llm_top_k) 130 | llm_dstl_model = QueryScoreModel(query_rep=learned_rep.numpy(), scaler=scaler) 131 | optimizer = optim.Adam(llm_dstl_model.parameters(), lr=learning_rate / 5) 132 | 133 | for _ in range(num_updates // 5): 134 | optimizer.zero_grad() 135 | pred_scores = llm_dstl_model(psg_embs=reranked_passage_reps) 136 | loss = loss_dict[llm_loss](pred_scores, reranked_target_scores, weighted=True if llm_loss == "ranknet" else False) 137 | loss.backward() 138 | optimizer.step() 139 | 140 | rep = llm_dstl_model.query_rep.data.cpu().detach() 141 | reps.append(rep.numpy()) 142 | ids.append(data["query_id"]) 143 | except Exception as e: 144 | print(f"Error for query ID {data['query_id']}: {e}") 145 | 146 | pickle.dump((np.array(reps), ids), open(output_path, "wb")) 147 | 148 | 149 | if __name__ == "__main__": 150 | 151 | parser = ArgumentParser() 152 | parser.add_argument('--inp_path', required=True) 153 | parser.add_argument('--rerank_path', required=True) 154 | parser.add_argument('--output_path', required=True) 155 | parser.add_argument('--loss_path', required=True) 156 | parser.add_argument('--ce_top_k', type=int, default=100) 157 | parser.add_argument('--llm_top_k', type=int, default=9) 158 | parser.add_argument('--learning_rate', type=float, default=0.005) 159 | parser.add_argument('--num_updates', type=int, default=100) 160 | parser.add_argument('--use_logits', type=int, default=0) 161 | parser.add_argument('--use_alpha', type=int, default=0) 162 | parser.add_argument('--llm_loss', type=str, default="lambdarank") 163 | 164 | args = parser.parse_args() 165 | 166 | run_query_teacher_importance_learner(args.inp_path, args.rerank_path, args.output_path, args.loss_path, args.ce_top_k, args.llm_top_k, args.learning_rate, args.num_updates, args.use_logits, args.use_alpha, args.llm_loss) 167 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from beir import util, LoggingHandler 2 | from beir.datasets.data_loader import GenericDataLoader 3 | from beir.retrieval.evaluation import EvaluateRetrieval 4 | import csv 5 | import os 6 | import logging 7 | import json 8 | from argparse import ArgumentParser 9 | 10 | def write_results_to_text(out_path, ndcg, recall): 11 | with open(out_path, 'w') as f: 12 | f.write(f"{ndcg}\n{recall}\n") 13 | 14 | def eval_rank(output_path, data_dir, dataset, data_type, prefix="", eval_type="rank", rerank_type="text"): 15 | try: 16 | # Load datasets based on type 17 | if rerank_type == "code": 18 | data_path = os.path.join(data_dir, "code_datasets", dataset) 19 | split = "test" 20 | else: # text reranking 21 | data_path = os.path.join(data_dir, "beir", dataset) 22 | split = "dev" if dataset == "msmarco" else "test" 23 | 24 | try: 25 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 26 | except Exception as e: 27 | print(f"Error loading dataset {dataset} from {data_path}: {e}") 28 | return 29 | 30 | # Set file name by evaluation type 31 | fname = "rank.tsv" if eval_type == "rank" else "rank_refit.tsv" 32 | 33 | # Load rank data from appropriate directory 34 | if rerank_type == "code": 35 | model_output_path = os.path.join(output_path, "code_datasets", dataset, fname) 36 | else: 37 | model_output_path = os.path.join(output_path, "beir", dataset, fname) 38 | 39 | if not os.path.exists(model_output_path): 40 | print(f"Rank file not found: {model_output_path}") 41 | return 42 | 43 | results = {} 44 | try: 45 | with open(model_output_path, 'r') as rank_file: 46 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|') 47 | if rerank_type == "code": 48 | next(csv_reader) # Skip header for code files 49 | for row in csv_reader: 50 | qid, pid, score = row[0], row[1], float(row[2]) 51 | if qid not in results: 52 | results[qid] = {} 53 | results[qid][pid] = score 54 | except Exception as e: 55 | print(f"Error reading rank file {model_output_path}: {e}") 56 | return 57 | 58 | retriever = EvaluateRetrieval() 59 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100, 125] 60 | 61 | # Evaluate based on rerank type 62 | if rerank_type == "code": 63 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr") 64 | print(f"MRR@{metrics_to_evaluate}: {mrr}") 65 | else: 66 | # Standard text reranking evaluation 67 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, metrics_to_evaluate) 68 | if dataset == "trec-covid": 69 | recall = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "recall_cap") 70 | print(ndcg, recall) 71 | 72 | except Exception as e: 73 | print(f"Error in eval_rank: {e}") 74 | raise 75 | 76 | def eval_rerank(output_path, data_dir, dataset, data_type, suffix="", rerank_type="text"): 77 | try: 78 | # Load datasets based on type 79 | if rerank_type == "code": 80 | data_path = os.path.join(data_dir, "code_datasets", dataset) 81 | split = "test" 82 | else: # text reranking 83 | data_path = os.path.join(data_dir, "beir", dataset) 84 | split = "dev" if dataset == "msmarco" else "test" 85 | 86 | try: 87 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 88 | except Exception as e: 89 | print(f"Error loading dataset {dataset} from {data_path}: {e}") 90 | return 91 | 92 | # Load rerank results from appropriate directory 93 | if rerank_type == "code": 94 | model_output_path = os.path.join(output_path, "code_datasets", dataset, f"rerank_100_{suffix}.json") 95 | else: 96 | model_output_path = os.path.join(output_path, "beir", dataset, f"rerank_100_{suffix}.json") 97 | 98 | if not os.path.exists(model_output_path): 99 | print(f"Rerank file not found: {model_output_path}") 100 | return 101 | 102 | try: 103 | with open(model_output_path, 'r') as json_file: 104 | results = json.load(json_file) 105 | except Exception as e: 106 | print(f"Error reading rerank file {model_output_path}: {e}") 107 | return 108 | 109 | retriever = EvaluateRetrieval() 110 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100] 111 | 112 | if rerank_type == "code": 113 | # For code reranking, focus on MRR and exact matches 114 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr") 115 | print(f"MRR@{metrics_to_evaluate}: {mrr}") 116 | else: 117 | # Standard text reranking evaluation 118 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, metrics_to_evaluate) 119 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr") 120 | print(ndcg, recall) 121 | 122 | except Exception as e: 123 | print(f"Error in eval_rerank: {e}") 124 | raise 125 | 126 | if __name__ == "__main__": 127 | parser = ArgumentParser() 128 | parser.add_argument('--dataset', required=True, help="Name of the dataset to evaluate.") 129 | parser.add_argument('--output_path', required=True, help="Directory where the output files are stored.") 130 | parser.add_argument('--data_dir', required=True, help="Directory where datasets are stored or will be downloaded.") 131 | parser.add_argument('--suffix', default="", type=str, help="Suffix for the evaluation files (e.g., 'ce', 'logits_alpha').") 132 | parser.add_argument('--data_type', required=True, help="Type of the dataset, must be 'beir' or 'codedataset'.") 133 | parser.add_argument('--eval_type', required=True, help="Type of evaluation: 'rank', 'rerank', or 'rank_refit'.") 134 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"], 135 | help="Whether to evaluate code or text reranking results") 136 | args = parser.parse_args() 137 | 138 | assert args.data_type in ["beir", "codedataset"], "Invalid data_type. Must be 'beir' or 'codedataset'." 139 | assert args.eval_type in ["rank", "rerank", "rank_refit"], "Invalid eval_type. Must be 'rank', 'rerank', or 'rank_refit'." 140 | 141 | if args.data_type == "codedataset" and args.rerank_type != "code": 142 | print("Warning: codedataset data_type implies code reranking. Setting rerank_type to 'code'") 143 | args.rerank_type = "code" 144 | 145 | if args.eval_type in ["rank", "rank_refit"]: 146 | eval_rank(args.output_path, args.data_dir, args.dataset, args.data_type, 147 | args.suffix, args.eval_type, args.rerank_type) 148 | else: 149 | eval_rerank(args.output_path, args.data_dir, args.dataset, args.data_type, 150 | args.suffix, args.rerank_type) 151 | -------------------------------------------------------------------------------- /scripts/prepare_distill.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import csv 3 | from itertools import count 4 | import json 5 | import sys 6 | csv.field_size_limit(sys.maxsize) 7 | 8 | from tqdm import tqdm 9 | from collections import Counter 10 | import random 11 | import pickle 12 | import glob 13 | from argparse import ArgumentParser 14 | import os 15 | import torch 16 | import pickle 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | 21 | def get_all_examples(rank_path, psg_embs_dir, qry_embs_path, output_path): 22 | import os 23 | import torch 24 | import pickle 25 | import numpy as np 26 | import torch.nn as nn 27 | 28 | embs_path = os.path.join(psg_embs_dir, "split*.pt") 29 | embs_files = glob.glob(embs_path) 30 | embs = list() 31 | ids = list() 32 | p_reps_0, look_up_0 = pickle.load(open(embs_files[0], "rb")) 33 | embs.extend(p_reps_0) 34 | ids.extend(look_up_0) 35 | for f in tqdm(embs_files[1:]): 36 | p_reps, look_up = pickle.load(open(f, "rb")) 37 | embs.extend(p_reps) 38 | ids.extend(look_up) 39 | embs = np.array(embs, dtype=np.float32) 40 | embs_dict = dict() 41 | assert len(embs) == len(ids) 42 | for emb, pid in zip(embs, ids): 43 | embs_dict[pid] = emb 44 | 45 | q_reps, q_look_up = pickle.load(open(qry_embs_path, "rb")) 46 | q_embs = np.array(q_reps, dtype=np.float32) 47 | assert len(q_embs) == len(q_look_up) 48 | q_tuples = list() 49 | for q_emb, qid in zip(q_embs, q_look_up): 50 | q_tuples.append((str(qid), q_emb)) 51 | 52 | results = dict() 53 | csv_reader = csv.reader(open(rank_path), delimiter="\t", quotechar='|') 54 | for row in csv_reader: 55 | qid = str(row[0]) 56 | pid = str(row[1]) 57 | score = float(row[2]) 58 | if qid not in results: 59 | results[qid] = dict() 60 | results[qid][pid] = score 61 | 62 | examples = list() 63 | for qid, q_emb in tqdm(q_tuples): 64 | item = dict() 65 | passage_ids = sorted(results[qid].items(), key=lambda item: item[1], reverse=True) 66 | passage_ids = [i[0] for i in passage_ids] 67 | 68 | item["query_rep"] = q_emb 69 | item["passage_ids"] = deepcopy(passage_ids) 70 | item["passage_embs"] = list() 71 | for pid in item["passage_ids"]: 72 | item["passage_embs"].append(embs_dict[pid]) 73 | item["passage_embs"] = np.array(item["passage_embs"]) 74 | item["query_id"] = qid 75 | item["qrels"] = "" 76 | examples.append(deepcopy(item)) 77 | pickle.dump(examples, open(output_path, "wb")) 78 | 79 | if __name__ == "__main__": 80 | parser = ArgumentParser() 81 | parser.add_argument('--rank_path', required=True) 82 | parser.add_argument('--psg_embs_dir', required=True) 83 | parser.add_argument('--qry_embs_path', required=True) 84 | parser.add_argument('--output_path', required=True) 85 | args = parser.parse_args() 86 | 87 | get_all_examples(args.rank_path, args.psg_embs_dir, args.qry_embs_path, args.output_path) 88 | 89 | 90 | -------------------------------------------------------------------------------- /scripts/rerank_CE.py: -------------------------------------------------------------------------------- 1 | from beir.reranking import Rerank 2 | from beir.reranking.models import CrossEncoder 3 | from beir import util, LoggingHandler 4 | from beir.datasets.data_loader import GenericDataLoader 5 | from beir.retrieval.evaluation import EvaluateRetrieval 6 | import csv 7 | import os 8 | import logging 9 | import json 10 | from argparse import ArgumentParser 11 | 12 | def rerank_beir_outputs(output_path, data_dir, dataset, data_type, top_k): 13 | if data_type == "beir": 14 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 15 | out_dir = os.path.join(data_dir, "datasets") 16 | data_path = util.download_and_unzip(url, out_dir) 17 | else: 18 | data_path = data_dir + dataset 19 | 20 | if dataset == "msmarco": 21 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="dev") 22 | else: 23 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test") 24 | model_output_path = os.path.join(os.path.join(output_path, dataset), "rank.tsv") 25 | csv_reader = csv.reader(open(model_output_path), delimiter="\t", quotechar='|') 26 | results = dict() 27 | for row in csv_reader: 28 | qid = str(row[0]) 29 | pid = str(row[1]) 30 | score = float(row[2]) 31 | if qid not in results: 32 | results[qid] = dict() 33 | results[qid][pid] = score 34 | 35 | retriever = EvaluateRetrieval() 36 | 37 | #### Evaluate your retrieval using NDCG@k, MAP@K ... 38 | print("Retriever evaluation") 39 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, [1,3,5,10,20,100]) 40 | print(ndcg) 41 | mrr = retriever.evaluate_custom(qrels, results, [1,3,5,10,20,100], "mrr") 42 | print(mrr) 43 | if dataset == "trec-covid": 44 | recall_cap = retriever.evaluate_custom(qrels, results, [1,3,5,10,20,100], "recall_cap") 45 | print(recall_cap) 46 | else: 47 | print(recall) 48 | 49 | if data_type == "beir": 50 | cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 51 | else: 52 | cross_encoder_model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') 53 | reranker = Rerank(cross_encoder_model, batch_size=64) 54 | 55 | rerank_results = reranker.rerank(corpus, queries, results, top_k=int(top_k)) 56 | 57 | 58 | #### Evaluate your retrieval using NDCG@k, MAP@K ... 59 | print("Re-ranker evaluation") 60 | ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, [1,3,5,10,20,100]) 61 | print(ndcg) 62 | mrr = retriever.evaluate_custom(qrels, rerank_results, [1,3,5,10,20,100], "mrr") 63 | print(mrr) 64 | if dataset == "trec-covid": 65 | recall_cap = retriever.evaluate_custom(qrels, rerank_results, [1,3,5,10,20,100], "recall_cap") 66 | print(recall_cap) 67 | else: 68 | print(recall) 69 | 70 | rerank_path = os.path.join(os.path.join(output_path, dataset), "rerank_" + str(top_k) + "_ce.json") 71 | json.dump(rerank_results, open(rerank_path, "w"), indent=4) 72 | 73 | if __name__ == "__main__": 74 | parser = ArgumentParser() 75 | parser.add_argument('--dataset', required=True) 76 | parser.add_argument('--output_dir', required=True) 77 | parser.add_argument('--data_dir', required=True) 78 | parser.add_argument('--data_type', required=True) 79 | parser.add_argument('--top_k', required=True) 80 | args = parser.parse_args() 81 | 82 | assert args.data_type == "beir" or args.data_type == "mrtydi" 83 | 84 | rerank_beir_outputs(args.output_dir, args.data_dir, args.dataset, args.data_type, args.top_k) 85 | -------------------------------------------------------------------------------- /scripts/rerank_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from argparse import ArgumentParser 4 | from beir.datasets.data_loader import GenericDataLoader 5 | from utils.result import Result, ResultsLoader 6 | from utils.llm_util import evaluate_results, get_results_to_eval, save_rerank_results, rerank_beir_outputs_llm 7 | 8 | def rerank_beir_outputs(model, output_path, data_dir, dataset, data_type, use_logits, use_alpha, llm_top_k, window_size, step_size, batched, context_size, rerank_type="text", code_prompt_type="docstring"): 9 | try: 10 | # Load dataset based on type 11 | if rerank_type == "code": 12 | data_path = os.path.join(data_dir, dataset) 13 | else: # text reranking 14 | data_path = os.path.join(data_dir, "beir", dataset) 15 | 16 | # Handle dataset loading 17 | split = "dev" if dataset == "msmarco" else "test" 18 | try: 19 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 20 | except Exception as e: 21 | print(f"Error loading dataset {dataset} from {data_path}: {e}") 22 | return 23 | 24 | # Load converted retriever results 25 | try: 26 | if rerank_type == "code": 27 | results_output_path = os.path.join(output_path, dataset, 'rank_100.json') 28 | else: 29 | results_output_path = os.path.join(output_path, "beir", dataset, 'rank_100.json') 30 | 31 | results_loader = ResultsLoader(results_output_path) 32 | results_to_rerank = results_loader.get_results(with_context=True) 33 | except Exception as e: 34 | print(f"Error loading results from {results_output_path}: {e}") 35 | return 36 | 37 | # Reranking 38 | try: 39 | reranked_results = rerank_beir_outputs_llm( 40 | model, results_to_rerank, use_logits=use_logits, use_alpha=use_alpha, 41 | top_k=llm_top_k, window_size=window_size, step_size=step_size, 42 | batched=batched, context_size=context_size, 43 | rerank_type=rerank_type, code_prompt_type=code_prompt_type 44 | ) 45 | 46 | # Evaluate results 47 | converted_results = get_results_to_eval(reranked_results) 48 | 49 | if rerank_type == "code": 50 | mrr_at_k = evaluate_results(dataset, qrels, converted_results, rerank_type="code") 51 | print("\nMean Reciprocal Rank (MRR) at different cutoffs:") 52 | for k, mrr in mrr_at_k.items(): 53 | print(f"MRR@{k}: {mrr:.4f}") 54 | else: 55 | ndcg, _map, recall, precision = evaluate_results(dataset, qrels, converted_results) 56 | print(f"\nNDCG (Normalized Discounted Cumulative Gain):\n {ndcg}") 57 | print(f"\nRecall:\n {recall}\n") 58 | 59 | # Save rerank results to appropriate directory 60 | if rerank_type == "code": 61 | save_path = os.path.join(output_path, "code_datasets", dataset) 62 | else: 63 | save_path = os.path.join(output_path, "beir", dataset) 64 | 65 | # Create directory if it doesn't exist 66 | os.makedirs(save_path, exist_ok=True) 67 | 68 | save_rerank_results(save_path, dataset, converted_results, llm_top_k, 69 | use_logits, use_alpha, is_llm_result=True) 70 | print(f"Reranked results saved successfully for dataset {dataset}") 71 | 72 | except Exception as e: 73 | print(f"Error during reranking process: {e}") 74 | raise 75 | 76 | except Exception as e: 77 | print(f"Unexpected error in rerank_beir_outputs: {e}") 78 | raise 79 | 80 | if __name__ == "__main__": 81 | parser = ArgumentParser() 82 | parser.add_argument('--model', default="rryisthebest/First_Model") 83 | parser.add_argument('--dataset', required=True) 84 | parser.add_argument('--output_dir', required=True) 85 | parser.add_argument('--data_dir', required=True) 86 | parser.add_argument('--data_type', required=True) 87 | parser.add_argument('--use_logits', default=0, type=int) 88 | parser.add_argument('--use_alpha', default=0, type=int) 89 | parser.add_argument('--context_size', default=32768, type=int) 90 | parser.add_argument('--llm_top_k', default=20, type=int) 91 | parser.add_argument('--window_size', default=9, type=int) 92 | parser.add_argument('--step_size', default=9, type=int) 93 | parser.add_argument('--do_batched', default=0, type=int) 94 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"], 95 | help="Whether to perform code or text reranking") 96 | parser.add_argument('--code_prompt_type', type=str, default="docstring", 97 | choices=["docstring", "github_issue"], 98 | help="Type of code prompt to use (only applicable when rerank_type is 'code')") 99 | args = parser.parse_args() 100 | 101 | # Validate arguments 102 | if args.rerank_type == "text" and args.code_prompt_type != "docstring": 103 | print("Warning: code_prompt_type is ignored when rerank_type is 'text'") 104 | 105 | if args.rerank_type == "code" and (args.use_logits or args.use_alpha): 106 | print("Warning: Code reranking does not support logits or alpha mode. These will be disabled.") 107 | args.use_logits = 0 108 | args.use_alpha = 0 109 | 110 | rerank_beir_outputs(args.model, args.output_dir, args.data_dir, args.dataset, 111 | args.data_type, args.use_logits, args.use_alpha, args.llm_top_k, 112 | args.window_size, args.step_size, args.do_batched, args.context_size, 113 | args.rerank_type, args.code_prompt_type) -------------------------------------------------------------------------------- /scripts/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/scripts/utils/__init__.py -------------------------------------------------------------------------------- /scripts/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional, Sequence 5 | 6 | import torch 7 | import transformers 8 | from torch.utils.data import Dataset 9 | from ftfy import fix_text 10 | 11 | max_psg_num = 20 12 | START_IDX = ord('A') 13 | IGNORE_INDEX = -100 14 | 15 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 16 | """Tokenize a list of strings.""" 17 | tokenized_list = [ 18 | tokenizer( 19 | text, 20 | return_tensors="pt", 21 | padding="longest", 22 | truncation=True, 23 | ) 24 | for text in strings 25 | ] 26 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 27 | input_ids_lens = labels_lens = [ 28 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 29 | ] 30 | return dict( 31 | input_ids=input_ids, 32 | labels=labels, 33 | input_ids_lens=input_ids_lens, 34 | labels_lens=labels_lens, 35 | ) 36 | 37 | def preprocess( 38 | sources: Sequence[str], 39 | targets: Sequence[str], 40 | tokenizer: transformers.PreTrainedTokenizer, 41 | ) -> Dict: 42 | """Preprocess the data by tokenizing.""" 43 | examples = [s + t for s, t in zip(sources, targets)] 44 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 45 | input_ids = examples_tokenized["input_ids"] 46 | labels = copy.deepcopy(input_ids) 47 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 48 | label[:source_len] = IGNORE_INDEX 49 | return input_ids, labels, sources_tokenized["input_ids_lens"] 50 | 51 | class RankingDataset(Dataset): 52 | def __init__(self, raw_data, model_tokenizer, type) -> None: 53 | self.raw_data = raw_data 54 | self.tokenizer = model_tokenizer 55 | self.tokenizer.padding_side="left" 56 | self.type = type 57 | self.system_message_supported = "system" in self.tokenizer.chat_template 58 | 59 | def __getitem__(self, index): 60 | conversation = self.raw_data[index]["conversations"] 61 | sys_msg = conversation[0]['value'] 62 | input_context = conversation[1]['value'] 63 | target_generation = conversation[2]["value"] 64 | 65 | if self.system_message_supported: 66 | messages = [ 67 | {"role": "system", "content": sys_msg}, 68 | {"role": "user", "content": input_context} 69 | ] 70 | else: 71 | messages = [ 72 | {"role": "user", "content": sys_msg + "\n " + input_context} 73 | ] 74 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 75 | prompt += "[" 76 | prompt = fix_text(prompt) 77 | 78 | if self.type == "train": 79 | label_map = {} 80 | label_rank = 0 81 | for token in target_generation: 82 | if token.isalpha(): 83 | label_map[token] = label_rank 84 | label_rank += 1 85 | 86 | label = [label_map[chr(c)] for c in range(START_IDX, START_IDX+len(label_map))] 87 | 88 | elif self.type == "eval": 89 | label = [self.raw_data[index]["id"]] + self.raw_data[index]["docids"] + self.raw_data[index]["scores"] 90 | else: 91 | raise Exception("Invalid run type specified for Dataset. Choose from ['train', 'eval']") 92 | return prompt, label 93 | 94 | def __len__(self): 95 | return len(self.raw_data) 96 | 97 | class GenerationDataset(Dataset): 98 | def __init__(self, raw_data, model_tokenizer, combined=False) -> None: 99 | self.raw_data = raw_data 100 | self.tokenizer = model_tokenizer 101 | self.combined = combined 102 | self.system_message_supported = "system" in self.tokenizer.chat_template 103 | 104 | def __getitem__(self, index): 105 | conversation = self.raw_data[index]["conversations"] 106 | sys_msg = conversation[0]['value'] 107 | input_context = conversation[1]['value'] 108 | label = conversation[2]["value"] 109 | label += self.tokenizer.eos_token 110 | 111 | if self.system_message_supported: 112 | messages = [ 113 | {"role": "system", "content": sys_msg}, 114 | {"role": "user", "content": input_context} 115 | ] 116 | else: 117 | messages = [ 118 | {"role": "user", "content": sys_msg + "\n " + input_context} 119 | ] 120 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 121 | prompt = fix_text(prompt) 122 | if self.combined: 123 | label_map = {} 124 | label_rank = 0 125 | for token in conversation[2]["value"]: 126 | if token.isalpha(): 127 | label_map[token] = label_rank 128 | label_rank += 1 129 | 130 | rank_label = [label_map[chr(c)] for c in range(START_IDX, START_IDX+len(label_map))] 131 | return prompt, label, rank_label 132 | else: 133 | return prompt, label 134 | 135 | def __len__(self): 136 | return len(self.raw_data) 137 | 138 | def ranking_collate_fn(data, tokenizer): 139 | prompts, labels = list(zip(*data)) 140 | tokenized_inputs = tokenizer(prompts, padding="longest", truncation=False, return_tensors="pt") 141 | return tokenized_inputs, labels 142 | 143 | def generation_collate_fn(data, tokenizer): 144 | prompts, labels = list(zip(*data)) 145 | tokenized_inputs, labels, source_lens = preprocess(prompts, labels, tokenizer) 146 | tokenized_inputs = torch.nn.utils.rnn.pad_sequence( 147 | tokenized_inputs, batch_first=True, padding_value=tokenizer.pad_token_id 148 | ) 149 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 150 | return tokenized_inputs, labels 151 | 152 | def combined_collate_fn(data, tokenizer): 153 | prompts, labels, rank_labels = list(zip(*data)) 154 | tokenized_inputs, labels, source_lens = preprocess(prompts, labels, tokenizer) 155 | tokenized_inputs = torch.nn.utils.rnn.pad_sequence( 156 | tokenized_inputs, batch_first=True, padding_value=tokenizer.pad_token_id 157 | ) 158 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 159 | return tokenized_inputs, labels, rank_labels, source_lens 160 | -------------------------------------------------------------------------------- /scripts/utils/llm_util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import logging 4 | import json 5 | import math 6 | from beir.retrieval.evaluation import EvaluateRetrieval 7 | from utils.result import Result, ResultsLoader 8 | from utils.rankllm import PromptMode, RankLLM 9 | from utils.reranker import Reranker 10 | from utils.rank_listwise_os_llm import RankListwiseOSLLM 11 | 12 | def evaluate_results(dataset, qrels, rerank_results, rerank_type="text"): 13 | """ 14 | Evaluate reranking results for both text and code datasets 15 | 16 | Args: 17 | dataset: Name of the dataset 18 | qrels: Ground truth relevance judgments 19 | rerank_results: Results to evaluate 20 | rerank_type: Either "text" or "code" reranking 21 | 22 | Returns: 23 | For text reranking: (ndcg, _map, recall, precision) 24 | For code reranking: Dictionary of MRR values at different cutoffs 25 | """ 26 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100] 27 | 28 | if rerank_type == "code": 29 | mrr_at_k = {} 30 | 31 | for k in metrics_to_evaluate: 32 | mrr_sum = 0.0 33 | num_queries = 0 34 | 35 | for qid in qrels: 36 | if qid in rerank_results: 37 | sorted_docs = sorted(rerank_results[qid].items(), key=lambda x: x[1], reverse=True)[:k] 38 | 39 | for rank, (doc_id, _) in enumerate(sorted_docs, start=1): 40 | if doc_id in qrels[qid] and qrels[qid][doc_id] > 0: 41 | mrr_sum += 1.0 / rank 42 | break 43 | num_queries += 1 44 | 45 | mrr = mrr_sum / num_queries if num_queries > 0 else 0.0 46 | mrr_at_k[k] = mrr 47 | 48 | return mrr_at_k 49 | else: # text reranking 50 | retriever = EvaluateRetrieval() 51 | 52 | ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, metrics_to_evaluate) 53 | 54 | if dataset == "trec-covid": 55 | recall_cap_metrics = metrics_to_evaluate + [125] 56 | recall = retriever.evaluate_custom(qrels, rerank_results, recall_cap_metrics, metric="recall_cap") 57 | 58 | return ndcg, _map, recall, precision 59 | 60 | def get_results_to_eval(results): 61 | eval_results = {} 62 | 63 | for result in results: 64 | hits = result.hits 65 | qid = hits[0]['qid'] 66 | eval_results[qid] = {hit['docid']: hit['score'] for hit in hits} 67 | 68 | return eval_results 69 | 70 | def save_rerank_results(output_path, dataset, results, top_k, use_logits=False, use_alpha=False, is_llm_result=False): 71 | suffix_parts = [] 72 | 73 | if is_llm_result: 74 | suffix_parts.append("_llm") 75 | suffix_parts.append("_FIRST" if use_logits else "_gen") 76 | suffix_parts.append("_alpha" if use_alpha else "_num") 77 | else: 78 | suffix_parts.append("_ce") 79 | 80 | suffix = "".join(suffix_parts) 81 | if output_path.endswith(dataset): 82 | rerank_path = os.path.join(output_path, f"rerank_{top_k}{suffix}.json") 83 | else: 84 | rerank_path = os.path.join(output_path, f"rerank_{top_k}{suffix}.json") 85 | 86 | os.makedirs(os.path.dirname(rerank_path), exist_ok=True) 87 | 88 | print(f"Saved to: {rerank_path}") 89 | with open(rerank_path, "w") as f: 90 | json.dump(results, f, indent=4) 91 | 92 | def rerank_beir_outputs_llm(model, results_for_rerank, use_logits, use_alpha, top_k, window_size, step_size, batched, context_size, rerank_type="text", code_prompt_type="docstring"): 93 | """ 94 | Rerank outputs using either text or code reranking 95 | 96 | Args: 97 | rerank_type (str): Whether to perform "text" or "code" reranking 98 | code_prompt_type (str): For code reranking, whether to use "docstring" or "github_issue" prompts 99 | """ 100 | # Validate parameters for code reranking 101 | if rerank_type == "code": 102 | if use_logits or use_alpha: 103 | print("Warning: Code reranking does not support logits or alpha mode. These will be disabled.") 104 | use_logits = False 105 | use_alpha = False 106 | 107 | # Select appropriate system message based on rerank type and prompt type 108 | if rerank_type == "code": 109 | if code_prompt_type == "docstring": 110 | system_message = "You are CodeRanker, an intelligent code reviewer that can analyze doc strings and rank code snippets based on their relevance to the doc string." 111 | elif code_prompt_type == "github_issue": 112 | system_message = "You are CodeRanker, an intelligent code reviewer that can analyze GitHub issues and rank code functions based on their relevance to contain the faults causing the GitHub issue." 113 | else: 114 | raise ValueError(f"Invalid code_prompt_type: {code_prompt_type}") 115 | else: # text reranking 116 | system_message = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query" 117 | 118 | # Initialize the ranking model 119 | agent = RankListwiseOSLLM( 120 | model=model, 121 | context_size=context_size, 122 | prompt_mode=PromptMode.RANK_GPT, 123 | num_few_shot_examples=0, 124 | device="cuda", 125 | num_gpus=1, 126 | variable_passages=True, 127 | window_size=window_size, 128 | system_message=system_message, 129 | batched=batched, 130 | rerank_type=rerank_type, 131 | code_prompt_type=code_prompt_type 132 | ) 133 | 134 | # Perform reranking 135 | reranker = Reranker(agent=agent) 136 | reranked_results = reranker.rerank( 137 | retrieved_results=results_for_rerank, 138 | use_logits=use_logits, 139 | use_alpha=use_alpha, 140 | rank_start=0, 141 | rank_end=top_k, 142 | window_size=window_size, 143 | step=step_size, 144 | logging=False, 145 | batched=batched 146 | ) 147 | 148 | for result in reranked_results: 149 | for rank, hit in enumerate(result.hits, start=1): 150 | hit['rank'] = rank 151 | 152 | return reranked_results 153 | -------------------------------------------------------------------------------- /scripts/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch import Tensor 5 | import torch.nn.functional as F 6 | from itertools import product 7 | import numpy as np 8 | 9 | def lambdarank(y_pred, y_true=None, eps=1e-10, padded_value_indicator=-100, weighing_scheme="ndcgLoss2_scheme", k=None, 10 | sigma=1., mu=10., reduction="mean", reduction_log="binary"): 11 | """ 12 | LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization". 13 | Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet. 14 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 15 | :param y_true: ground truth labels, shape [batch_size, slate_length] 16 | :param eps: epsilon value, used for numerical stability 17 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 18 | :param weighing_scheme: a string corresponding to a name of one of the weighing schemes 19 | :param k: rank at which the loss is truncated 20 | :param sigma: score difference weight used in the sigmoid function 21 | :param mu: optional weight used in NDCGLoss2++ weighing scheme 22 | :param reduction: losses reduction method, could be either a sum or a mean 23 | :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural 24 | :return: loss value, a torch.Tensor 25 | """ 26 | if y_true is None: 27 | y_true = torch.zeros_like(y_pred).to(y_pred.device) 28 | y_true[:, 0] = 1 29 | 30 | device = y_pred.device 31 | 32 | # sort the true and predicted relevancy scores. 33 | y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) 34 | y_true_sorted, _ = y_true.sort(descending=True, dim=-1) 35 | 36 | # mask out the pairs of indices (i, j) containing index of a padded element. 37 | true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) 38 | true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] 39 | padded_pairs_mask = torch.isfinite(true_diffs) 40 | 41 | if weighing_scheme != "ndcgLoss1_scheme": 42 | padded_pairs_mask = padded_pairs_mask & (true_diffs > 0) 43 | 44 | ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device) 45 | ndcg_at_k_mask[:k, :k] = 1 46 | 47 | # clamp the -infs to get correct gains and ideal DCGs (maxDCGs) 48 | true_sorted_by_preds.clamp_(min=0.) 49 | y_true_sorted.clamp_(min=0.) 50 | 51 | # find the gains, discounts and ideal DCGs per slate. 52 | pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) 53 | D = torch.log2(1. + pos_idxs.float())[None, :] 54 | maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps) 55 | G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] 56 | 57 | # apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0) 58 | if weighing_scheme is None: 59 | weights = 1. 60 | else: 61 | weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore 62 | 63 | # clamping the array entries to maintain correct backprop (log(0) and division by 0) 64 | scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8) 65 | scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.) 66 | weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps) 67 | if reduction_log == "natural": 68 | losses = torch.log(weighted_probas) 69 | elif reduction_log == "binary": 70 | losses = torch.log2(weighted_probas) 71 | else: 72 | raise ValueError("Reduction logarithm base can be either natural or binary") 73 | 74 | if reduction == "sum": 75 | loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask]) 76 | elif reduction == "mean": 77 | loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask]) 78 | else: 79 | raise ValueError("Reduction method can be either sum or mean") 80 | 81 | return loss 82 | 83 | def ndcgLoss1_scheme(G, D, *args): 84 | return (G / D)[:, :, None] 85 | 86 | 87 | def ndcgLoss2_scheme(G, D, *args): 88 | pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device) 89 | delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :]) 90 | deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.)) 91 | deltas.diagonal().zero_() 92 | 93 | return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :]) 94 | 95 | def rank_net(y_pred, y_true, weighted=False, use_rank=False, weight_by_diff=False, 96 | weight_by_diff_powed=False): 97 | """ 98 | RankNet loss introduced in "Learning to Rank using Gradient Descent". 99 | :param y_pred: predictions from the model, shape [batch_size, slate_length] 100 | :param y_true: ground truth labels, shape [batch_size, slate_length] 101 | :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences. 102 | :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences. 103 | :return: loss value, a torch.Tensor 104 | """ 105 | if use_rank is None: 106 | y_true = torch.tensor([[1 / (np.argsort(y_true)[::-1][i] + 1) for i in range(y_pred.size(1))]] * y_pred.size(0)).cuda() 107 | 108 | # generate every pair of indices from the range of document length in the batch 109 | document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2)) 110 | 111 | pairs_true = y_true[:, document_pairs_candidates] 112 | selected_pred = y_pred[:, document_pairs_candidates] 113 | 114 | # calculate the relative true relevance of every candidate pair 115 | true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1] 116 | pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1] 117 | 118 | # filter just the pairs that are 'positive' and did not involve a padded instance 119 | # we can do that since in the candidate pairs we had symetric pairs so we can stick with 120 | # positive ones for a simpler loss function formulation 121 | the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs)) 122 | 123 | pred_diffs = pred_diffs[the_mask] 124 | 125 | weight = None 126 | if weighted: 127 | values, indices = torch.sort(y_true, descending=True) 128 | ranks = torch.zeros_like(indices) 129 | ranks.scatter_(1, indices, torch.arange(1, y_true.numel() + 1).to(y_true.device).view_as(indices)) 130 | pairs_ranks = ranks[:, document_pairs_candidates] 131 | rank_sum = pairs_ranks.sum(-1) 132 | weight = 1/rank_sum[the_mask] #Relevance Feedback 133 | # rank_prod=pairs_ranks[:, :, 0]*pairs_ranks[:, :, 1] 134 | # weight = rank_sum[the_mask]/rank_prod[the_mask] 135 | else: 136 | if weight_by_diff: 137 | abs_diff = torch.abs(true_diffs) 138 | weight = abs_diff[the_mask] 139 | elif weight_by_diff_powed: 140 | true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2) 141 | abs_diff = torch.abs(true_pow_diffs) 142 | weight = abs_diff[the_mask] 143 | 144 | # 'binarize' true relevancy diffs since for a pairwise loss we just need to know 145 | # whether one document is better than the other and not about the actual difference in 146 | # their relevancy levels 147 | true_diffs = (true_diffs > 0).type(torch.float32) 148 | true_diffs = true_diffs[the_mask] 149 | 150 | return nn.BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs) 151 | 152 | 153 | class ADRMSELoss(nn.Module): 154 | def __init__(self, eps: float = 1e-10): 155 | super().__init__() 156 | self.eps = eps 157 | 158 | def forward(self, scores: torch.Tensor, reranked_target_loss: float = None) -> torch.Tensor: 159 | """ 160 | Compute the Approx Discounted Rank MSE (ADR-MSE) loss. 161 | :param scores: Tensor of shape [batch_size, slate_length] containing scores for each passage. 162 | :param reranked_target_loss: An additional parameter that is ignored in the computation. 163 | :return: Scalar tensor representing the ADR-MSE loss. 164 | """ 165 | batch_size, slate_length = scores.size() 166 | 167 | # Compute the approximated ranks 168 | softmax_scores = torch.softmax(scores, dim=1) 169 | approx_ranks = torch.cumsum(softmax_scores, dim=1) 170 | 171 | # Compute the actual ranks 172 | sorted_indices = torch.argsort(scores, dim=1, descending=True) 173 | ranks = torch.argsort(sorted_indices, dim=1) + 1 # Convert to 1-based ranks 174 | 175 | # Compute the logarithmic discount 176 | log_discounts = torch.log2(ranks.float() + 1) 177 | 178 | rank_diffs = (ranks.float() - approx_ranks) ** 2 179 | 180 | # Apply the logarithmic discount 181 | discounted_diffs = rank_diffs / log_discounts 182 | 183 | loss = discounted_diffs.mean() 184 | return loss 185 | 186 | 187 | def listNet(y_pred, y_true, eps=1e-10, padded_value_indicator=-1): 188 | y_pred = y_pred.clone() 189 | y_true = y_true.clone() 190 | 191 | mask = y_true == padded_value_indicator 192 | y_pred[mask] = float('-inf') 193 | y_true[mask] = float('-inf') 194 | 195 | preds_smax = F.softmax(y_pred, dim=1) 196 | true_smax = F.softmax(y_true, dim=1) 197 | 198 | preds_smax = preds_smax + eps 199 | preds_log = torch.log(preds_smax) 200 | 201 | return torch.mean(-torch.sum(true_smax * preds_log, dim=1)) 202 | 203 | loss_dict = { 204 | "lambdarank": lambdarank, 205 | "ranknet": rank_net, 206 | "listnet_loss": listNet, 207 | "adr_mse_loss": ADRMSELoss() 208 | } -------------------------------------------------------------------------------- /scripts/utils/reranker.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | from typing import List 4 | import os 5 | import time 6 | 7 | from tqdm import tqdm 8 | 9 | from utils.rankllm import RankLLM 10 | from utils.result import Result 11 | 12 | class Reranker: 13 | def __init__(self, agent: RankLLM) -> None: 14 | self._agent = agent 15 | 16 | def rerank( 17 | self, 18 | retrieved_results: List[Result], 19 | use_logits: bool = False, 20 | use_alpha: bool = False, 21 | rank_start: int = 0, 22 | rank_end: int = 100, 23 | window_size: int = 20, 24 | step: int = 10, 25 | logging: bool = False, 26 | batched: bool = False 27 | ) -> List[Result]: 28 | """ 29 | Reranks a list of retrieved results using the RankLLM agent. 30 | 31 | This function applies a sliding window algorithm to rerank the results. 32 | Each window of results is processed by the RankLLM agent to obtain a new ranking. 33 | 34 | Args: 35 | retrieved_results (List[Result]): The list of results to be reranked. 36 | rank_start (int, optional): The starting rank for processing. Defaults to 0. 37 | rank_end (int, optional): The end rank for processing. Defaults to 100. 38 | window_size (int, optional): The size of each sliding window. Defaults to 20. 39 | step (int, optional): The step size for moving the window. Defaults to 10. 40 | logging (bool, optional): Enables logging of the reranking process. Defaults to False. 41 | 42 | Returns: 43 | List[Result]: A list containing the reranked results. 44 | """ 45 | if batched: 46 | return self._agent.sliding_windows_batched( 47 | retrieved_results, 48 | use_logits=use_logits, 49 | use_alpha=use_alpha, 50 | rank_start=max(rank_start, 0), 51 | rank_end=min(rank_end, len(retrieved_results[0].hits)), #TODO: Fails arbitrary hit sizes 52 | window_size=window_size, 53 | step=step, 54 | logging=logging, 55 | ) 56 | 57 | rerank_results = [] 58 | for result in tqdm(retrieved_results): 59 | rerank_result = self._agent.sliding_windows( 60 | result, 61 | use_logits=use_logits, 62 | use_alpha=use_alpha, 63 | rank_start=max(rank_start, 0), 64 | rank_end=min(rank_end, len(result.hits)), 65 | window_size=window_size, 66 | step=step, 67 | logging=logging, 68 | ) 69 | rerank_results.append(rerank_result) 70 | return rerank_results 71 | -------------------------------------------------------------------------------- /scripts/utils/result.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List 3 | 4 | 5 | class RankingExecInfo: 6 | def __init__( 7 | self, prompt, response: str, input_token_count: int, output_token_count: int 8 | ): 9 | self.prompt = prompt 10 | self.response = response 11 | self.input_token_count = input_token_count 12 | self.output_token_count = output_token_count 13 | 14 | def __repr__(self): 15 | return str(self.__dict__) 16 | 17 | 18 | class Result: 19 | def __init__( 20 | self, 21 | query: str, 22 | hits: List[Dict[str, Any]], 23 | ranking_exec_summary: List[RankingExecInfo] = None, 24 | ): 25 | self.query = query 26 | self.hits = hits 27 | self.ranking_exec_summary = ranking_exec_summary 28 | 29 | def __repr__(self): 30 | return str(self.__dict__) 31 | 32 | 33 | class ResultsWriter: 34 | def __init__(self, results: List[Result], append: bool = False): 35 | self._results = results 36 | self._append = append 37 | 38 | def write_in_json_format(self, filename: str): 39 | results = [] 40 | for result in self._results: 41 | results.append({"query": result.query, "hits": result.hits}) 42 | with open(filename, "a" if self._append else "w") as f: 43 | json.dump(results, f, indent=2) 44 | 45 | class ResultsLoader: 46 | def __init__(self, filename: str): 47 | data = json.load(open(filename, 'r')) 48 | self._results = [] 49 | for item in data: 50 | hits = [] 51 | for hit in item['hits']: 52 | hits.append({'qid': hit['qid'], 'docid': hit['docid'], 'score': float(hit['score']), 'content': hit['content']}) 53 | self._results.append(Result(query=item['query'], hits=hits)) 54 | 55 | def get_results(self, with_context: bool): 56 | if with_context: 57 | return self._results 58 | else: 59 | results = dict() 60 | for result in self._results: 61 | query = result.query 62 | hits = result.hits 63 | qid = hits[0]['qid'] 64 | results[qid] = dict() 65 | for hit in hits: 66 | pid = hit['docid'] 67 | score = hit['score'] 68 | results[qid][pid] = score 69 | return results -------------------------------------------------------------------------------- /scripts/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import argparse 5 | from transformers import SchedulerType, CONFIG_MAPPING, MODEL_MAPPING 6 | 7 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 8 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 9 | 10 | def load_data(data_path): 11 | data = [] 12 | with open(data_path, 'r') as f: 13 | for item in list(f): 14 | data.append(json.loads(item)) 15 | f.close() 16 | return data 17 | 18 | # NEFTune: Noisy Embedding 19 | def NEFTune(model, noise_alpha=5): 20 | def noised_embed(orig_embed, noise_alpha): 21 | def new_func(x): 22 | # during training, we add noise to the embedding 23 | # during generation, we don't add noise to the embedding 24 | if model.training: 25 | embed_init = orig_embed(x) 26 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) 27 | mag_norm = noise_alpha/torch.sqrt(dims) 28 | return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) 29 | else: 30 | return orig_embed(x) 31 | return new_func 32 | orig_forward = model.base_model.embed_tokens.forward 33 | model.base_model.embed_tokens.forward = noised_embed(orig_forward, noise_alpha) 34 | return model 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") 39 | 40 | # parser.add_argument( 41 | # "--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data." 42 | # ) 43 | # parser.add_argument( 44 | # "--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data." 45 | # ) 46 | parser.add_argument( 47 | "--model_name_or_path", 48 | type=str, 49 | help="Path to pretrained model or model identifier from huggingface.co/models.", 50 | required=False, 51 | ) 52 | parser.add_argument( 53 | "--config_name", 54 | type=str, 55 | default=None, 56 | help="Pretrained config name or path if not the same as model_name", 57 | ) 58 | parser.add_argument( 59 | "--tokenizer_name", 60 | type=str, 61 | default=None, 62 | help="Pretrained tokenizer name or path if not the same as model_name", 63 | ) 64 | parser.add_argument( 65 | "--train_dataset_path", 66 | type=str, 67 | required=True, 68 | help="Training dataset path in jsonl format" 69 | ) 70 | parser.add_argument( 71 | "--eval_dataset_path", 72 | type=str, 73 | required=True, 74 | help="Validation dataset path in jsonl format" 75 | ) 76 | parser.add_argument( 77 | "--beir_data_path", 78 | type=str, 79 | required=True, 80 | help="BEIR(MSMARCO) dataset for validation" 81 | ) 82 | parser.add_argument( 83 | "--cache_dir", 84 | type=str, 85 | help="Path to cache" 86 | ) 87 | parser.add_argument( 88 | "--per_device_train_batch_size", 89 | type=int, 90 | default=8, 91 | help="Batch size (per device) for the training dataloader.", 92 | ) 93 | parser.add_argument( 94 | "--per_device_eval_batch_size", 95 | type=int, 96 | default=8, 97 | help="Batch size (per device) for the evaluation dataloader.", 98 | ) 99 | parser.add_argument( 100 | "--learning_rate", 101 | type=float, 102 | default=5e-6, 103 | help="Initial learning rate (after the potential warmup period) to use.", 104 | ) 105 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 106 | parser.add_argument("--noisy_embedding_alpha", type=int, default=None, help="NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings") 107 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 108 | parser.add_argument( 109 | "--max_train_steps", 110 | type=int, 111 | default=None, 112 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 113 | ) 114 | parser.add_argument( 115 | "--gradient_accumulation_steps", 116 | type=int, 117 | default=1, 118 | help="Number of updates steps to accumulate before performing a backward/update pass.", 119 | ) 120 | parser.add_argument( 121 | "--ranking_loss", 122 | type=str, 123 | default="lambda", 124 | help="Ranking loss to use", 125 | choices=["lambda", "listnet", "ranknet"] 126 | ) 127 | parser.add_argument( 128 | "--weighted", action="store_true", help="Use weighting with Ranknet" 129 | ) 130 | parser.add_argument( 131 | "--objective", 132 | type=str, 133 | default="generation", 134 | help="Training objective for reranker training", 135 | choices=["ranking", "generation", "combined"] 136 | ) 137 | parser.add_argument( 138 | "--lr_scheduler_type", 139 | type=SchedulerType, 140 | default="linear", 141 | help="The scheduler type to use.", 142 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 143 | ) 144 | parser.add_argument( 145 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 146 | ) 147 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 148 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 149 | parser.add_argument( 150 | "--model_type", 151 | type=str, 152 | default=None, 153 | help="Model type to use if training from scratch.", 154 | choices=MODEL_TYPES, 155 | ) 156 | 157 | parser.add_argument( 158 | "--preprocessing_num_workers", 159 | type=int, 160 | default=None, 161 | help="The number of processes to use for the preprocessing.", 162 | ) 163 | parser.add_argument( 164 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 165 | ) 166 | parser.add_argument( 167 | "--trust_remote_code", 168 | type=bool, 169 | default=True, 170 | help=( 171 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " 172 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 173 | "execute code present on the Hub on your local machine." 174 | ), 175 | ) 176 | parser.add_argument( 177 | "--checkpointing_steps", 178 | type=str, 179 | default="epoch", 180 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 181 | ) 182 | parser.add_argument( 183 | "--eval_steps", 184 | type=int, 185 | default=100, 186 | help="Evaluate the model at the end of every n steps", 187 | ) 188 | parser.add_argument( 189 | "--resume_from_checkpoint", 190 | type=str, 191 | default=None, 192 | help="If the training should continue from a checkpoint folder.", 193 | ) 194 | parser.add_argument( 195 | "--with_tracking", 196 | action="store_true", 197 | help="Whether to enable experiment trackers for logging.", 198 | ) 199 | parser.add_argument( 200 | "--gradient_checkpointing", 201 | action="store_true", 202 | help="Whether to use gradient checkpointing.", 203 | ) 204 | parser.add_argument( 205 | "--report_to", 206 | type=str, 207 | default="all", 208 | help=( 209 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 210 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' 211 | "Only applicable when `--with_tracking` is passed." 212 | ), 213 | ) 214 | parser.add_argument( 215 | "--low_cpu_mem_usage", 216 | action="store_true", 217 | help=( 218 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " 219 | "If passed, LLM loading time and RAM consumption will be benefited." 220 | ), 221 | ) 222 | args = parser.parse_args() 223 | assert args.output_dir is not None 224 | return args -------------------------------------------------------------------------------- /tevatron/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | *.egg-info/ -------------------------------------------------------------------------------- /tevatron/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tevatron/README.md: -------------------------------------------------------------------------------- 1 | # Tevatron 2 | Tevatron is a simple and efficient toolkit for training and running dense retrievers with deep language models. 3 | The toolkit has a modularized design for easy research; a set of command line tools are also provided for fast 4 | development and testing. A set of easy-to-use interfaces to Huggingface's state-of-the-art pre-trained transformers 5 | ensures Tevatron's superior performance. 6 | 7 | *Tevatron is currently under initial development stage. We will be actively adding new features and API changes 8 | may happen. Suggestions, feature requests and PRs are welcomed.* 9 | 10 | ## Features 11 | - Command line interface for dense retriever training/encoding and dense index search. 12 | - Flexible and extendable Pytorch retriever models. 13 | - Highly efficient Trainer, a subclass of Huggingface Trainer, that naively support training performance features like mixed precision and distributed data parallel. 14 | - Fast and memory-efficient train/inference data access based on memory mapping with Apache Arrow through Huggingface datasets. 15 | - Jax/Flax training/encoding on TPU 16 | 17 | ## Installation 18 | First install neural network and similarity search backends, 19 | namely Pytorch (or Jax) and FAISS. 20 | Check out the official installation guides for [Pytorch](https://pytorch.org/get-started/locally/#start-locally) 21 | , [Jax](https://github.com/google/jax) / [Flax](https://flax.readthedocs.io/en/latest/installation.html) 22 | and [FAISS](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) accordingly. 23 | 24 | Then install Tevatron with pip, 25 | ```bash 26 | pip install tevatron 27 | ``` 28 | 29 | Or typically for development and research, clone this repo and install as editable, 30 | ``` 31 | git https://github.com/texttron/tevatron 32 | cd tevatron 33 | pip install --editable . 34 | ``` 35 | 36 | > Note: The current code base has been tested with, `torch==1.10.1`, `faiss-cpu==1.7.2`, `transformers==4.15.0`, `datasets==1.17.0` 37 | 38 | Optionally, you can also install GradCache to support our gradient cache feature during training by: 39 | ```bash 40 | git clone https://github.com/luyug/GradCache 41 | cd GradCache 42 | pip install . 43 | ``` 44 | 45 | ## Documentation 46 | - [**Please view the documentation here**](http://tevatron.ai/) 47 | 48 | 49 | ## Examples 50 | In the `/examples` folder, we provided full pipeline instructions for various IR/QA tasks. 51 | 52 | ## Citation 53 | If you find Tevatron helpful, please consider citing our [paper](https://arxiv.org/abs/2203.05765). 54 | ``` 55 | @article{Gao2022TevatronAE, 56 | title={Tevatron: An Efficient and Flexible Toolkit for Dense Retrieval}, 57 | author={Luyu Gao and Xueguang Ma and Jimmy J. Lin and Jamie Callan}, 58 | journal={ArXiv}, 59 | year={2022}, 60 | volume={abs/2203.05765} 61 | } 62 | ``` 63 | 64 | ## Contacts 65 | If you have a toolkit specific question, feel free to open an issue. 66 | 67 | You can also reach out to us for general comments/suggestions/questions through email. 68 | - Luyu Gao luyug@cs.cmu.edu 69 | - Xueguang Ma x93ma@uwaterloo.ca 70 | -------------------------------------------------------------------------------- /tevatron/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Tevatron 2 | nav: 3 | - Home: index.md 4 | - Datasets: datasets.md 5 | - Training: training.md 6 | - Encoding: encoding.md 7 | - Retrieval: retrieval.md 8 | 9 | theme: 10 | name: 'material' 11 | palette: 12 | primary: 'green' 13 | accent: 'green' 14 | 15 | markdown_extensions: 16 | - pymdownx.highlight 17 | - pymdownx.superfences 18 | - toc: 19 | permalink: true 20 | - attr_list 21 | 22 | repo_name: 'texttron/tevatron' 23 | repo_url: 'https://github.com/texttron/tevatron' -------------------------------------------------------------------------------- /tevatron/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='tevatron', 5 | version='0.0.1', 6 | packages=find_packages("src"), 7 | package_dir={'': 'src'}, 8 | url='https://github.com/texttron/tevatron', 9 | license='Apache 2.0', 10 | author='Luyu Gao', 11 | author_email='luyug@cs.cmu.edu', 12 | description='Tevatron: A toolkit for learning and running deep dense retrieval models.', 13 | python_requires='>=3.7', 14 | install_requires=[ 15 | "torch==2.0.1", 16 | "transformers==4.30.2", 17 | "datasets==2.13.1", 18 | "accelerate==0.20.3", 19 | "faiss-cpu==1.7.2", 20 | "sentencepiece==0.1.99", 21 | "tokenizers==0.13.3" 22 | ] 23 | ) 24 | -------------------------------------------------------------------------------- /tevatron/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/tevatron/src/.DS_Store -------------------------------------------------------------------------------- /tevatron/src/tevatron/__init__.py: -------------------------------------------------------------------------------- 1 | from .faiss_retriever import BaseFaissIPRetriever 2 | from . import utils -------------------------------------------------------------------------------- /tevatron/src/tevatron/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | target_model_path: str = field( 13 | default=None, 14 | metadata={"help": "Path to pretrained reranker target model"} 15 | ) 16 | config_name: Optional[str] = field( 17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 18 | ) 19 | tokenizer_name: Optional[str] = field( 20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 21 | ) 22 | cache_dir: Optional[str] = field( 23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 24 | ) 25 | 26 | # modeling 27 | untie_encoder: bool = field( 28 | default=False, 29 | metadata={"help": "no weight sharing between qry passage encoders"} 30 | ) 31 | 32 | # out projection 33 | add_pooler: bool = field(default=False) 34 | projection_in_dim: int = field(default=768) 35 | projection_out_dim: int = field(default=768) 36 | use_mean_pooling: bool = field(default=False) 37 | 38 | # for Jax training 39 | dtype: Optional[str] = field( 40 | default="float32", 41 | metadata={ 42 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one " 43 | "of `[float32, float16, bfloat16]`. " 44 | }, 45 | ) 46 | 47 | 48 | @dataclass 49 | class DataArguments: 50 | train_dir: str = field( 51 | default=None, metadata={"help": "Path to train directory"} 52 | ) 53 | dataset_name: str = field( 54 | default=None, metadata={"help": "huggingface dataset name"} 55 | ) 56 | passage_field_separator: str = field(default=' ') 57 | dataset_proc_num: int = field( 58 | default=12, metadata={"help": "number of proc used in dataset preprocess"} 59 | ) 60 | train_n_passages: int = field(default=8) 61 | positive_passage_no_shuffle: bool = field( 62 | default=False, metadata={"help": "always use the first positive passage"}) 63 | negative_passage_no_shuffle: bool = field( 64 | default=False, metadata={"help": "always use the first negative passages"}) 65 | 66 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"}) 67 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"}) 68 | encode_is_qry: bool = field(default=False) 69 | encode_num_shard: int = field(default=1) 70 | encode_shard_index: int = field(default=0) 71 | 72 | q_max_len: int = field( 73 | default=32, 74 | metadata={ 75 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 76 | "than this will be truncated, sequences shorter will be padded." 77 | }, 78 | ) 79 | p_max_len: int = field( 80 | default=128, 81 | metadata={ 82 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " 83 | "than this will be truncated, sequences shorter will be padded." 84 | }, 85 | ) 86 | data_cache_dir: Optional[str] = field( 87 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"} 88 | ) 89 | normalize_text: bool = field(default=False, metadata={"help": "Whether to normalize text"}) 90 | 91 | lower_case: bool = field(default=False, metadata={"help": "Whether to lower case text"}) 92 | 93 | 94 | def __post_init__(self): 95 | if self.dataset_name is not None: 96 | info = self.dataset_name.split('/') 97 | self.dataset_split = info[-1] if len(info) == 3 else 'train' 98 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info) 99 | self.dataset_language = 'default' 100 | if ':' in self.dataset_name: 101 | self.dataset_name, self.dataset_language = self.dataset_name.split(':') 102 | else: 103 | self.dataset_name = 'json' 104 | self.dataset_split = 'train' 105 | self.dataset_language = 'default' 106 | if self.train_dir is not None: 107 | files = os.listdir(self.train_dir) 108 | self.train_path = [ 109 | os.path.join(self.train_dir, f) 110 | for f in files 111 | if f.endswith('jsonl') or f.endswith('json') 112 | ] 113 | else: 114 | self.train_path = None 115 | 116 | 117 | @dataclass 118 | class TevatronTrainingArguments(TrainingArguments): 119 | warmup_ratio: float = field(default=0.1) 120 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) 121 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) 122 | 123 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) 124 | gc_q_chunk_size: int = field(default=4) 125 | gc_p_chunk_size: int = field(default=32) 126 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import List, Tuple 4 | 5 | import datasets 6 | from torch.utils.data import Dataset 7 | from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding 8 | 9 | 10 | from .arguments import DataArguments 11 | from .trainer import TevatronTrainer 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class TrainDataset(Dataset): 18 | def __init__( 19 | self, 20 | data_args: DataArguments, 21 | dataset: datasets.Dataset, 22 | tokenizer: PreTrainedTokenizer, 23 | trainer: TevatronTrainer = None, 24 | ): 25 | self.train_data = dataset 26 | self.tok = tokenizer 27 | self.trainer = trainer 28 | 29 | self.data_args = data_args 30 | self.total_len = len(self.train_data) 31 | 32 | def create_one_example(self, text_encoding: List[int], is_query=False): 33 | item = self.tok.encode_plus( 34 | text_encoding, 35 | truncation='only_first', 36 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, 37 | padding=False, 38 | return_attention_mask=False, 39 | return_token_type_ids=False, 40 | ) 41 | return item 42 | 43 | def __len__(self): 44 | return self.total_len 45 | 46 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: 47 | group = self.train_data[item] 48 | epoch = int(self.trainer.state.epoch) 49 | 50 | _hashed_seed = hash(item + self.trainer.args.seed) 51 | 52 | qry = group['query'] 53 | encoded_query = self.create_one_example(qry, is_query=True) 54 | 55 | encoded_passages = [] 56 | group_positives = group['positives'] 57 | group_negatives = group['negatives'] 58 | 59 | if self.data_args.positive_passage_no_shuffle: 60 | pos_psg = group_positives[0] 61 | else: 62 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)] 63 | encoded_passages.append(self.create_one_example(pos_psg)) 64 | 65 | negative_size = self.data_args.train_n_passages - 1 66 | if len(group_negatives) < negative_size: 67 | negs = random.choices(group_negatives, k=negative_size) 68 | elif self.data_args.train_n_passages == 1: 69 | negs = [] 70 | elif self.data_args.negative_passage_no_shuffle: 71 | negs = group_negatives[:negative_size] 72 | else: 73 | _offset = epoch * negative_size % len(group_negatives) 74 | negs = [x for x in group_negatives] 75 | random.Random(_hashed_seed).shuffle(negs) 76 | negs = negs * 2 77 | negs = negs[_offset: _offset + negative_size] 78 | 79 | for neg_psg in negs: 80 | encoded_passages.append(self.create_one_example(neg_psg)) 81 | 82 | return encoded_query, encoded_passages 83 | 84 | 85 | class EncodeDataset(Dataset): 86 | input_keys = ['text_id', 'text'] 87 | 88 | def __init__(self, dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer, max_len=128): 89 | self.encode_data = dataset 90 | self.tok = tokenizer 91 | self.max_len = max_len 92 | 93 | def __len__(self): 94 | return len(self.encode_data) 95 | 96 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]: 97 | text_id, text = (self.encode_data[item][f] for f in self.input_keys) 98 | encoded_text = self.tok.encode_plus( 99 | text, 100 | max_length=self.max_len, 101 | truncation='only_first', 102 | padding=False, 103 | return_token_type_ids=False, 104 | ) 105 | return text_id, encoded_text 106 | 107 | 108 | @dataclass 109 | class QPCollator(DataCollatorWithPadding): 110 | """ 111 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] 112 | and pass batch separately to the actual collator. 113 | Abstract out data detail for the model. 114 | """ 115 | max_q_len: int = 32 116 | max_p_len: int = 128 117 | 118 | def __call__(self, features): 119 | qq = [f[0] for f in features] 120 | dd = [f[1] for f in features] 121 | 122 | if isinstance(qq[0], list): 123 | qq = sum(qq, []) 124 | if isinstance(dd[0], list): 125 | dd = sum(dd, []) 126 | 127 | q_collated = self.tokenizer.pad( 128 | qq, 129 | padding='max_length', 130 | max_length=self.max_q_len, 131 | return_tensors="pt", 132 | ) 133 | d_collated = self.tokenizer.pad( 134 | dd, 135 | padding='max_length', 136 | max_length=self.max_p_len, 137 | return_tensors="pt", 138 | ) 139 | 140 | return q_collated, d_collated 141 | 142 | 143 | @dataclass 144 | class EncodeCollator(DataCollatorWithPadding): 145 | def __call__(self, features): 146 | text_ids = [x[0] for x in features] 147 | text_features = [x[1] for x in features] 148 | collated_features = super().__call__(text_features) 149 | return text_ids, collated_features -------------------------------------------------------------------------------- /tevatron/src/tevatron/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import HFTrainDataset, HFQueryDataset, HFCorpusDataset 2 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor 3 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import PreTrainedTokenizer 3 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor 4 | from ..arguments import DataArguments 5 | 6 | DEFAULT_PROCESSORS = [TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor] 7 | PROCESSOR_INFO = { 8 | 'Tevatron/wikipedia-nq': DEFAULT_PROCESSORS, 9 | 'Tevatron/wikipedia-trivia': DEFAULT_PROCESSORS, 10 | 'Tevatron/wikipedia-curated': DEFAULT_PROCESSORS, 11 | 'Tevatron/wikipedia-wq': DEFAULT_PROCESSORS, 12 | 'Tevatron/wikipedia-squad': DEFAULT_PROCESSORS, 13 | 'Tevatron/scifact': DEFAULT_PROCESSORS, 14 | 'Tevatron/msmarco-passage': DEFAULT_PROCESSORS, 15 | 'json': [None, None, None] 16 | } 17 | 18 | 19 | class HFTrainDataset: 20 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 21 | data_files = data_args.train_path 22 | if data_files: 23 | data_files = {data_args.dataset_split: data_files} 24 | self.dataset = load_dataset(data_args.dataset_name, 25 | data_args.dataset_language, 26 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 27 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][0] if data_args.dataset_name in PROCESSOR_INFO\ 28 | else DEFAULT_PROCESSORS[0] 29 | self.tokenizer = tokenizer 30 | self.q_max_len = data_args.q_max_len 31 | self.p_max_len = data_args.p_max_len 32 | self.proc_num = data_args.dataset_proc_num 33 | self.neg_num = data_args.train_n_passages - 1 34 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 35 | 36 | def process(self, shard_num=1, shard_idx=0): 37 | self.dataset = self.dataset.shard(shard_num, shard_idx) 38 | print("Processing Dataset") 39 | if self.preprocessor is not None: 40 | self.dataset = self.dataset.map( 41 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len, self.separator), 42 | batched=False, 43 | num_proc=self.proc_num, 44 | remove_columns=self.dataset.column_names, 45 | desc="Running tokenizer on train dataset", 46 | ) 47 | return self.dataset 48 | 49 | 50 | class HFQueryDataset: 51 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 52 | data_files = data_args.encode_in_path 53 | if data_files: 54 | data_files = {data_args.dataset_split: data_files} 55 | self.dataset = load_dataset(data_args.dataset_name, 56 | data_args.dataset_language, 57 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 58 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][1] if data_args.dataset_name in PROCESSOR_INFO \ 59 | else DEFAULT_PROCESSORS[1] 60 | self.tokenizer = tokenizer 61 | self.q_max_len = data_args.q_max_len 62 | self.proc_num = data_args.dataset_proc_num 63 | 64 | def process(self, shard_num=1, shard_idx=0): 65 | self.dataset = self.dataset.shard(shard_num, shard_idx) 66 | if self.preprocessor is not None: 67 | self.dataset = self.dataset.map( 68 | self.preprocessor(self.tokenizer, self.q_max_len), 69 | batched=False, 70 | num_proc=self.proc_num, 71 | remove_columns=self.dataset.column_names, 72 | desc="Running tokenization", 73 | ) 74 | return self.dataset 75 | 76 | 77 | class HFCorpusDataset: 78 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str): 79 | data_files = data_args.encode_in_path 80 | if data_files: 81 | data_files = {data_args.dataset_split: data_files} 82 | self.dataset = load_dataset(data_args.dataset_name, 83 | data_args.dataset_language, 84 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split] 85 | script_prefix = data_args.dataset_name 86 | if script_prefix.endswith('-corpus'): 87 | script_prefix = script_prefix[:-7] 88 | self.preprocessor = PROCESSOR_INFO[script_prefix][2] \ 89 | if script_prefix in PROCESSOR_INFO else DEFAULT_PROCESSORS[2] 90 | self.tokenizer = tokenizer 91 | self.p_max_len = data_args.p_max_len 92 | self.proc_num = data_args.dataset_proc_num 93 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator) 94 | 95 | def process(self, shard_num=1, shard_idx=0): 96 | self.dataset = self.dataset.shard(shard_num, shard_idx) 97 | if self.preprocessor is not None: 98 | self.dataset = self.dataset.map( 99 | self.preprocessor(self.tokenizer, self.p_max_len, self.separator), 100 | batched=False, 101 | num_proc=self.proc_num, 102 | remove_columns=self.dataset.column_names, 103 | desc="Running tokenization", 104 | ) 105 | return self.dataset 106 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/datasets/preprocessor.py: -------------------------------------------------------------------------------- 1 | class TrainPreProcessor: 2 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256, separator=' '): 3 | self.tokenizer = tokenizer 4 | self.query_max_length = query_max_length 5 | self.text_max_length = text_max_length 6 | self.separator = separator 7 | 8 | def __call__(self, example): 9 | query = self.tokenizer.encode(example['query'], 10 | add_special_tokens=False, 11 | max_length=self.query_max_length, 12 | truncation=True) 13 | positives = [] 14 | for pos in example['positive_passages']: 15 | text = pos['title'] + self.separator + pos['text'] if 'title' in pos else pos['text'] 16 | positives.append(self.tokenizer.encode(text, 17 | add_special_tokens=False, 18 | max_length=self.text_max_length, 19 | truncation=True)) 20 | negatives = [] 21 | for neg in example['negative_passages']: 22 | text = neg['title'] + self.separator + neg['text'] if 'title' in neg else neg['text'] 23 | negatives.append(self.tokenizer.encode(text, 24 | add_special_tokens=False, 25 | max_length=self.text_max_length, 26 | truncation=True)) 27 | return {'query': query, 'positives': positives, 'negatives': negatives} 28 | 29 | 30 | class QueryPreProcessor: 31 | def __init__(self, tokenizer, query_max_length=32): 32 | self.tokenizer = tokenizer 33 | self.query_max_length = query_max_length 34 | 35 | def __call__(self, example): 36 | query_id = example['query_id'] 37 | query = self.tokenizer.encode(example['query'], 38 | add_special_tokens=False, 39 | max_length=self.query_max_length, 40 | truncation=True) 41 | return {'text_id': query_id, 'text': query} 42 | 43 | 44 | class CorpusPreProcessor: 45 | def __init__(self, tokenizer, text_max_length=256, separator=' '): 46 | self.tokenizer = tokenizer 47 | self.text_max_length = text_max_length 48 | self.separator = separator 49 | 50 | def __call__(self, example): 51 | docid = example['docid'] 52 | text = example['title'] + self.separator + example['text'] if 'title' in example else example['text'] 53 | text = self.tokenizer.encode(text, 54 | add_special_tokens=False, 55 | max_length=self.text_max_length, 56 | truncation=True) 57 | return {'text_id': docid, 'text': text} 58 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/tevatron/src/tevatron/driver/__init__.py -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoConfig, AutoTokenizer 14 | from transformers import ( 15 | HfArgumentParser, 16 | ) 17 | 18 | from tevatron.arguments import ModelArguments, DataArguments, \ 19 | TevatronTrainingArguments as TrainingArguments 20 | from tevatron.data import EncodeDataset, EncodeCollator 21 | from tevatron.modeling import EncoderOutput, DenseModel 22 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 38 | raise NotImplementedError('Multi-GPU encoding is not supported.') 39 | 40 | # Setup logging 41 | logging.basicConfig( 42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 43 | datefmt="%m/%d/%Y %H:%M:%S", 44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 45 | ) 46 | 47 | num_labels = 1 48 | config = AutoConfig.from_pretrained( 49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 50 | num_labels=num_labels, 51 | cache_dir=model_args.cache_dir, 52 | ) 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 55 | cache_dir=model_args.cache_dir, 56 | use_fast=False, 57 | ) 58 | 59 | model = DenseModel.load( 60 | model_name_or_path=model_args.model_name_or_path, 61 | config=config, 62 | cache_dir=model_args.cache_dir, 63 | ) 64 | 65 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 66 | if data_args.encode_is_qry: 67 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args, 68 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 69 | else: 70 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args, 71 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 72 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 73 | tokenizer, max_len=text_max_length) 74 | 75 | encode_loader = DataLoader( 76 | encode_dataset, 77 | batch_size=training_args.per_device_eval_batch_size, 78 | collate_fn=EncodeCollator( 79 | tokenizer, 80 | max_length=text_max_length, 81 | padding='max_length' 82 | ), 83 | shuffle=False, 84 | drop_last=False, 85 | num_workers=training_args.dataloader_num_workers, 86 | ) 87 | encoded = [] 88 | lookup_indices = [] 89 | model = model.to(training_args.device) 90 | model.eval() 91 | 92 | for (batch_ids, batch) in tqdm(encode_loader): 93 | lookup_indices.extend(batch_ids) 94 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext(): 95 | with torch.no_grad(): 96 | for k, v in batch.items(): 97 | batch[k] = v.to(training_args.device) 98 | if data_args.encode_is_qry: 99 | model_output: EncoderOutput = model(query=batch) 100 | encoded.append(model_output.q_reps.cpu().detach().numpy()) 101 | else: 102 | model_output: EncoderOutput = model(passage=batch) 103 | encoded.append(model_output.p_reps.cpu().detach().numpy()) 104 | 105 | encoded = np.concatenate(encoded) 106 | 107 | with open(data_args.encoded_save_path, 'wb') as f: 108 | pickle.dump((encoded, lookup_indices), f) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/encode_new.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | import torch 11 | import json 12 | 13 | from torch.utils.data import DataLoader 14 | from transformers import AutoConfig, AutoTokenizer 15 | from transformers import ( 16 | HfArgumentParser, 17 | ) 18 | import csv 19 | 20 | from tevatron.arguments import ModelArguments, DataArguments, \ 21 | TevatronTrainingArguments as TrainingArguments 22 | from tevatron.data import EncodeDataset, EncodeCollator 23 | from tevatron.modeling import EncoderOutput, DenseModel 24 | from tevatron.utils.normalize_text import normalize 25 | logger = logging.getLogger(__name__) 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 38 | raise NotImplementedError('Multi-GPU encoding is not supported.') 39 | 40 | # Setup logging 41 | logging.basicConfig( 42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 43 | datefmt="%m/%d/%Y %H:%M:%S", 44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 45 | ) 46 | 47 | num_labels = 1 48 | config = AutoConfig.from_pretrained( 49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 50 | num_labels=num_labels, 51 | cache_dir=model_args.cache_dir, 52 | ) 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 55 | cache_dir=model_args.cache_dir, 56 | use_fast=False, 57 | ) 58 | model = DenseModel.load( 59 | model_name_or_path=model_args.model_name_or_path, 60 | config=config, 61 | cache_dir=model_args.cache_dir, 62 | ) 63 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 64 | 65 | encoded = [] 66 | lookup_indices = [] 67 | model = model.to(training_args.device) 68 | model.eval() 69 | corpus = list() 70 | 71 | if data_args.encode_is_qry: 72 | tsv_reader = csv.reader(open(data_args.encode_in_path[0]), delimiter="\t") 73 | for row in tsv_reader: 74 | corpus.append({"id": row[0], "text": row[1]}) 75 | corpus_items = [(c["id"],c["text"]) for c in corpus] 76 | 77 | else: 78 | corpus_lines = open(data_args.encode_in_path[0]).readlines() 79 | for line in corpus_lines: 80 | corpus.append(deepcopy(json.loads(line))) 81 | 82 | corpus_items = [(c["id"], c["title"] + " " +c["text"]) for c in corpus] 83 | 84 | if data_args.normalize_text: 85 | corpus_items = [(c[0], normalize(c[1])) for c in corpus_items] 86 | 87 | if data_args.lower_case: 88 | corpus_items = [(c[0], c[1].lower()) for c in corpus_items] 89 | 90 | corpus_items = sorted(corpus_items, key=lambda item: len(item[1]), reverse=True) 91 | 92 | batch_size = training_args.per_device_eval_batch_size 93 | nbatch = int(len(corpus_items) / batch_size) + 1 94 | 95 | with torch.no_grad(): 96 | for k in tqdm(range(nbatch)): 97 | try: 98 | start_idx = k * batch_size 99 | end_idx = min((k + 1) * batch_size, len(corpus)) 100 | 101 | batch_items = corpus_items[start_idx:end_idx] 102 | batch_ids = [item[0] for item in batch_items] 103 | batch_text = [item[1] for item in batch_items] 104 | 105 | cencode = tokenizer( 106 | batch_text, 107 | padding=True, 108 | truncation='longest_first', 109 | return_tensors="pt", 110 | ) 111 | 112 | cencode = {key: value.cuda() for key, value in cencode.items()} 113 | 114 | if data_args.encode_is_qry: 115 | model_output: EncoderOutput = model(query=cencode) 116 | encoded.append(model_output.q_reps.cpu().detach().numpy()) 117 | else: 118 | model_output: EncoderOutput = model(passage=cencode) 119 | encoded.append(model_output.p_reps.cpu().detach().numpy()) 120 | 121 | lookup_indices.extend(batch_ids) 122 | except Exception as e: 123 | print(e) 124 | continue 125 | 126 | 127 | encoded = np.concatenate(encoded) 128 | 129 | with open(data_args.encoded_save_path, 'wb') as f: 130 | pickle.dump((encoded, lookup_indices), f) 131 | 132 | 133 | 134 | 135 | if __name__ == "__main__": 136 | main() -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/jax_encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import datasets 7 | import jax 8 | import numpy as np 9 | from flax.training.common_utils import shard 10 | from jax import pmap 11 | from tevatron.arguments import DataArguments 12 | from tevatron.arguments import TevatronTrainingArguments as TrainingArguments 13 | from tevatron.arguments import ModelArguments 14 | from tevatron.data import EncodeCollator, EncodeDataset 15 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | from flax.training.train_state import TrainState 19 | from flax import jax_utils 20 | import optax 21 | from transformers import (AutoConfig, AutoTokenizer, FlaxAutoModel, 22 | HfArgumentParser, TensorType) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | # Setup logging 38 | logging.basicConfig( 39 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 40 | datefmt="%m/%d/%Y %H:%M:%S", 41 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 42 | ) 43 | 44 | num_labels = 1 45 | config = AutoConfig.from_pretrained( 46 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 47 | num_labels=num_labels, 48 | cache_dir=model_args.cache_dir, 49 | ) 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 52 | cache_dir=model_args.cache_dir, 53 | use_fast=False, 54 | ) 55 | 56 | model = FlaxAutoModel.from_pretrained(model_args.model_name_or_path, config=config, from_pt=False) 57 | 58 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len 59 | if data_args.encode_is_qry: 60 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args, 61 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 62 | else: 63 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args, 64 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 65 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index), 66 | tokenizer, max_len=text_max_length) 67 | 68 | # prepare padding batch (for last nonfull batch) 69 | dataset_size = len(encode_dataset) 70 | padding_prefix = "padding_" 71 | total_batch_size = len(jax.devices()) * training_args.per_device_eval_batch_size 72 | features = list(encode_dataset.encode_data.features.keys()) 73 | padding_batch = {features[0]: [], features[1]: []} 74 | for i in range(total_batch_size - (dataset_size % total_batch_size)): 75 | padding_batch["text_id"].append(f"{padding_prefix}{i}") 76 | padding_batch["text"].append([0]) 77 | padding_batch = datasets.Dataset.from_dict(padding_batch) 78 | encode_dataset.encode_data = datasets.concatenate_datasets([encode_dataset.encode_data, padding_batch]) 79 | 80 | encode_loader = DataLoader( 81 | encode_dataset, 82 | batch_size=training_args.per_device_eval_batch_size * len(jax.devices()), 83 | collate_fn=EncodeCollator( 84 | tokenizer, 85 | max_length=text_max_length, 86 | padding='max_length', 87 | pad_to_multiple_of=16, 88 | return_tensors=TensorType.NUMPY, 89 | ), 90 | shuffle=False, 91 | drop_last=False, 92 | num_workers=training_args.dataloader_num_workers, 93 | ) 94 | 95 | # craft a fake state for now to replicate on devices 96 | adamw = optax.adamw(0.0001) 97 | state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) 98 | 99 | def encode_step(batch, state): 100 | embedding = state.apply_fn(**batch, params=state.params, train=False)[0] 101 | return embedding[:, 0] 102 | 103 | p_encode_step = pmap(encode_step) 104 | state = jax_utils.replicate(state) 105 | 106 | encoded = [] 107 | lookup_indices = [] 108 | 109 | for (batch_ids, batch) in tqdm(encode_loader): 110 | lookup_indices.extend(batch_ids) 111 | batch_embeddings = p_encode_step(shard(batch.data), state) 112 | encoded.extend(np.concatenate(batch_embeddings, axis=0)) 113 | with open(data_args.encoded_save_path, 'wb') as f: 114 | pickle.dump((encoded[:dataset_size], lookup_indices[:dataset_size]), f) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/jax_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from functools import partial 5 | 6 | import datasets 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | from flax import jax_utils, traverse_util 11 | from flax.jax_utils import prefetch_to_device 12 | from flax.training.common_utils import get_metrics, shard 13 | from torch.utils.data import DataLoader, IterableDataset 14 | from tqdm import tqdm 15 | from transformers import AutoConfig, AutoTokenizer, FlaxAutoModel 16 | from transformers import ( 17 | HfArgumentParser, 18 | set_seed, 19 | ) 20 | 21 | from tevatron.arguments import ModelArguments, DataArguments, TevatronTrainingArguments 22 | from tevatron.tevax.training import TiedParams, RetrieverTrainState, retriever_train_step, grad_cache_train_step, \ 23 | DualParams 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) 30 | 31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 33 | else: 34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 35 | model_args: ModelArguments 36 | data_args: DataArguments 37 | training_args: TevatronTrainingArguments 38 | 39 | if ( 40 | os.path.exists(training_args.output_dir) 41 | and os.listdir(training_args.output_dir) 42 | and training_args.do_train 43 | and not training_args.overwrite_output_dir 44 | ): 45 | raise ValueError( 46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 47 | ) 48 | 49 | # Setup logging 50 | logging.basicConfig( 51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 52 | datefmt="%m/%d/%Y %H:%M:%S", 53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 54 | ) 55 | logger.warning( 56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 57 | training_args.local_rank, 58 | training_args.device, 59 | training_args.n_gpu, 60 | bool(training_args.local_rank != -1), 61 | training_args.fp16, 62 | ) 63 | logger.info("Training/evaluation parameters %s", training_args) 64 | logger.info("MODEL parameters %s", model_args) 65 | 66 | set_seed(training_args.seed) 67 | 68 | config = AutoConfig.from_pretrained( 69 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 70 | cache_dir=model_args.cache_dir, 71 | ) 72 | tokenizer = AutoTokenizer.from_pretrained( 73 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 74 | cache_dir=model_args.cache_dir, 75 | ) 76 | try: 77 | model = FlaxAutoModel.from_pretrained( 78 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) 79 | ) 80 | except: 81 | model = FlaxAutoModel.from_pretrained( 82 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), 83 | from_pt=True 84 | ) 85 | 86 | if data_args.train_dir: 87 | data_files = { 88 | 'train': data_args.train_path 89 | } 90 | else: 91 | data_files = None 92 | 93 | train_dataset = \ 94 | datasets.load_dataset(data_args.dataset_name, data_args.dataset_language, cache_dir=model_args.cache_dir, 95 | data_files=data_files)[data_args.dataset_split] 96 | 97 | def tokenize_train(example): 98 | tokenize = partial(tokenizer, return_attention_mask=False, return_token_type_ids=False, padding=False, 99 | truncation=True) 100 | query = example['query'] 101 | pos_psgs = [p['title'] + " " + p['text'] for p in example['positive_passages']] 102 | neg_psgs = [p['title'] + " " + p['text'] for p in example['negative_passages']] 103 | 104 | example['query_input_ids'] = dict(tokenize(query, max_length=32)) 105 | example['pos_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in pos_psgs] 106 | example['neg_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in neg_psgs] 107 | 108 | return example 109 | 110 | train_data = train_dataset.map( 111 | tokenize_train, 112 | batched=False, 113 | num_proc=data_args.dataset_proc_num, 114 | desc="Running tokenizer on train dataset", 115 | ) 116 | train_data = train_data.filter( 117 | function=lambda data: len(data["pos_psgs_input_ids"]) >= 1 and \ 118 | len(data["neg_psgs_input_ids"]) >= data_args.train_n_passages-1, num_proc=64 119 | ) 120 | 121 | class TrainDataset: 122 | def __init__(self, train_data, group_size, tokenizer): 123 | self.group_size = group_size 124 | self.data = train_data 125 | self.tokenizer = tokenizer 126 | 127 | def __len__(self): 128 | return len(self.data) 129 | 130 | def get_example(self, i, epoch): 131 | example = self.data[i] 132 | q = example['query_input_ids'] 133 | 134 | pp = example['pos_psgs_input_ids'] 135 | p = pp[0] 136 | 137 | nn = example['neg_psgs_input_ids'] 138 | off = epoch * (self.group_size - 1) % len(nn) 139 | nn = nn * 2 140 | nn = nn[off: off + self.group_size - 1] 141 | 142 | return q, [p] + nn 143 | 144 | def get_batch(self, indices, epoch): 145 | qq, dd = zip(*[self.get_example(i, epoch) for i in map(int, indices)]) 146 | dd = sum(dd, []) 147 | return dict(tokenizer.pad(qq, max_length=32, padding='max_length', return_tensors='np')), dict( 148 | tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np')) 149 | 150 | train_dataset = TrainDataset(train_data, data_args.train_n_passages, tokenizer) 151 | 152 | def create_learning_rate_fn( 153 | train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, 154 | learning_rate: float 155 | ): 156 | """Returns a linear warmup, linear_decay learning rate function.""" 157 | steps_per_epoch = train_ds_size // train_batch_size 158 | num_train_steps = steps_per_epoch * num_train_epochs 159 | warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) 160 | decay_fn = optax.linear_schedule( 161 | init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps 162 | ) 163 | schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) 164 | return schedule_fn 165 | 166 | def _decay_mask_fn(params): 167 | flat_params = traverse_util.flatten_dict(params) 168 | layer_norm_params = [ 169 | (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] 170 | ] 171 | flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} 172 | return traverse_util.unflatten_dict(flat_mask) 173 | 174 | def decay_mask_fn(params): 175 | param_nodes, treedef = jax.tree_flatten(params, lambda v: isinstance(v, dict)) 176 | masks = [_decay_mask_fn(param_node) for param_node in param_nodes] 177 | return jax.tree_unflatten(treedef, masks) 178 | 179 | num_epochs = int(training_args.num_train_epochs) 180 | train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() 181 | steps_per_epoch = len(train_dataset) // train_batch_size 182 | total_train_steps = steps_per_epoch * num_epochs 183 | 184 | linear_decay_lr_schedule_fn = create_learning_rate_fn( 185 | len(train_dataset), 186 | train_batch_size, 187 | int(training_args.num_train_epochs), 188 | int(total_train_steps * 0.1), 189 | training_args.learning_rate, 190 | ) 191 | 192 | adamw = optax.adamw( 193 | learning_rate=linear_decay_lr_schedule_fn, 194 | b1=training_args.adam_beta1, 195 | b2=training_args.adam_beta2, 196 | eps=training_args.adam_epsilon, 197 | weight_decay=training_args.weight_decay, 198 | mask=decay_mask_fn, 199 | ) 200 | 201 | if model_args.untie_encoder: 202 | params = DualParams.create(model.params) 203 | else: 204 | params = TiedParams.create(model.params) 205 | state = RetrieverTrainState.create(apply_fn=model.__call__, params=params, tx=adamw) 206 | 207 | if training_args.grad_cache: 208 | q_n_subbatch = train_batch_size // training_args.gc_q_chunk_size 209 | p_n_subbatch = train_batch_size * data_args.train_n_passages // training_args.gc_p_chunk_size 210 | p_train_step = jax.pmap( 211 | partial(grad_cache_train_step, q_n_subbatch=q_n_subbatch, p_n_subbatch=p_n_subbatch), 212 | "device" 213 | ) 214 | else: 215 | p_train_step = jax.pmap( 216 | retriever_train_step, 217 | "device" 218 | ) 219 | 220 | state = jax_utils.replicate(state) 221 | rng = jax.random.PRNGKey(training_args.seed) 222 | dropout_rngs = jax.random.split(rng, jax.local_device_count()) 223 | 224 | class IterableTrain(IterableDataset): 225 | def __init__(self, dataset, batch_idx, epoch): 226 | super(IterableTrain).__init__() 227 | self.dataset = dataset 228 | self.batch_idx = batch_idx 229 | self.epoch = epoch 230 | 231 | def __iter__(self): 232 | for idx in self.batch_idx: 233 | batch = self.dataset.get_batch(idx, self.epoch) 234 | batch = shard(batch) 235 | yield batch 236 | 237 | logger.info("***** Running training *****") 238 | logger.info(f" Num examples = {len(train_dataset)}") 239 | logger.info(f" Num Epochs = {num_epochs}") 240 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") 241 | logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") 242 | logger.info(f" Total optimization steps = {total_train_steps}") 243 | 244 | train_metrics = [] 245 | for epoch in tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0): 246 | # ======================== Training ================================ 247 | # Create sampling rng 248 | rng, input_rng = jax.random.split(rng) 249 | 250 | steps_per_epoch = len(train_dataset) // train_batch_size 251 | 252 | batch_idx = jax.random.permutation(input_rng, len(train_dataset)) 253 | batch_idx = batch_idx[: steps_per_epoch * train_batch_size] 254 | batch_idx = batch_idx.reshape((steps_per_epoch, train_batch_size)).tolist() 255 | 256 | train_loader = prefetch_to_device( 257 | iter(DataLoader( 258 | IterableTrain(train_dataset, batch_idx, epoch), 259 | num_workers=16, prefetch_factor=256, batch_size=None, collate_fn=lambda v: v) 260 | ), 2) 261 | 262 | # train 263 | epochs = tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False) 264 | for step in epochs: 265 | cur_step = epoch * (len(train_dataset) // train_batch_size) + step 266 | batch = next(train_loader) 267 | 268 | loss, state, dropout_rngs = p_train_step(state, *batch, dropout_rngs) 269 | train_metrics.append({'loss': loss}) 270 | 271 | if cur_step % training_args.logging_steps == 0 and cur_step > 0: 272 | train_metrics = get_metrics(train_metrics) 273 | print( 274 | f"Step... ({cur_step} | Loss: {train_metrics['loss'].mean()}," 275 | f" Learning Rate: {linear_decay_lr_schedule_fn(cur_step)})", 276 | flush=True, 277 | ) 278 | train_metrics = [] 279 | 280 | epochs.write( 281 | f"Epoch... ({epoch + 1}/{num_epochs})" 282 | ) 283 | 284 | params = jax_utils.unreplicate(state.params) 285 | 286 | if model_args.untie_encoder: 287 | os.makedirs(training_args.output_dir, exist_ok=True) 288 | model.save_pretrained(os.path.join(training_args.output_dir, 'query_encoder'), params=params.q_params) 289 | model.save_pretrained(os.path.join(training_args.output_dir, 'passage_encoder'), params=params.p_params) 290 | else: 291 | model.save_pretrained(training_args.output_dir, params=params.p_params) 292 | tokenizer.save_pretrained(training_args.output_dir) 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/driver/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | from transformers import ( 7 | HfArgumentParser, 8 | set_seed, 9 | ) 10 | 11 | from tevatron.arguments import ModelArguments, DataArguments, \ 12 | TevatronTrainingArguments as TrainingArguments 13 | from tevatron.data import TrainDataset, QPCollator 14 | from tevatron.modeling import DenseModel 15 | from tevatron.trainer import TevatronTrainer as Trainer, GCTrainer 16 | from tevatron.datasets import HFTrainDataset 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def main(): 22 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 23 | 24 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 25 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 26 | else: 27 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 28 | model_args: ModelArguments 29 | data_args: DataArguments 30 | training_args: TrainingArguments 31 | 32 | if ( 33 | os.path.exists(training_args.output_dir) 34 | and os.listdir(training_args.output_dir) 35 | and training_args.do_train 36 | and not training_args.overwrite_output_dir 37 | ): 38 | raise ValueError( 39 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 40 | ) 41 | 42 | # Setup logging 43 | logging.basicConfig( 44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 45 | datefmt="%m/%d/%Y %H:%M:%S", 46 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 47 | ) 48 | logger.warning( 49 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 50 | training_args.local_rank, 51 | training_args.device, 52 | training_args.n_gpu, 53 | bool(training_args.local_rank != -1), 54 | training_args.fp16, 55 | ) 56 | logger.info("Training/evaluation parameters %s", training_args) 57 | logger.info("MODEL parameters %s", model_args) 58 | 59 | set_seed(training_args.seed) 60 | 61 | num_labels = 1 62 | config = AutoConfig.from_pretrained( 63 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 64 | num_labels=num_labels, 65 | cache_dir=model_args.cache_dir, 66 | ) 67 | tokenizer = AutoTokenizer.from_pretrained( 68 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 69 | cache_dir=model_args.cache_dir, 70 | use_fast=False, 71 | ) 72 | model = DenseModel.build( 73 | model_args, 74 | training_args, 75 | config=config, 76 | cache_dir=model_args.cache_dir, 77 | ) 78 | 79 | train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args, 80 | cache_dir=data_args.data_cache_dir or model_args.cache_dir) 81 | train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer) 82 | 83 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer 84 | trainer = trainer_cls( 85 | model=model, 86 | args=training_args, 87 | train_dataset=train_dataset, 88 | data_collator=QPCollator( 89 | tokenizer, 90 | max_p_len=data_args.p_max_len, 91 | max_q_len=data_args.q_max_len 92 | ), 93 | ) 94 | train_dataset.trainer = trainer 95 | 96 | trainer.train() # TODO: resume training 97 | trainer.save_model() 98 | if trainer.is_world_process_zero(): 99 | tokenizer.save_pretrained(training_args.output_dir) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/faiss_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import BaseFaissIPRetriever, BaseFaissHNSWRetriever 2 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/faiss_retriever/__main__.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import glob 5 | from argparse import ArgumentParser 6 | from itertools import chain 7 | from tqdm import tqdm 8 | import os 9 | import json 10 | from .retriever import BaseFaissIPRetriever, BaseFaissHNSWRetriever 11 | 12 | import logging 13 | logger = logging.getLogger(__name__) 14 | logging.basicConfig( 15 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 16 | datefmt="%m/%d/%Y %H:%M:%S", 17 | level=logging.INFO, 18 | ) 19 | 20 | 21 | def search_queries(retriever, q_reps, p_lookup, args): 22 | if args.batch_size > 0: 23 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size) 24 | else: 25 | all_scores, all_indices = retriever.search(q_reps, args.depth) 26 | 27 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices] 28 | psg_indices = np.array(psg_indices) 29 | return all_scores, psg_indices 30 | 31 | 32 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): 33 | with open(ranking_save_file, 'w') as f: 34 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): 35 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 36 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 37 | for s, idx in score_list: 38 | f.write(f'{qid}\t{idx}\t{s}\n') 39 | 40 | def pickle_load(path): 41 | with open(path, 'rb') as f: 42 | reps, lookup = pickle.load(f) 43 | return np.array(reps, dtype=np.float32), lookup 44 | 45 | def pickle_save(obj, path): 46 | with open(path, 'wb') as f: 47 | pickle.dump(obj, f) 48 | 49 | def main(): 50 | parser = ArgumentParser() 51 | parser.add_argument('--query_reps', required=True) 52 | parser.add_argument('--passage_reps', required=True) 53 | parser.add_argument('--batch_size', type=int, default=128) 54 | parser.add_argument('--depth', type=int, default=1000) 55 | parser.add_argument('--save_ranking_to', required=True) 56 | parser.add_argument('--save_text', action='store_true') 57 | parser.add_argument('--use_hnsw', action='store_true') 58 | parser.add_argument('--save_index_path') 59 | parser.add_argument('--quiet', action='store_true') 60 | 61 | args = parser.parse_args() 62 | 63 | index_files = glob.glob(args.passage_reps) 64 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') 65 | 66 | if args.use_hnsw: 67 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 68 | retriever = BaseFaissHNSWRetriever(p_reps_0) 69 | 70 | if os.path.exists(os.path.join(args.save_index_path, "index.faiss")): 71 | logger.info("Loading from saved index") 72 | retriever.load(os.path.join(args.save_index_path, "index.faiss")) 73 | look_up = json.load(open(os.path.join(args.save_index_path, "index.json"))) 74 | 75 | else: 76 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 77 | if len(index_files) > 1: 78 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 79 | look_up = [] 80 | all_p_reps = list() 81 | for p_reps, p_lookup in shards: 82 | all_p_reps.extend(p_reps) 83 | look_up += p_lookup 84 | 85 | all_p_reps = np.array(all_p_reps, dtype=np.float32) 86 | logger.info("Building index") 87 | retriever.build(all_p_reps) 88 | logger.info("Saving index") 89 | retriever.save(os.path.join(args.save_index_path, "index.faiss")) 90 | json.dump(look_up, open(os.path.join(args.save_index_path, "index.json"), "w")) 91 | 92 | else: 93 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 94 | retriever = BaseFaissIPRetriever(p_reps_0) 95 | 96 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 97 | if len(index_files) > 1: 98 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 99 | look_up = [] 100 | for p_reps, p_lookup in shards: 101 | retriever.add(p_reps) 102 | look_up += p_lookup 103 | 104 | q_reps, q_lookup = pickle_load(args.query_reps) 105 | q_reps = q_reps 106 | 107 | logger.info('Index Search Start') 108 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) 109 | logger.info('Index Search Finished') 110 | 111 | if args.save_text: 112 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) 113 | else: 114 | pickle_save((all_scores, psg_indices), args.save_ranking_to) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/faiss_retriever/reducer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import faiss 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from typing import Iterable, Tuple 6 | from numpy import ndarray 7 | from .__main__ import pickle_load, write_ranking 8 | 9 | 10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]): 11 | rh = None 12 | for scores, indices in results: 13 | if rh is None: 14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.') 15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1]) 16 | rh.add_result(-scores, indices) 17 | rh.finalize() 18 | corpus_scores, corpus_indices = -rh.D, rh.I 19 | 20 | return corpus_scores, corpus_indices 21 | 22 | 23 | def main(): 24 | parser = ArgumentParser() 25 | parser.add_argument('--score_dir', required=True) 26 | parser.add_argument('--query', required=True) 27 | parser.add_argument('--save_ranking_to', required=True) 28 | args = parser.parse_args() 29 | 30 | partitions = glob.glob(f'{args.score_dir}/*') 31 | 32 | corpus_scores, corpus_indices = combine_faiss_results(map(pickle_load, tqdm(partitions))) 33 | 34 | _, q_lookup = pickle_load(args.query) 35 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/faiss_retriever/retriever.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | from tqdm.autonotebook import trange 4 | import time 5 | 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class BaseFaissHNSWRetriever: 11 | def __init__(self, init_reps: np.ndarray, hnsw_store_n: int = 512, hnsw_ef_search: int = 128, hnsw_ef_construction: int = 64, similarity_metric=faiss.METRIC_INNER_PRODUCT): 12 | self.index = faiss.IndexHNSWFlat(init_reps.shape[1], hnsw_store_n, similarity_metric) 13 | self.index.hnsw.efSearch = hnsw_ef_search 14 | self.index.hnsw.efConstruction = hnsw_ef_construction 15 | # self.index = faiss.index_factory(init_reps.shape[1], "HNSW" + str(hnsw_store_n)) 16 | 17 | def load(self, fname: str): 18 | self.index = faiss.read_index(fname) 19 | 20 | def save(self, fname: str): 21 | faiss.write_index(self.index, fname) 22 | 23 | def build(self, p_reps: np.ndarray, buffer_size: int = 1000): 24 | # sq_norms = (p_reps ** 2).sum(1) 25 | # max_sq_norm = float(sq_norms.max()) 26 | # aux_dims = np.sqrt(max_sq_norm - sq_norms) 27 | # p_reps = np.hstack((p_reps, aux_dims.reshape(-1, 1))) 28 | for start in trange(0, p_reps.shape[0], buffer_size): 29 | self.index.add(p_reps[start : start + buffer_size]) 30 | 31 | def search(self, q_reps: np.ndarray, k: int): 32 | # q_reps = np.hstack((q_reps, np.zeros((q_reps.shape[0], 1), dtype=np.float32))) 33 | return self.index.search(q_reps, k) 34 | 35 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): 36 | num_query = q_reps.shape[0] 37 | # q_reps = np.hstack((q_reps, np.zeros((q_reps.shape[0], 1), dtype=np.float32))) 38 | all_scores = [] 39 | all_indices = [] 40 | for start_idx in trange(0, num_query, batch_size): 41 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 42 | all_scores.append(nn_scores) 43 | all_indices.append(nn_indices) 44 | all_scores = np.concatenate(all_scores, axis=0) 45 | all_indices = np.concatenate(all_indices, axis=0) 46 | 47 | return all_scores, all_indices 48 | 49 | class BaseFaissIPRetriever: 50 | def __init__(self, init_reps: np.ndarray): 51 | index = faiss.IndexFlatIP(init_reps.shape[1]) 52 | self.index = index 53 | 54 | def add(self, p_reps: np.ndarray): 55 | self.index.add(p_reps) 56 | 57 | def search(self, q_reps: np.ndarray, k: int): 58 | return self.index.search(q_reps, k) 59 | 60 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): 61 | num_query = q_reps.shape[0] 62 | all_scores = [] 63 | all_indices = [] 64 | for start_idx in trange(0, num_query, batch_size): 65 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 66 | all_scores.append(nn_scores) 67 | all_indices.append(nn_indices) 68 | all_scores = np.concatenate(all_scores, axis=0) 69 | all_indices = np.concatenate(all_indices, axis=0) 70 | 71 | return all_scores, all_indices 72 | 73 | 74 | class FaissRetriever(BaseFaissIPRetriever): 75 | 76 | def __init__(self, init_reps: np.ndarray, factory_str: str): 77 | index = faiss.index_factory(init_reps.shape[1], factory_str) 78 | self.index = index 79 | self.index.verbose = True 80 | if not self.index.is_trained: 81 | self.index.train(init_reps) 82 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torch import distributed as dist 5 | 6 | 7 | class SimpleContrastiveLoss: 8 | 9 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 10 | if target is None: 11 | target_per_qry = y.size(0) // x.size(0) 12 | target = torch.arange( 13 | 0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) 14 | logits = torch.matmul(x, y.transpose(0, 1)) 15 | return F.cross_entropy(logits, target, reduction=reduction) 16 | 17 | 18 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 19 | def __init__(self, n_target: int = 0, scale_loss: bool = True): 20 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 21 | super().__init__() 22 | self.word_size = dist.get_world_size() 23 | self.rank = dist.get_rank() 24 | self.scale_loss = scale_loss 25 | 26 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 27 | dist_x = self.gather_tensor(x) 28 | dist_y = self.gather_tensor(y) 29 | loss = super().__call__(dist_x, dist_y, **kwargs) 30 | if self.scale_loss: 31 | loss = loss * self.word_size 32 | return loss 33 | 34 | def gather_tensor(self, t): 35 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 36 | dist.all_gather(gathered, t) 37 | gathered[self.rank] = t 38 | return torch.cat(gathered, dim=0) -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import EncoderModel, EncoderPooler, EncoderOutput 2 | from .dense import DenseModel 3 | from .unicoil import UniCoilModel 4 | from .splade import SpladeModel 5 | from .colbert import ColbertModel 6 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/colbert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import logging 5 | from .encoder import EncoderPooler, EncoderModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ColbertPooler(EncoderPooler): 11 | def __init__(self, input_dim: int = 768, output_dim: int = 32, tied=True): 12 | super(ColbertPooler, self).__init__() 13 | self.linear_q = nn.Linear(input_dim, output_dim) 14 | if tied: 15 | self.linear_p = self.linear_q 16 | else: 17 | self.linear_p = nn.Linear(input_dim, output_dim) 18 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied} 19 | 20 | def forward(self, q: Tensor = None, p: Tensor = None, **kwargs): 21 | if q is not None: 22 | return self.linear_q(q) 23 | elif p is not None: 24 | return self.linear_p(p) 25 | else: 26 | raise ValueError 27 | 28 | 29 | class ColbertModel(EncoderModel): 30 | def encode_passage(self, psg): 31 | if psg is None: 32 | return None 33 | psg_out = self.lm_p(**psg, return_dict=True) 34 | p_hidden = psg_out.last_hidden_state 35 | p_reps = self.pooler(p=p_hidden) 36 | p_reps *= psg['attention_mask'][:, :, None].float() 37 | return p_reps 38 | 39 | def encode_query(self, qry): 40 | if qry is None: 41 | return None 42 | qry_out = self.lm_q(**qry, return_dict=True) 43 | q_hidden = qry_out.last_hidden_state 44 | q_reps = self.pooler(q=q_hidden) 45 | q_reps *= qry['attention_mask'][:, :, None].float() 46 | return q_reps 47 | 48 | def compute_similarity(self, q_reps, p_reps): 49 | token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps) 50 | scores, _ = token_scores.max(-1) 51 | scores = scores.sum(1) 52 | return scores 53 | 54 | @staticmethod 55 | def load_pooler(model_weights_file, **config): 56 | pooler = ColbertPooler(**config) 57 | pooler.load(model_weights_file) 58 | return pooler 59 | 60 | @staticmethod 61 | def build_pooler(model_args): 62 | pooler = ColbertPooler( 63 | model_args.projection_in_dim, 64 | model_args.projection_out_dim, 65 | tied=not model_args.untie_encoder 66 | ) 67 | pooler.load(model_args.model_name_or_path) 68 | return pooler 69 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import logging 5 | from .encoder import EncoderPooler, EncoderModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class DensePooler(EncoderPooler): 11 | def __init__(self, input_dim: int = 768, output_dim: int = 768, tied=True): 12 | super(DensePooler, self).__init__() 13 | self.linear_q = nn.Linear(input_dim, output_dim) 14 | if tied: 15 | self.linear_p = self.linear_q 16 | else: 17 | self.linear_p = nn.Linear(input_dim, output_dim) 18 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied} 19 | 20 | def forward(self, q: Tensor = None, p: Tensor = None, **kwargs): 21 | if q is not None: 22 | return self.linear_q(q[:, 0]) 23 | elif p is not None: 24 | return self.linear_p(p[:, 0]) 25 | else: 26 | raise ValueError 27 | 28 | 29 | class DenseModel(EncoderModel): 30 | def encode_passage(self, psg): 31 | if psg is None: 32 | return None 33 | psg_out = self.lm_p(**psg, return_dict=True) 34 | p_hidden = psg_out.last_hidden_state 35 | if self.pooler is not None: 36 | p_reps = self.pooler(p=p_hidden) # D * d 37 | else: 38 | if self.use_mean_pooling: 39 | p_reps = p_hidden.masked_fill(~psg['attention_mask'][..., None].bool(), 0.) 40 | p_reps = p_reps.sum(dim=1) / psg['attention_mask'].sum(dim=1)[..., None] 41 | else: 42 | p_reps = p_hidden[:, 0] 43 | return p_reps 44 | 45 | def encode_query(self, qry): 46 | if qry is None: 47 | return None 48 | qry_out = self.lm_q(**qry, return_dict=True) 49 | q_hidden = qry_out.last_hidden_state 50 | if self.pooler is not None: 51 | q_reps = self.pooler(q=q_hidden) 52 | else: 53 | if self.use_mean_pooling: 54 | q_reps = q_hidden.masked_fill(~qry['attention_mask'][..., None].bool(), 0.) 55 | q_reps = q_reps.sum(dim=1) / qry['attention_mask'].sum(dim=1)[..., None] 56 | else: 57 | q_reps = q_hidden[:, 0] 58 | return q_reps 59 | 60 | def compute_similarity(self, q_reps, p_reps): 61 | return torch.matmul(q_reps, p_reps.transpose(0, 1)) 62 | 63 | @staticmethod 64 | def load_pooler(model_weights_file, **config): 65 | pooler = DensePooler(**config) 66 | pooler.load(model_weights_file) 67 | return pooler 68 | 69 | @staticmethod 70 | def build_pooler(model_args): 71 | pooler = DensePooler( 72 | model_args.projection_in_dim, 73 | model_args.projection_out_dim, 74 | tied=not model_args.untie_encoder 75 | ) 76 | pooler.load(model_args.model_name_or_path) 77 | return pooler -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from dataclasses import dataclass 5 | from typing import Dict, Optional 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | import torch.distributed as dist 10 | from transformers import PreTrainedModel, AutoModel 11 | from transformers.file_utils import ModelOutput 12 | 13 | from tevatron.arguments import ModelArguments, \ 14 | TevatronTrainingArguments as TrainingArguments 15 | 16 | import logging 17 | import pickle 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | @dataclass 23 | class EncoderOutput(ModelOutput): 24 | q_reps: Optional[Tensor] = None 25 | p_reps: Optional[Tensor] = None 26 | loss: Optional[Tensor] = None 27 | scores: Optional[Tensor] = None 28 | 29 | 30 | class EncoderPooler(nn.Module): 31 | def __init__(self, **kwargs): 32 | super(EncoderPooler, self).__init__() 33 | self._config = {} 34 | 35 | def forward(self, q_reps, p_reps): 36 | raise NotImplementedError('EncoderPooler is an abstract class') 37 | 38 | def load(self, model_dir: str): 39 | pooler_path = os.path.join(model_dir, 'pooler.pt') 40 | if pooler_path is not None: 41 | if os.path.exists(pooler_path): 42 | logger.info(f'Loading Pooler from {pooler_path}') 43 | state_dict = torch.load(pooler_path, map_location='cpu') 44 | self.load_state_dict(state_dict) 45 | return 46 | logger.info("Training Pooler from scratch") 47 | return 48 | 49 | def save_pooler(self, save_path): 50 | torch.save(self.state_dict(), os.path.join(save_path, 'pooler.pt')) 51 | with open(os.path.join(save_path, 'pooler_config.json'), 'w') as f: 52 | json.dump(self._config, f) 53 | 54 | 55 | class EncoderModel(nn.Module): 56 | TRANSFORMER_CLS = AutoModel 57 | 58 | def __init__(self, 59 | lm_q: PreTrainedModel, 60 | lm_p: PreTrainedModel, 61 | pooler: nn.Module = None, 62 | untie_encoder: bool = False, 63 | negatives_x_device: bool = False, 64 | model_args: ModelArguments = None 65 | ): 66 | super().__init__() 67 | self.lm_q = lm_q 68 | self.lm_p = lm_p 69 | self.pooler = pooler 70 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 71 | self.negatives_x_device = negatives_x_device 72 | self.untie_encoder = untie_encoder 73 | if model_args: 74 | self.use_mean_pooling = model_args.use_mean_pooling 75 | else: 76 | logger.info("Mean Pooling Enabled") 77 | self.use_mean_pooling = True 78 | if self.negatives_x_device: 79 | if not dist.is_initialized(): 80 | raise ValueError('Distributed training has not been initialized for representation all gather.') 81 | self.process_rank = dist.get_rank() 82 | self.world_size = dist.get_world_size() 83 | 84 | def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None): 85 | q_reps = self.encode_query(query) 86 | p_reps = self.encode_passage(passage) 87 | 88 | # for inference 89 | if q_reps is None or p_reps is None: 90 | return EncoderOutput( 91 | q_reps=q_reps, 92 | p_reps=p_reps 93 | ) 94 | 95 | # for training 96 | if self.training: 97 | if self.negatives_x_device: 98 | q_reps = self._dist_gather_tensor(q_reps) 99 | p_reps = self._dist_gather_tensor(p_reps) 100 | 101 | scores = self.compute_similarity(q_reps, p_reps) 102 | scores = scores.view(q_reps.size(0), -1) 103 | 104 | target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) 105 | target = target * (p_reps.size(0) // q_reps.size(0)) 106 | 107 | loss = self.compute_loss(scores, target) 108 | if self.negatives_x_device: 109 | loss = loss * self.world_size # counter average weight reduction 110 | # for eval 111 | else: 112 | scores = self.compute_similarity(q_reps, p_reps) 113 | loss = None 114 | return EncoderOutput( 115 | loss=loss, 116 | scores=scores, 117 | q_reps=q_reps, 118 | p_reps=p_reps, 119 | ) 120 | 121 | @staticmethod 122 | def build_pooler(model_args): 123 | return None 124 | 125 | @staticmethod 126 | def load_pooler(weights, **config): 127 | return None 128 | 129 | def encode_passage(self, psg): 130 | raise NotImplementedError('EncoderModel is an abstract class') 131 | 132 | def encode_query(self, qry): 133 | raise NotImplementedError('EncoderModel is an abstract class') 134 | 135 | def compute_similarity(self, q_reps, p_reps): 136 | return torch.matmul(q_reps, p_reps.transpose(0, 1)) 137 | 138 | def compute_loss(self, scores, target): 139 | return self.cross_entropy(scores, target) 140 | 141 | def _dist_gather_tensor(self, t: Optional[torch.Tensor]): 142 | if t is None: 143 | return None 144 | t = t.contiguous() 145 | 146 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] 147 | dist.all_gather(all_tensors, t) 148 | 149 | all_tensors[self.process_rank] = t 150 | all_tensors = torch.cat(all_tensors, dim=0) 151 | 152 | return all_tensors 153 | 154 | @classmethod 155 | def build( 156 | cls, 157 | model_args: ModelArguments, 158 | train_args: TrainingArguments, 159 | **hf_kwargs, 160 | ): 161 | # load local 162 | if os.path.isdir(model_args.model_name_or_path): 163 | if model_args.untie_encoder: 164 | _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model') 165 | _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model') 166 | if not os.path.exists(_qry_model_path): 167 | _qry_model_path = model_args.model_name_or_path 168 | _psg_model_path = model_args.model_name_or_path 169 | logger.info(f'loading query model weight from {_qry_model_path}') 170 | lm_q = cls.TRANSFORMER_CLS.from_pretrained( 171 | _qry_model_path, 172 | **hf_kwargs 173 | ) 174 | logger.info(f'loading passage model weight from {_psg_model_path}') 175 | lm_p = cls.TRANSFORMER_CLS.from_pretrained( 176 | _psg_model_path, 177 | **hf_kwargs 178 | ) 179 | else: 180 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 181 | lm_p = lm_q 182 | # load pre-trained 183 | else: 184 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 185 | lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q 186 | 187 | if model_args.add_pooler: 188 | pooler = cls.build_pooler(model_args) 189 | else: 190 | pooler = None 191 | 192 | model = cls( 193 | lm_q=lm_q, 194 | lm_p=lm_p, 195 | pooler=pooler, 196 | negatives_x_device=train_args.negatives_x_device, 197 | untie_encoder=model_args.untie_encoder, 198 | model_args=model_args 199 | ) 200 | return model 201 | 202 | @classmethod 203 | def load( 204 | cls, 205 | model_name_or_path, 206 | **hf_kwargs, 207 | ): 208 | # load local 209 | untie_encoder = True 210 | if os.path.isdir(model_name_or_path): 211 | _qry_model_path = os.path.join(model_name_or_path, 'query_model') 212 | _psg_model_path = os.path.join(model_name_or_path, 'passage_model') 213 | if os.path.exists(_qry_model_path): 214 | logger.info(f'found separate weight for query/passage encoders') 215 | logger.info(f'loading query model weight from {_qry_model_path}') 216 | lm_q = cls.TRANSFORMER_CLS.from_pretrained( 217 | _qry_model_path, 218 | **hf_kwargs 219 | ) 220 | logger.info(f'loading passage model weight from {_psg_model_path}') 221 | lm_p = cls.TRANSFORMER_CLS.from_pretrained( 222 | _psg_model_path, 223 | **hf_kwargs 224 | ) 225 | untie_encoder = False 226 | else: 227 | logger.info(f'try loading tied weight') 228 | logger.info(f'loading model weight from {model_name_or_path}') 229 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs) 230 | lm_p = lm_q 231 | else: 232 | logger.info(f'try loading tied weight') 233 | logger.info(f'loading model weight from {model_name_or_path}') 234 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs) 235 | lm_p = lm_q 236 | 237 | pooler_weights = os.path.join(model_name_or_path, 'pooler.pt') 238 | pooler_config = os.path.join(model_name_or_path, 'pooler_config.json') 239 | if os.path.exists(pooler_weights) and os.path.exists(pooler_config): 240 | logger.info(f'found pooler weight and configuration') 241 | with open(pooler_config) as f: 242 | pooler_config_dict = json.load(f) 243 | pooler = cls.load_pooler(model_name_or_path, **pooler_config_dict) 244 | else: 245 | pooler = None 246 | 247 | model_args_path = os.path.join(model_name_or_path, 'model_args.bin') 248 | model_args = None 249 | if os.path.exists(model_args_path): 250 | model_args = pickle.load(open(model_args_path, "rb")) 251 | 252 | model = cls( 253 | lm_q=lm_q, 254 | lm_p=lm_p, 255 | pooler=pooler, 256 | untie_encoder=untie_encoder, 257 | model_args = model_args 258 | ) 259 | return model 260 | 261 | def save(self, output_dir: str): 262 | if self.untie_encoder: 263 | os.makedirs(os.path.join(output_dir, 'query_model')) 264 | os.makedirs(os.path.join(output_dir, 'passage_model')) 265 | self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model')) 266 | self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model')) 267 | else: 268 | self.lm_q.save_pretrained(output_dir) 269 | if self.pooler: 270 | self.pooler.save_pooler(output_dir) -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/splade.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from transformers import AutoModelForMaskedLM 4 | from .encoder import EncoderModel 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class SpladeModel(EncoderModel): 10 | TRANSFORMER_CLS = AutoModelForMaskedLM 11 | 12 | def encode_passage(self, psg): 13 | if psg is None: 14 | return None 15 | psg_out = self.lm_p(**psg, return_dict=True).logits 16 | aggregated_psg_out, _ = torch.max(torch.log(1 + torch.relu(psg_out)) * psg['attention_mask'].unsqueeze(-1), dim=1) 17 | return aggregated_psg_out 18 | 19 | def encode_query(self, qry): 20 | if qry is None: 21 | return None 22 | qry_out = self.lm_q(**qry, return_dict=True).logits 23 | aggregated_psg_out, _ = torch.max(torch.log(1 + torch.relu(qry_out)) * qry['attention_mask'].unsqueeze(-1), dim=1) 24 | return aggregated_psg_out 25 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/modeling/unicoil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | import logging 4 | 5 | from .encoder import EncoderPooler, EncoderModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class UniCoilPooler(EncoderPooler): 11 | def __init__(self, input_dim: int = 768, tied=True): 12 | super(UniCoilPooler, self).__init__() 13 | self.linear_q = nn.Linear(input_dim, 1) 14 | if tied: 15 | self.linear_p = self.linear_q 16 | else: 17 | self.linear_p = nn.Linear(input_dim, 1) 18 | self._config = {'input_dim': input_dim, 'tied': tied} 19 | 20 | def forward(self, q: Tensor = None, p: Tensor = None): 21 | if q is not None: 22 | return self.linear_q(q) 23 | elif p is not None: 24 | return self.linear_p(p) 25 | else: 26 | raise ValueError 27 | 28 | 29 | class UniCoilModel(EncoderModel): 30 | def encode_passage(self, psg): 31 | if psg is None: 32 | return None 33 | psg_out = self.lm_p(**psg, return_dict=True) 34 | p_hidden = psg_out.last_hidden_state 35 | p_reps = self.pooler(p=p_hidden) 36 | return self._weights_to_vec(psg['input_ids'], p_reps) 37 | 38 | def encode_query(self, qry): 39 | if qry is None: 40 | return None 41 | qry_out = self.lm_q(**qry, return_dict=True) 42 | q_hidden = qry_out.last_hidden_state 43 | q_reps = self.pooler(q=q_hidden) 44 | return self._weights_to_vec(qry['input_ids'], q_reps) 45 | 46 | def compute_similarity(self, q_reps, p_reps): 47 | return torch.matmul(q_reps, p_reps.transpose(0, 1)) 48 | 49 | def _weights_to_vec(self, input_ids, tok_weights): 50 | input_shape = input_ids.size() 51 | tok_weights = torch.relu(tok_weights) 52 | tok_emb = torch.zeros(input_shape[0], input_shape[1], self.lm_p.config.vocab_size, dtype=tok_weights.dtype, 53 | device=input_ids.device) 54 | tok_emb = torch.scatter(tok_emb, dim=-1, index=input_ids.unsqueeze(-1), src=tok_weights) 55 | disabled_token_ids = [0, 101, 102, 103] # hard code for bert for now, can pass in a tokenizer in the future 56 | tok_emb = torch.max(tok_emb, dim=1).values 57 | tok_emb[:, disabled_token_ids] *= 0 58 | return tok_emb 59 | 60 | @staticmethod 61 | def build_pooler(model_args): 62 | pooler = UniCoilPooler( 63 | model_args.projection_in_dim, 64 | tied=not model_args.untie_encoder 65 | ) 66 | pooler.load(model_args.model_name_or_path) 67 | return pooler 68 | 69 | @staticmethod 70 | def load_pooler(model_weights_file, **config): 71 | pooler = UniCoilPooler(**config) 72 | pooler.load(model_weights_file) 73 | return pooler 74 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessor_tsv import SimpleTrainPreProcessor as MarcoPassageTrainPreProcessor, \ 2 | SimpleCollectionPreProcessor as MarcoPassageCollectionPreProcessor -------------------------------------------------------------------------------- /tevatron/src/tevatron/preprocessor/normalize_text.py: -------------------------------------------------------------------------------- 1 | #: Control characters. 2 | CONTROLS = { 3 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 4 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 5 | } 6 | # There are further control characters, but they are instead replaced with a space by unicode normalization 7 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 8 | 9 | 10 | #: Hyphen and dash characters. 11 | HYPHENS = { 12 | '-', # \u002d Hyphen-minus 13 | '‐', # \u2010 Hyphen 14 | '‑', # \u2011 Non-breaking hyphen 15 | '⁃', # \u2043 Hyphen bullet 16 | '‒', # \u2012 figure dash 17 | '–', # \u2013 en dash 18 | '—', # \u2014 em dash 19 | '―', # \u2015 horizontal bar 20 | } 21 | 22 | #: Minus characters. 23 | MINUSES = { 24 | '-', # \u002d Hyphen-minus 25 | '−', # \u2212 Minus 26 | '-', # \uff0d Full-width Hyphen-minus 27 | '⁻', # \u207b Superscript minus 28 | } 29 | 30 | #: Plus characters. 31 | PLUSES = { 32 | '+', # \u002b Plus 33 | '+', # \uff0b Full-width Plus 34 | '⁺', # \u207a Superscript plus 35 | } 36 | 37 | #: Slash characters. 38 | SLASHES = { 39 | '/', # \u002f Solidus 40 | '⁄', # \u2044 Fraction slash 41 | '∕', # \u2215 Division slash 42 | } 43 | 44 | #: Tilde characters. 45 | TILDES = { 46 | '~', # \u007e Tilde 47 | '˜', # \u02dc Small tilde 48 | '⁓', # \u2053 Swung dash 49 | '∼', # \u223c Tilde operator #in mbert vocab 50 | '∽', # \u223d Reversed tilde 51 | '∿', # \u223f Sine wave 52 | '〜', # \u301c Wave dash #in mbert vocab 53 | '~', # \uff5e Full-width tilde #in mbert vocab 54 | } 55 | 56 | #: Apostrophe characters. 57 | APOSTROPHES = { 58 | "'", # \u0027 59 | '’', # \u2019 60 | '՚', # \u055a 61 | 'Ꞌ', # \ua78b 62 | 'ꞌ', # \ua78c 63 | ''', # \uff07 64 | } 65 | 66 | #: Single quote characters. 67 | SINGLE_QUOTES = { 68 | "'", # \u0027 69 | '‘', # \u2018 70 | '’', # \u2019 71 | '‚', # \u201a 72 | '‛', # \u201b 73 | 74 | } 75 | 76 | #: Double quote characters. 77 | DOUBLE_QUOTES = { 78 | '"', # \u0022 79 | '“', # \u201c 80 | '”', # \u201d 81 | '„', # \u201e 82 | '‟', # \u201f 83 | } 84 | 85 | #: Accent characters. 86 | ACCENTS = { 87 | '`', # \u0060 88 | '´', # \u00b4 89 | } 90 | 91 | #: Prime characters. 92 | PRIMES = { 93 | '′', # \u2032 94 | '″', # \u2033 95 | '‴', # \u2034 96 | '‵', # \u2035 97 | '‶', # \u2036 98 | '‷', # \u2037 99 | '⁗', # \u2057 100 | } 101 | 102 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 103 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 104 | 105 | def normalize(text): 106 | for control in CONTROLS: 107 | text = text.replace(control, '') 108 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 109 | 110 | for hyphen in HYPHENS | MINUSES: 111 | text = text.replace(hyphen, '-') 112 | text = text.replace('\u00ad', '') 113 | 114 | for double_quote in DOUBLE_QUOTES: 115 | text = text.replace(double_quote, '"') # \u0022 116 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 117 | text = text.replace(single_quote, "'") # \u0027 118 | text = text.replace('′', "'") # \u2032 prime 119 | text = text.replace('‵', "'") # \u2035 reversed prime 120 | text = text.replace('″', "''") # \u2033 double prime 121 | text = text.replace('‶', "''") # \u2036 reversed double prime 122 | text = text.replace('‴', "'''") # \u2034 triple prime 123 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 124 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 125 | 126 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 127 | 128 | for slash in SLASHES: 129 | text = text.replace(slash, '/') 130 | 131 | #for tilde in TILDES: 132 | # text = text.replace(tilde, '~') 133 | 134 | return text -------------------------------------------------------------------------------- /tevatron/src/tevatron/preprocessor/preprocessor_tsv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import datasets 4 | from transformers import PreTrainedTokenizer 5 | from dataclasses import dataclass 6 | from .normalize_text import normalize 7 | 8 | @dataclass 9 | class SimpleTrainPreProcessor: 10 | query_file: str 11 | collection_file: str 12 | tokenizer: PreTrainedTokenizer 13 | 14 | max_length: int = 128 15 | columns = ['text_id', 'title', 'text'] 16 | title_field = 'title' 17 | text_field = 'text' 18 | 19 | def __post_init__(self): 20 | self.queries = self.read_queries(self.query_file) 21 | self.collection = datasets.load_dataset( 22 | 'csv', 23 | data_files=self.collection_file, 24 | column_names=self.columns, 25 | delimiter='\t', 26 | )['train'] 27 | 28 | @staticmethod 29 | def read_queries(queries): 30 | qmap = {} 31 | with open(queries) as f: 32 | for l in f: 33 | qid, qry = l.strip().split('\t') 34 | qmap[qid] = qry 35 | return qmap 36 | 37 | @staticmethod 38 | def read_qrel(relevance_file): 39 | qrel = {} 40 | with open(relevance_file, encoding='utf8') as f: 41 | tsvreader = csv.reader(f, delimiter="\t") 42 | for [topicid, _, docid, rel] in tsvreader: 43 | assert rel == "1" 44 | if topicid in qrel: 45 | qrel[topicid].append(docid) 46 | else: 47 | qrel[topicid] = [docid] 48 | return qrel 49 | 50 | def get_query(self, q): 51 | query_encoded = self.tokenizer.encode( 52 | self.queries[q], 53 | add_special_tokens=False, 54 | max_length=self.max_length, 55 | truncation=True 56 | ) 57 | return query_encoded 58 | 59 | def get_passage(self, p): 60 | entry = self.collection[int(p)] 61 | title = entry[self.title_field] 62 | title = "" if title is None else title 63 | body = entry[self.text_field] 64 | content = title + self.tokenizer.sep_token + body 65 | 66 | passage_encoded = self.tokenizer.encode( 67 | content, 68 | add_special_tokens=False, 69 | max_length=self.max_length, 70 | truncation=True 71 | ) 72 | 73 | return passage_encoded 74 | 75 | def process_one(self, train): 76 | q, pp, nn = train 77 | train_example = { 78 | 'query': self.get_query(q), 79 | 'positives': [self.get_passage(p) for p in pp], 80 | 'negatives': [self.get_passage(n) for n in nn], 81 | } 82 | 83 | return json.dumps(train_example) 84 | 85 | 86 | @dataclass 87 | class SimpleCollectionPreProcessor: 88 | tokenizer: PreTrainedTokenizer 89 | separator: str = '\t' 90 | max_length: int = 128 91 | is_query: bool = False 92 | text_first: bool = False 93 | lower_case: bool = False 94 | normalize_text: bool = False 95 | 96 | def process_line(self, line: str): 97 | xx = line.strip().split(self.separator) 98 | text_id, text = xx[0], xx[1:] 99 | if len(text) == 1 and (not self.is_query): 100 | text.append("") 101 | if self.text_first: 102 | text = text[::-1] 103 | for i in range(0, len(text)): 104 | if self.normalize_text: 105 | text[i] = normalize(text[i]) 106 | if self.lower_case: 107 | text[i] = text[i].lower() 108 | text_encoded = self.tokenizer.encode( 109 | self.tokenizer.sep_token.join(text), 110 | add_special_tokens=False, 111 | max_length=self.max_length, 112 | truncation=True 113 | ) 114 | encoded = { 115 | 'text_id': text_id, 116 | 'text': text_encoded 117 | } 118 | return json.dumps(encoded) 119 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/tevax/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import TiedParams, DualParams, RetrieverTrainState, retriever_train_step 2 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/tevax/loss.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | import optax 4 | import chex 5 | 6 | 7 | def _onehot(labels: chex.Array, num_classes: int) -> chex.Array: 8 | x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,)) 9 | x = lax.select(x, jnp.ones(x.shape), jnp.zeros(x.shape)) 10 | return x.astype(jnp.float32) 11 | 12 | 13 | def p_contrastive_loss(ss: chex.Array, tt: chex.Array, axis: str = 'device') -> chex.Array: 14 | per_shard_targets = tt.shape[0] 15 | per_sample_targets = int(tt.shape[0] / ss.shape[0]) 16 | labels = jnp.arange(0, per_shard_targets, per_sample_targets) + per_shard_targets * lax.axis_index(axis) 17 | 18 | tt = lax.all_gather(tt, axis).reshape((-1, ss.shape[-1])) 19 | scores = jnp.dot(ss, jnp.transpose(tt)) 20 | 21 | return optax.softmax_cross_entropy(scores, _onehot(labels, scores.shape[-1])) 22 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/tevax/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple, Any, Union 3 | 4 | import jax 5 | from jax import numpy as jnp 6 | 7 | from flax.training.train_state import TrainState 8 | from flax.core import FrozenDict 9 | from flax.struct import PyTreeNode 10 | 11 | from .loss import p_contrastive_loss 12 | 13 | 14 | class TiedParams(PyTreeNode): 15 | params: FrozenDict[str, Any] 16 | 17 | @property 18 | def q_params(self): 19 | return self.params 20 | 21 | @property 22 | def p_params(self): 23 | return self.params 24 | 25 | @classmethod 26 | def create(cls, params): 27 | return cls(params=params) 28 | 29 | 30 | class DualParams(PyTreeNode): 31 | params: Tuple[FrozenDict[str, Any], FrozenDict[str, Any]] 32 | 33 | @property 34 | def q_params(self): 35 | return self.params[0] 36 | 37 | @property 38 | def p_params(self): 39 | return self.params[1] 40 | 41 | @classmethod 42 | def create(cls, *ps): 43 | if len(ps) == 1: 44 | return cls(params=ps*2) 45 | else: 46 | p_params, q_params = ps 47 | return cls(params=[p_params, q_params]) 48 | 49 | 50 | class RetrieverTrainState(TrainState): 51 | params: Union[TiedParams, DualParams] 52 | 53 | 54 | def retriever_train_step(state, queries, passages, dropout_rng, axis='device'): 55 | q_dropout_rng, p_dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 3) 56 | 57 | def compute_loss(params): 58 | q_reps = state.apply_fn(**queries, params=params.q_params, dropout_rng=q_dropout_rng, train=True)[0][:, 0, :] 59 | p_reps = state.apply_fn(**passages, params=params.p_params, dropout_rng=p_dropout_rng, train=True)[0][:, 0, :] 60 | return jnp.mean(p_contrastive_loss(q_reps, p_reps, axis=axis)) 61 | 62 | loss, grad = jax.value_and_grad(compute_loss)(state.params) 63 | loss, grad = jax.lax.pmean([loss, grad], axis) 64 | 65 | new_state = state.apply_gradients(grads=grad) 66 | 67 | return loss, new_state, new_dropout_rng 68 | 69 | 70 | def grad_cache_train_step(state, queries, passages, dropout_rng, axis='device', q_n_subbatch=1, p_n_subbatch=1): 71 | try: 72 | from grad_cache import cachex 73 | except ImportError: 74 | raise ModuleNotFoundError('GradCache packaged needs to be installed for running grad_cache_train_step') 75 | 76 | def encode_query(params, **kwargs): 77 | return state.apply_fn(**kwargs, params=params.q_params, train=True)[0][:, 0, :] 78 | 79 | def encode_passage(params, **kwargs): 80 | return state.apply_fn(**kwargs, params=params.p_params, train=True)[0][:, 0, :] 81 | 82 | queries, passages = cachex.tree_chunk(queries, q_n_subbatch), cachex.tree_chunk(passages, p_n_subbatch) 83 | q_rngs, p_rngs, new_rng = jax.random.split(dropout_rng, 3) 84 | q_rngs = jax.random.split(q_rngs, q_n_subbatch) 85 | p_rngs = jax.random.split(p_rngs, p_n_subbatch) 86 | 87 | q_reps = cachex.chunk_encode(partial(encode_query, state.params))(**queries, dropout_rng=q_rngs) 88 | p_reps = cachex.chunk_encode(partial(encode_passage, state.params))(**passages, dropout_rng=p_rngs) 89 | 90 | @cachex.unchunk_args(axis=0, argnums=(0, 1)) 91 | def compute_loss(xx, yy): 92 | return jnp.mean(p_contrastive_loss(xx, yy, axis=axis)) 93 | 94 | loss, (q_grads, p_grads) = jax.value_and_grad(compute_loss, argnums=(0, 1))(q_reps, p_reps) 95 | 96 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params) 97 | grads = cachex.cache_grad(encode_query)(state.params, grads, q_grads, **queries, dropout_rng=q_rngs) 98 | grads = cachex.cache_grad(encode_passage)(state.params, grads, p_grads, **passages, dropout_rng=p_rngs) 99 | 100 | loss, grads = jax.lax.pmean([loss, grads], axis) 101 | new_state = state.apply_gradients(grads=grads) 102 | return loss, new_state, new_rng 103 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | from typing import Dict, List, Tuple, Optional, Any, Union 4 | 5 | from transformers.trainer import Trainer 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.distributed as dist 10 | 11 | from .loss import SimpleContrastiveLoss, DistributedContrastiveLoss 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | try: 17 | from grad_cache import GradCache 18 | _grad_cache_available = True 19 | except ModuleNotFoundError: 20 | _grad_cache_available = False 21 | 22 | 23 | class TevatronTrainer(Trainer): 24 | def __init__(self, *args, **kwargs): 25 | super(TevatronTrainer, self).__init__(*args, **kwargs) 26 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1 27 | 28 | def _save(self, output_dir: Optional[str] = None): 29 | output_dir = output_dir if output_dir is not None else self.args.output_dir 30 | os.makedirs(output_dir, exist_ok=True) 31 | logger.info("Saving model checkpoint to %s", output_dir) 32 | self.model.save(output_dir) 33 | 34 | def _prepare_inputs( 35 | self, 36 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...] 37 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]: 38 | prepared = [] 39 | for x in inputs: 40 | if isinstance(x, torch.Tensor): 41 | prepared.append(x.to(self.args.device)) 42 | else: 43 | prepared.append(super()._prepare_inputs(x)) 44 | return prepared 45 | 46 | def get_train_dataloader(self) -> DataLoader: 47 | if self.train_dataset is None: 48 | raise ValueError("Trainer: training requires a train_dataset.") 49 | train_sampler = self._get_train_sampler() 50 | 51 | return DataLoader( 52 | self.train_dataset, 53 | batch_size=self.args.train_batch_size, 54 | sampler=train_sampler, 55 | collate_fn=self.data_collator, 56 | drop_last=True, 57 | num_workers=self.args.dataloader_num_workers, 58 | ) 59 | 60 | def compute_loss(self, model, inputs): 61 | query, passage = inputs 62 | return model(query=query, passage=passage).loss 63 | 64 | def training_step(self, *args): 65 | return super(TevatronTrainer, self).training_step(*args) / self._dist_loss_scale_factor 66 | 67 | 68 | def split_dense_inputs(model_input: dict, chunk_size: int): 69 | assert len(model_input) == 1 70 | arg_key = list(model_input.keys())[0] 71 | arg_val = model_input[arg_key] 72 | 73 | keys = list(arg_val.keys()) 74 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] 75 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 76 | 77 | return [{arg_key: c} for c in chunked_arg_val] 78 | 79 | 80 | def get_dense_rep(x): 81 | if x.q_reps is None: 82 | return x.p_reps 83 | else: 84 | return x.q_reps 85 | 86 | 87 | class GCTrainer(TevatronTrainer): 88 | def __init__(self, *args, **kwargs): 89 | logger.info('Initializing Gradient Cache Trainer') 90 | if not _grad_cache_available: 91 | raise ValueError( 92 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.') 93 | super(GCTrainer, self).__init__(*args, **kwargs) 94 | 95 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss 96 | loss_fn = loss_fn_cls() 97 | 98 | self.gc = GradCache( 99 | models=[self.model, self.model], 100 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], 101 | loss_fn=loss_fn, 102 | split_input_fn=split_dense_inputs, 103 | get_rep_fn=get_dense_rep, 104 | fp16=self.args.fp16, 105 | scaler=self.scaler if self.args.fp16 else None 106 | ) 107 | 108 | def training_step(self, model, inputs) -> torch.Tensor: 109 | model.train() 110 | queries, passages = self._prepare_inputs(inputs) 111 | queries, passages = {'query': queries}, {'passage': passages} 112 | 113 | _distributed = self.args.local_rank > -1 114 | self.gc.models = [model, model] 115 | loss = self.gc(queries, passages, no_sync_except_last=_distributed) 116 | 117 | return loss / self._dist_loss_scale_factor 118 | -------------------------------------------------------------------------------- /tevatron/src/tevatron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import evaluation -------------------------------------------------------------------------------- /tevatron/src/tevatron/utils/convert_from_dpr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dpr_model', required=True) 10 | parser.add_argument('--save_to', required=True) 11 | args = parser.parse_args() 12 | 13 | dpr_model_ckpt = torch.load(args.dpr_model, map_location='cpu') 14 | config_name = dpr_model_ckpt['encoder_params']['pretrained_model_cfg'] 15 | dpr_model_dict = dpr_model_ckpt['model_dict'] 16 | 17 | AutoConfig.from_pretrained(config_name).save_pretrained(args.save_to) 18 | AutoTokenizer.from_pretrained(config_name).save_pretrained(args.save_to) 19 | 20 | question_keys = [k for k in dpr_model_dict.keys() if k.startswith('question_model')] 21 | ctx_keys = [k for k in dpr_model_dict.keys() if k.startswith('ctx_model')] 22 | 23 | question_dict = dict([(k[len('question_model')+1:], dpr_model_dict[k]) for k in question_keys]) 24 | ctx_dict = dict([(k[len('ctx_model')+1:], dpr_model_dict[k]) for k in ctx_keys]) 25 | 26 | os.makedirs(os.path.join(args.save_to, 'query_model'), exist_ok=True) 27 | os.makedirs(os.path.join(args.save_to, 'passage_model'), exist_ok=True) 28 | torch.save(question_dict, os.path.join(args.save_to, 'query_model', 'pytorch_model.bin')) 29 | torch.save(ctx_dict, os.path.join(args.save_to, 'passage_model', 'pytorch_model.bin')) 30 | 31 | 32 | if __name__ == '__main__': 33 | main() -------------------------------------------------------------------------------- /tevatron/src/tevatron/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def exact_match_score(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def ems(prediction, ground_truths): 142 | return max([exact_match_score(prediction, gt) for gt in ground_truths]) 143 | 144 | #################################################### 145 | ######## RETRIEVER EVALUATION ######## 146 | #################################################### 147 | 148 | def eval_batch(scores, inversions, avg_topk, idx_topk): 149 | for k, s in enumerate(scores): 150 | s = s.cpu().numpy() 151 | sorted_idx = np.argsort(-s) 152 | score(sorted_idx, inversions, avg_topk, idx_topk) 153 | 154 | def count_inversions(arr): 155 | inv_count = 0 156 | lenarr = len(arr) 157 | for i in range(lenarr): 158 | for j in range(i + 1, lenarr): 159 | if (arr[i] > arr[j]): 160 | inv_count += 1 161 | return inv_count 162 | 163 | def score(x, inversions, avg_topk, idx_topk): 164 | x = np.array(x) 165 | inversions.append(count_inversions(x)) 166 | for k in avg_topk: 167 | # ratio of passages in the predicted top-k that are 168 | # also in the topk given by gold score 169 | avg_pred_topk = (x[:k]