├── .gitignore
├── FIRST.png
├── README.md
├── bash
├── run_1st_retrieval.sh
├── run_2nd_retrieval.sh
├── run_convert_results.sh
├── run_distill.sh
├── run_eval.sh
├── run_prepare_distill.sh
├── run_rerank_CE.sh
├── run_rerank_llm.sh
└── run_train.sh
├── scripts
├── convert_results.py
├── distill.py
├── eval.py
├── prepare_distill.py
├── rerank_CE.py
├── rerank_llm.py
├── train_ranking.py
└── utils
│ ├── __init__.py
│ ├── dataset.py
│ ├── llm_util.py
│ ├── loss.py
│ ├── rank_listwise_os_llm.py
│ ├── rankllm.py
│ ├── reranker.py
│ ├── result.py
│ └── train_utils.py
├── tevatron
├── .gitignore
├── LICENSE
├── README.md
├── mkdocs.yml
├── setup.py
└── src
│ ├── .DS_Store
│ └── tevatron
│ ├── __init__.py
│ ├── arguments.py
│ ├── data.py
│ ├── datasets
│ ├── __init__.py
│ ├── dataset.py
│ └── preprocessor.py
│ ├── driver
│ ├── __init__.py
│ ├── encode.py
│ ├── encode_new.py
│ ├── jax_encode.py
│ ├── jax_train.py
│ └── train.py
│ ├── faiss_retriever
│ ├── __init__.py
│ ├── __main__.py
│ ├── reducer.py
│ └── retriever.py
│ ├── loss.py
│ ├── modeling
│ ├── __init__.py
│ ├── colbert.py
│ ├── dense.py
│ ├── encoder.py
│ ├── splade.py
│ └── unicoil.py
│ ├── preprocessor
│ ├── __init__.py
│ ├── normalize_text.py
│ └── preprocessor_tsv.py
│ ├── tevax
│ ├── __init__.py
│ ├── loss.py
│ └── training.py
│ ├── trainer.py
│ └── utils
│ ├── __init__.py
│ ├── convert_from_dpr.py
│ ├── evaluation.py
│ ├── format
│ ├── __init__.py
│ ├── convert_result_to_marco.py
│ └── convert_result_to_trec.py
│ └── normalize_text.py
└── train_configs
├── accel_config_deepspeed.yaml
└── zero3_bf16.json
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/*
2 | data/*
3 | scripts/__pycache__/*
4 | scripts/utils/__pycache__/*
5 | models/*
6 | scripts/*.ipynb
7 | wandb/*
8 | logs/*
9 | qrels/*
10 | outputs/*
11 | scripts/latency_test.py
12 | scripts/logits_reranking_test.py
13 | temp/*
--------------------------------------------------------------------------------
/FIRST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/FIRST.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FIRST: Faster Improved Listwise Reranking with Single Token Decoding
2 |
3 | This repository contains the code for the paper [FIRST: Faster Improved Listwise Reranking with Single Token Decoding](https://arxiv.org/pdf/2406.15657)
4 |
5 | FIRST is a novel listwise LLM reranking approach leveraging the output logits of the first generated identifier to obtain a ranked ordering of the input candidates directly. FIRST incorporates a learning-to-rank loss during training, prioritizing ranking accuracy for the more relevant passages.
6 |
7 |
8 |
9 |
10 | ## Installation
11 | You need to install the tevatron library (original source [here](https://github.com/texttron/tevatron)) which provides the framework for retrieval.
12 |
13 | ```
14 | git clone https://github.com/gangiswag/llm-reranker.git
15 | cd llm-reranker
16 | conda create --name reranker python=3.9.18
17 | cd tevatron
18 | pip install --editable .
19 | pip install beir
20 | ```
21 | **Note:** You need to install the vLLM library (instructions [here](https://docs.vllm.ai/en/latest/getting_started/installation.html)) which provides optimization for LLM inference.
22 |
23 | Before running the scripts below, do
24 | ```
25 | export REPO_DIR=
38 | ```
39 | To get the retrieval scores, run:
40 |
41 | ```
42 | bash bash/beir/run_eval.sh rank
43 | ```
44 |
45 | ## 2. Reranking
46 | ### 2a. Baseline Cross-encoder reranking
47 |
48 | To run the baseline cross encoder re-ranking, run:
49 | ```
50 | bash bash/beir/run_rerank.sh
51 | ```
52 | ### 2b. FIRST LLM Reranking
53 |
54 | To convert the retrieval results to input for LLM reranking, run:
55 |
56 | ```
57 | bash bash/beir/run_convert_results.sh
58 | ```
59 |
60 | We provide the trained FIRST reranker [here](https://huggingface.co/rryisthebest/First_Model).
61 |
62 | To run the FIRST reranking, run:
63 |
64 | ```
65 | bash bash/beir/run_rerank_llm.sh
66 | ```
67 |
68 | To evaluate the reranking performance, run:
69 |
70 | ```
71 | bash bash/run_eval.sh rerank
72 |
73 | ```
74 | **Note:** Set flag --suffix to "llm_FIRST_alpha" for FIRST reranker evaluation or "ce" for cross encoder reranker
75 |
76 | ## 3. Model Training
77 | We also provide the data and scripts to train the LLM reranker by yourself if you wish to do so.
78 | ### 3a. Training Dataset
79 | Converted training dataset (alphabetic IDs) is on [HF](https://huggingface.co/datasets/rryisthebest/rank_zephyr_training_data_alpha). The standard numeric training dataset can be found [here](https://huggingface.co/datasets/castorini/rank_zephyr_training_data).
80 |
81 | ### 3b. Training
82 | We support three training objectives:
83 |
84 | - **Ranking**: The Ranking objective uses a learning-to-rank algorithm to output the logits for the highest-ranked passage ID.
85 | - **Generation**: The Generation objective follows the principles of Causal Language Modeling, focusing on permutation generation.
86 | - **Combined**: The Combined objective, which we introduce in our paper, is a novel weighted approach that seamlessly integrates both ranking and generation principles, and is the setting applied to the FIRST model.
87 |
88 |
89 | To train the model, run:
90 | ```
91 | bash bash/beir/run_train.sh
92 | ```
93 |
94 | To train a gated model, login to Huggingface and get token access at huggingface.co/settings/tokens.
95 | ```
96 | huggingface-cli login
97 | ```
98 | ## 4. Relevance Feedback
99 | We also provide scripts here to use the LLM reranker for a downstream task, such as relevance feedback. [Inference-time relevance feedback](https://arxiv.org/pdf/2305.11744) uses the reranker's output to distill the retriever's query embedding to improve recall.
100 | ### 4a. Dataset preparation for relevance feedback
101 | To prepare dataset(s) for relevance feedback, run:
102 | ```
103 | bash bash/beir/run_prepare_distill.sh
104 | ```
105 | ### 4b. Distillation (Relevance Feedback Step)
106 | You can choose to run distillation with either the cross encoder or the LLM reranker or both sequentially.
107 | To perform the relevance feedback distillation step, run:
108 | ```
109 | bash bash/beir/run_distill.sh
110 | ```
111 | This step creates new query embeddings after distillation.
112 |
113 | ### 4c. 2nd Retrieval
114 | To perform the retrieval step with the new query embedding after distillation, run:
115 | ```
116 | bash bash/beir/run_2nd_retrieval.sh
117 | ```
118 |
119 | ### 4d. Relevance feedback evaluation
120 | To evaluate the 2nd retrieval step, run:
121 | ```
122 | bash bash/beir/run_eval.sh rank_refit
123 | ```
124 |
125 | ## Citation
126 |
127 | If you found this repo useful for your work, please consider citing our papers:
128 |
129 | ```
130 | @article{reddy2024first,
131 | title={FIRST: Faster Improved Listwise Reranking with Single Token Decoding},
132 | author={Reddy, Revanth Gangi and Doo, JaeHyeok and Xu, Yifei and Sultan, Md Arafat and Swain, Deevya and Sil, Avirup and Ji, Heng},
133 | journal={arXiv preprint arXiv:2406.15657},
134 | year={2024}
135 | }
136 | ```
137 |
138 | ```
139 | @article{reddy2023inference,
140 | title={Inference-time Re-ranker Relevance Feedback for Neural Information Retrieval},
141 | author={Reddy, Revanth Gangi and Dasigi, Pradeep and Sultan, Md Arafat and Cohan, Arman and Sil, Avirup and Ji, Heng and Hajishirzi, Hannaneh},
142 | journal={arXiv preprint arXiv:2305.11744},
143 | year={2023}
144 | }
145 | ```
146 |
147 | We also acknowledge the following opens-source repos, which were instrumental for this work:
148 | - [Tevatron](https://github.com/texttron/tevatron) for retrieval framework
149 | - [RankLLM](https://github.com/castorini/rank_llm/) for LLM reranking inference backbone.
--------------------------------------------------------------------------------
/bash/run_1st_retrieval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]; then
4 | echo "Usage: $0 "
5 | exit 1
6 | fi
7 |
8 | input_dir="$1"
9 |
10 | output_dir="${REPO_DIR}/outputs/beir"
11 | data_dir="${REPO_DIR}/datasets/beir"
12 |
13 | mkdir -p "$output_dir" "$data_dir"
14 |
15 | # Datasets to process
16 | datasets=('trec-covid') # 'climate-fever' 'dbpedia-entity' 'fever' 'fiqa' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'scidocs' 'scifact'
17 |
18 | # Iterate over datasets
19 | for dataset in "${datasets[@]}"; do
20 | echo "Processing dataset: ${dataset}"
21 |
22 | dataset_output_dir="${output_dir}/${dataset}"
23 | mkdir -p "$dataset_output_dir"
24 |
25 | python -m tevatron.faiss_retriever \
26 | --query_reps "${input_dir}/${dataset}/original_query/qry.pt" \
27 | --passage_reps "${input_dir}/${dataset}/original_corpus/*.pt" \
28 | --depth 1000 \
29 | --batch_size -1 \
30 | --save_text \
31 | --save_ranking_to "${dataset_output_dir}/rank.tsv"
32 | done
33 |
--------------------------------------------------------------------------------
/bash/run_2nd_retrieval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]; then
4 | echo "Usage: $0 "
5 | exit 1
6 | fi
7 |
8 | input_dir=$1
9 |
10 | # Check if output directory exists
11 | mkdir -p "${REPO_DIR}/outputs/beir"
12 |
13 | # List of datasets to process
14 | datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact'
15 |
16 | # Process each dataset
17 | for dataset in "${datasets[@]}"; do
18 | echo "Processing dataset: ${dataset}"
19 |
20 | dataset_output_dir="${REPO_DIR}/outputs/beir/${dataset}"
21 | mkdir -p "$dataset_output_dir"
22 |
23 | python -m tevatron.faiss_retriever \
24 | --query_reps "${dataset_output_dir}/qry_refit.pt" \
25 | --passage_reps "${input_dir}/${dataset}/original_corpus/*.pt" \
26 | --depth 1000 \
27 | --batch_size -1 \
28 | --save_text \
29 | --save_ranking_to "${dataset_output_dir}/rank_refit.tsv"
30 |
31 | if [ $? -ne 0 ]; then
32 | echo "Error processing dataset: ${dataset}"
33 | exit 1
34 | fi
35 |
36 | echo "Finished processing dataset: ${dataset}"
37 | done
38 |
39 | echo "All datasets processed successfully."
40 |
--------------------------------------------------------------------------------
/bash/run_convert_results.sh:
--------------------------------------------------------------------------------
1 | data_dir=${REPO_DIR}/datasets/beir/
2 | output_dir=${REPO_DIR}/outputs/beir/
3 |
4 | # List of datasets to process
5 | datasets=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
6 |
7 | # Iterate over datasets and process each one
8 | for datasets in "${datasets[@]}"; do
9 | echo "Processing dataset: ${datasets}"
10 |
11 | if python "${REPO_DIR}/scripts/convert_results.py" \
12 | --dataset "${datasets}" \
13 | --output_dir "${output_dir}" \
14 | --data_type "beir" \
15 | --data_dir "${data_dir}" \
16 | --top_k 100; then
17 | echo "Successfully processed ${datasets}"
18 | else
19 | echo "Failed to process ${datasets}" >&2
20 | exit 1
21 | fi
22 | done
--------------------------------------------------------------------------------
/bash/run_distill.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if output directory exists
4 | mkdir -p "${REPO_DIR}/outputs/beir"
5 | output_dir="${REPO_DIR}/outputs/beir/"
6 | data_dir="${REPO_DIR}/datasets/beir/"
7 |
8 | # Configuration flags
9 | use_logits=1 # Whether to use FIRST single token logit decoding
10 | use_alpha=1 # Whether to use Alphabetic Identifiers
11 |
12 | # List of datasets to process
13 | datasets=('trec-covid' 'dbpedia-entity') #'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact'
14 |
15 | # Process each dataset
16 | for dataset in "${datasets[@]}"; do
17 | echo "Processing dataset: ${dataset}"
18 |
19 | python ${REPO_DIR}/scripts/distill.py \
20 | --inp_path ${output_dir}/${dataset}/distill_input.pt \
21 | --rerank_path ${output_dir}/${dataset} \
22 | --output_path ${output_dir}/${dataset}/qry_refit.pt \
23 | --ce_top_k 100 \
24 | --llm_top_k 100 \
25 | --use_logits ${use_logits} \
26 | --use_alpha ${use_alpha} \
27 | --loss_path ${output_dir}/${dataset} \
28 | --llm_loss ranknet
29 |
30 | if [ $? -ne 0 ]; then
31 | echo "Error processing dataset: ${dataset}"
32 | exit 1
33 | fi
34 |
35 | echo "Finished processing dataset: ${dataset}"
36 | done
37 |
38 | echo "All datasets processed successfully."
39 |
--------------------------------------------------------------------------------
/bash/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if eval_type argument is provided
4 | if [ -z "$1" ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | EVAL_TYPE=$1
10 | DATA_DIR="${REPO_DIR}/datasets/beir/"
11 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/"
12 |
13 | # List of datasets to process
14 | DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
15 |
16 | # Iterate over datasets and process each one
17 | for DATASET in "${DATASETS[@]}"; do
18 | echo "Evaluating dataset: ${DATASET}"
19 |
20 | # suffix: ce -> cross encoder reranker | llm_FIRST_alpha -> FIRST Model
21 | if python "${REPO_DIR}/scripts/eval.py" \
22 | --dataset "${DATASET}" \
23 | --output_path "${OUTPUT_DIR}" \
24 | --data_type "beir" \
25 | --suffix "llm_FIRST_alpha" \
26 | --eval_type "${EVAL_TYPE}" \
27 | --data_dir "${DATA_DIR}"; then
28 | echo "Successfully evaluated ${DATASET}"
29 | else
30 | echo "Failed to evaluate ${DATASET}" >&2
31 | exit 1
32 | fi
33 | done
--------------------------------------------------------------------------------
/bash/run_prepare_distill.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if input directory is provided
4 | if [ -z "$1" ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | input_dir=$1
10 | output_dir="${REPO_DIR}/outputs/beir/"
11 |
12 | # List of datasets to process
13 | datasets=('trec-covid' 'dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
14 |
15 | # Process each dataset
16 | for dataset in "${datasets[@]}"; do
17 | echo "Processing dataset: ${dataset}"
18 |
19 | python ${REPO_DIR}/scripts/prepare_distill.py \
20 | --output_path ${output_dir}/${dataset}/distill_input.pt \
21 | --rank_path ${output_dir}/${dataset}/rank.tsv \
22 | --psg_embs_dir ${input_dir}/${dataset}/original_corpus/ \
23 | --qry_embs_path ${input_dir}/${dataset}/original_query/qry.pt
24 |
25 | if [ $? -ne 0 ]; then
26 | echo "Error processing dataset: ${dataset}"
27 | exit 1
28 | fi
29 |
30 | echo "Finished processing dataset: ${dataset}"
31 | done
32 |
--------------------------------------------------------------------------------
/bash/run_rerank_CE.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set directories
4 | DATA_DIR="${REPO_DIR}/datasets/beir/"
5 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/"
6 |
7 | # List of datasets to rerank
8 | DATASETS=('trec-covid') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'dbpedia-entity'
9 |
10 | # Iterate over datasets and rerank each one
11 | for DATASET in "${DATASETS[@]}"; do
12 | echo "Reranking dataset: ${DATASET}"
13 |
14 | if python "${REPO_DIR}/scripts/rerank_CE.py" \
15 | --dataset "${DATASET}" \
16 | --output_dir "${OUTPUT_DIR}" \
17 | --data_dir "${DATA_DIR}" \
18 | --data_type "beir" \
19 | --top_k 100; then
20 | echo "Successfully reranked ${DATASET} with CE reranker"
21 | else
22 | echo "Failed to rerank ${DATASET}" >&2
23 | exit 1
24 | fi
25 | done
26 |
--------------------------------------------------------------------------------
/bash/run_rerank_llm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set directories and model
4 | DATA_DIR="${REPO_DIR}/datasets/beir/"
5 | OUTPUT_DIR="${REPO_DIR}/outputs/beir/"
6 | MODEL_IN_USE="rryisthebest/First_Model"
7 |
8 | # Configuration flags
9 | USE_LOGITS=1 # Whether to use FIRST single token logit decoding
10 | USE_ALPHA=1 # Whether to use Alphabetic Identifiers
11 |
12 | # List of datasets to rerank
13 | DATASETS=('dbpedia-entity') # 'climate-fever' 'fever' 'hotpotqa' 'msmarco' 'nfcorpus' 'nq' 'fiqa' 'scidocs' 'scifact' 'trec-covid'
14 |
15 | # Iterate over datasets and rerank each one
16 | for DATASET in "${DATASETS[@]}"; do
17 | echo "Reranking dataset: ${DATASET}"
18 |
19 | if python "${REPO_DIR}/scripts/rerank_llm.py" \
20 | --model "${MODEL_IN_USE}" \
21 | --dataset "${DATASET}" \
22 | --output_dir "${OUTPUT_DIR}" \
23 | --data_type "beir" \
24 | --data_dir "${DATA_DIR}" \
25 | --use_logits "${USE_LOGITS}" \
26 | --use_alpha "${USE_ALPHA}" \
27 | --llm_top_k 100 \
28 | --window_size 20 \
29 | --step_size 10 \
30 | --do_batched 1; then
31 | echo "Successfully reranked ${DATASET} with LLM reranker"
32 | else
33 | echo "Failed to rerank ${DATASET} with LLM reranker" >&2
34 | exit 1
35 | fi
36 | done
--------------------------------------------------------------------------------
/bash/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define model, dataset paths, and output directory
4 | BASE_MODEL="HuggingFaceH4/zephyr-7b-beta"
5 | TRAIN_DATA_PATH="rryisthebest/rank_zephyr_training_data_alpha" # Train Dataset --> Hugging Face dataset or Local dataset
6 | EVAL_DATA_PATH="rryisthebest/evaluation_data_alpha" # Eval Dataset --> Hugging Face dataset or Local dataset
7 | OUTPUT_DIR="${REPO_DIR}/models/ranking/FIRST_Model" # Directory to save the trained model
8 | BEIR_DATA_DIR="${REPO_DIR}/datasets/beir/"
9 |
10 | # Launch training with DeepSpeed configuration
11 | accelerate launch --config_file "${REPO_DIR}/train_configs/accel_config_deepspeed.yaml" "${REPO_DIR}/scripts/train_ranking.py" \
12 | --model_name_or_path "${BASE_MODEL}" \
13 | --train_dataset_path "${TRAIN_DATA_PATH}" \
14 | --eval_dataset_path "${EVAL_DATA_PATH}" \
15 | --beir_data_path "${BEIR_DATA_DIR}" \
16 | --per_device_eval_batch_size 1 \
17 | --num_train_epochs 3 \
18 | --seed 42 \
19 | --per_device_train_batch_size 2 \
20 | --eval_steps 400 \
21 | --gradient_checkpointing \
22 | --gradient_accumulation_steps 16 \
23 | --lr_scheduler_type cosine \
24 | --num_warmup_steps 50 \
25 | --output_dir "${OUTPUT_DIR}" \
26 | --noisy_embedding_alpha 5 \
27 | --objective combined
28 |
--------------------------------------------------------------------------------
/scripts/convert_results.py:
--------------------------------------------------------------------------------
1 | from beir.datasets.data_loader import GenericDataLoader
2 | import csv
3 | import os
4 | import json
5 | from collections import defaultdict
6 | from argparse import ArgumentParser
7 | from utils.result import Result, ResultsWriter
8 |
9 | def convert_results(output_path, data_dir, dataset, data_type, top_k, rerank_type="text"):
10 | """Convert ranking results to format suitable for reranking
11 |
12 | Args:
13 | rerank_type (str): Whether this is for "text" or "code" reranking
14 | """
15 | print(f"Loading {dataset} dataset")
16 |
17 | try:
18 | # Load datasets based on type
19 | if rerank_type == "code":
20 | if dataset in ('swebench_function', 'swebench_file'):
21 | return convert_results_swebench(output_path, data_dir, dataset, data_type, top_k)
22 |
23 | data_path = os.path.join(data_dir, dataset)
24 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
25 |
26 | # Load rank data
27 | rank_path = os.path.join(output_path, dataset, "rank.tsv")
28 | dataset_output_path = os.path.join(output_path, dataset)
29 |
30 | else: # text reranking
31 | out_dir = os.path.join(data_dir, "beir")
32 | data_path = os.path.join(out_dir, dataset)
33 |
34 | split = "dev" if dataset == "msmarco" else "test"
35 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
36 |
37 | # Load rank data
38 | dataset_output_path = os.path.join(output_path, "beir", dataset)
39 | rank_path = os.path.join(dataset_output_path, "rank.tsv")
40 |
41 | print("Loading rank data")
42 | if not os.path.exists(rank_path):
43 | print(f"Rank file not found: {rank_path}")
44 | return
45 |
46 | results = {}
47 | with open(rank_path, 'r') as rank_file:
48 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|')
49 | if rerank_type == "code":
50 | next(csv_reader) # Skip header for code files
51 | for row in csv_reader:
52 | qid = str(row[0])
53 | pid = str(row[1])
54 | score = float(row[2])
55 | if qid not in results:
56 | results[qid] = {}
57 | results[qid][pid] = score
58 |
59 | print("Converting to reranker results")
60 | # Remove dummy entries if present (for code reranking)
61 | if 'dummy' in results:
62 | results.pop('dummy')
63 |
64 | results_to_rerank = to_reranker_results(results, queries, corpus, top_k)
65 |
66 | # Ensure output directory exists
67 | os.makedirs(dataset_output_path, exist_ok=True)
68 |
69 | results_output_path = os.path.join(dataset_output_path, f'rank_{top_k}.json')
70 | results_writer = ResultsWriter(results_to_rerank)
71 | results_writer.write_in_json_format(results_output_path)
72 | print(f"Results saved to {results_output_path}")
73 |
74 | except Exception as e:
75 | print(f"Error in convert_results: {e}")
76 | raise
77 |
78 | def convert_results_swebench(output_path, data_dir, dataset, data_type, top_k):
79 | """Special handling for swebench datasets"""
80 | prefx = f"csn_{dataset.split('_')[1]}"
81 | instance_list = [instance for instance in os.listdir(data_dir) if instance.startswith(prefx)]
82 |
83 | for dataset_instance in instance_list:
84 | print(f"Loading {dataset_instance} dataset")
85 | data_path = os.path.join(data_dir, "code_datasets", dataset_instance)
86 |
87 | try:
88 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
89 |
90 | print("Loading rank data")
91 | rank_path = os.path.join(output_path, "code_datasets", dataset_instance, "rank.tsv")
92 |
93 | results = {}
94 | with open(rank_path, 'r') as rank_file:
95 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|')
96 | next(csv_reader) # Skip header
97 | for row in csv_reader:
98 | qid = str(row[0])
99 | pid = str(row[1])
100 | score = float(row[2])
101 | if qid not in results:
102 | results[qid] = {}
103 | results[qid][pid] = score
104 |
105 | print("Converting to reranker results")
106 | results_to_rerank = to_reranker_results(results, queries, corpus, top_k)
107 |
108 | dataset_output_path = os.path.join(output_path, "code_datasets", dataset_instance)
109 | os.makedirs(dataset_output_path, exist_ok=True)
110 |
111 | results_output_path = os.path.join(dataset_output_path, f'rank_{top_k}.json')
112 | results_writer = ResultsWriter(results_to_rerank)
113 | results_writer.write_in_json_format(results_output_path)
114 | print(f"Results saved to {results_output_path}")
115 |
116 | except Exception as e:
117 | print(f"Error processing {dataset_instance}: {e}")
118 | continue
119 |
120 | def to_reranker_results(results, queries, corpus, top_k):
121 | """Convert results to format needed by reranker"""
122 | retrieved_results_with_text = []
123 | for qid, docs_scores in results.items():
124 | query_text = queries[qid]
125 | for doc_id, score in docs_scores.items():
126 | doc_text = corpus[doc_id]
127 | result_with_text = {
128 | 'qid': qid,
129 | 'query_text': query_text,
130 | 'doc_id': doc_id,
131 | 'doc_text': doc_text,
132 | 'score': score
133 | }
134 | retrieved_results_with_text.append(result_with_text)
135 |
136 | hits_by_query = defaultdict(list)
137 | for result in retrieved_results_with_text:
138 | content_string = ''
139 | if isinstance(result['doc_text'], dict):
140 | if result['doc_text'].get('title'):
141 | content_string += result['doc_text']['title'] + ". "
142 | content_string += result['doc_text']['text']
143 | else:
144 | content_string = result['doc_text']
145 |
146 | hits_by_query[result['query_text']].append({
147 | 'qid': result['qid'],
148 | 'docid': result['doc_id'],
149 | 'score': result['score'],
150 | 'content': content_string
151 | })
152 |
153 | results_to_rerank = []
154 | for query_text, hits in hits_by_query.items():
155 | sorted_hits = sorted(hits, reverse=True, key=lambda x: x['score'])[:top_k]
156 | result = Result(query=query_text, hits=sorted_hits)
157 | results_to_rerank.append(result)
158 |
159 | return results_to_rerank
160 |
161 | if __name__ == '__main__':
162 | parser = ArgumentParser()
163 | parser.add_argument('--dataset', required=True)
164 | parser.add_argument('--output_dir', required=True)
165 | parser.add_argument('--data_dir', required=True)
166 | parser.add_argument('--data_type', required=True)
167 | parser.add_argument('--top_k', required=True, type=int)
168 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"],
169 | help="Whether to convert for code or text reranking")
170 | args = parser.parse_args()
171 |
172 | assert args.data_type in ["beir", "codedataset"], "Invalid data_type. Must be 'beir' or 'codedataset'."
173 |
174 | if args.data_type == "codedataset" and args.rerank_type != "code":
175 | print("Warning: codedataset data_type implies code reranking. Setting rerank_type to 'code'")
176 | args.rerank_type = "code"
177 |
178 | convert_results(args.output_dir, args.data_dir, args.dataset, args.data_type, args.top_k, args.rerank_type)
179 |
--------------------------------------------------------------------------------
/scripts/distill.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from copy import deepcopy
6 | from tqdm import tqdm
7 | from itertools import product
8 | from argparse import ArgumentParser
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from torch import Tensor
14 | from utils.loss import loss_dict
15 |
16 | class QueryImpModel(nn.Module):
17 | def __init__(self, query_rep, scaler):
18 | super().__init__()
19 | self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True)
20 | self.scaler = scaler
21 |
22 | def forward(self, psg_embs: Tensor, attn_mask: Tensor = None):
23 | pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1))
24 | if attn_mask is not None:
25 | extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min
26 | pred_scores += extended_attention_mask
27 | pred_probs = nn.functional.log_softmax(pred_scores, dim=-1)
28 | return pred_probs
29 |
30 |
31 | class QueryScoreModel(nn.Module):
32 | def __init__(self, query_rep, scaler=2.0):
33 | super().__init__()
34 | self.query_rep = nn.Parameter(torch.FloatTensor(query_rep), requires_grad=True)
35 | self.scaler = scaler
36 |
37 | def forward(self, psg_embs: Tensor, attn_mask: Tensor = None):
38 | pred_scores = (self.scaler / 2) * torch.matmul(self.query_rep, psg_embs.transpose(0, 1))
39 | if attn_mask is not None:
40 | extended_attention_mask = (1.0 - attn_mask) * torch.finfo(pred_scores.dtype).min
41 | pred_scores += extended_attention_mask
42 | return pred_scores.unsqueeze(0)
43 |
44 |
45 | def load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha):
46 | llm_rerank = None
47 | ce_rerank = None
48 |
49 | if llm_top_k > 0:
50 | suffix = "_llm"
51 | suffix += "_FIRST" if use_logits else "_gen"
52 | suffix += "_alpha" if use_alpha else "_num"
53 | llm_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{llm_top_k}{suffix}.json")))
54 |
55 | if ce_top_k > 0:
56 | ce_rerank = json.load(open(os.path.join(rerank_path, f"rerank_{ce_top_k}_ce.json")))
57 |
58 | examples = pickle.load(open(inp_path, "rb"))
59 | return examples, ce_rerank, llm_rerank
60 |
61 |
62 | def prepare_distill_ce(data, ce_rerank, ce_top_k):
63 | qid = data["query_id"]
64 | pids = data["passage_ids"][:ce_top_k]
65 |
66 | data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])}
67 | target_scores = [ce_rerank[qid][pid] for pid in pids]
68 | psg_embs = [data_passage_mapping[pid] for pid in pids]
69 |
70 | target_scores = torch.FloatTensor(target_scores)
71 | target_probs = nn.functional.log_softmax(target_scores, dim=-1)
72 |
73 | baseline_rep = torch.FloatTensor(data["query_rep"])
74 | passage_reps = torch.FloatTensor(np.array(psg_embs))
75 |
76 | init_scores = torch.matmul(baseline_rep, passage_reps.transpose(0, 1))
77 | scaler = (target_scores.max() - target_scores.min()) / (init_scores.max().item() - init_scores.min().item())
78 |
79 | return passage_reps, target_probs, scaler
80 |
81 |
82 | def prepare_distill_llm(data, llm_rerank, query_rep, llm_top_k):
83 | qid = data["query_id"]
84 | pids = data["passage_ids"][:llm_top_k]
85 |
86 | data_passage_mapping = {pid: deepcopy(emb) for pid, emb in zip(data["passage_ids"], data["passage_embs"])}
87 | reranked_target_scores = [llm_rerank[qid][pid] for pid in pids]
88 | reranked_psg_embs = [data_passage_mapping[pid] for pid in pids]
89 |
90 | reranked_target_scores = torch.FloatTensor(reranked_target_scores)
91 | reranked_passage_reps = torch.FloatTensor(np.array(reranked_psg_embs))
92 |
93 | init_scores = torch.matmul(query_rep, reranked_passage_reps.transpose(0, 1))
94 | scaler = (reranked_target_scores.max() - reranked_target_scores.min()) / \
95 | (init_scores.max().item() - init_scores.min().item())
96 |
97 | return reranked_passage_reps, reranked_target_scores.unsqueeze(0), scaler
98 |
99 |
100 | def run_query_teacher_importance_learner(inp_path, rerank_path, output_path, loss_path, ce_top_k, llm_top_k, learning_rate,
101 | num_updates, use_logits, use_alpha, llm_loss):
102 | assert llm_loss in loss_dict
103 | examples, ce_rerank, llm_rerank = load_results(inp_path, rerank_path, ce_top_k, llm_top_k, use_logits, use_alpha)
104 |
105 | reps = []
106 | ids = []
107 |
108 | for data in tqdm(examples):
109 | baseline_rep = torch.FloatTensor(data["query_rep"])
110 |
111 | try:
112 | learned_rep = baseline_rep
113 | if ce_top_k > 0:
114 | passage_reps, target_probs, scaler = prepare_distill_ce(data, ce_rerank, ce_top_k)
115 | ce_dstl_model = QueryImpModel(query_rep=baseline_rep.numpy(), scaler=scaler)
116 | loss_function = nn.KLDivLoss(reduction="batchmean", log_target=True)
117 | optimizer = optim.Adam(ce_dstl_model.parameters(), lr=learning_rate)
118 |
119 | for _ in range(num_updates):
120 | optimizer.zero_grad()
121 | pred_probs = ce_dstl_model(psg_embs=passage_reps)
122 | loss = loss_function(pred_probs.unsqueeze(0), target_probs.unsqueeze(0))
123 | loss.backward()
124 | optimizer.step()
125 |
126 | learned_rep = ce_dstl_model.query_rep.data.cpu().detach()
127 |
128 | reranked_passage_reps, reranked_target_scores, scaler = prepare_distill_llm(data, llm_rerank, learned_rep,
129 | llm_top_k)
130 | llm_dstl_model = QueryScoreModel(query_rep=learned_rep.numpy(), scaler=scaler)
131 | optimizer = optim.Adam(llm_dstl_model.parameters(), lr=learning_rate / 5)
132 |
133 | for _ in range(num_updates // 5):
134 | optimizer.zero_grad()
135 | pred_scores = llm_dstl_model(psg_embs=reranked_passage_reps)
136 | loss = loss_dict[llm_loss](pred_scores, reranked_target_scores, weighted=True if llm_loss == "ranknet" else False)
137 | loss.backward()
138 | optimizer.step()
139 |
140 | rep = llm_dstl_model.query_rep.data.cpu().detach()
141 | reps.append(rep.numpy())
142 | ids.append(data["query_id"])
143 | except Exception as e:
144 | print(f"Error for query ID {data['query_id']}: {e}")
145 |
146 | pickle.dump((np.array(reps), ids), open(output_path, "wb"))
147 |
148 |
149 | if __name__ == "__main__":
150 |
151 | parser = ArgumentParser()
152 | parser.add_argument('--inp_path', required=True)
153 | parser.add_argument('--rerank_path', required=True)
154 | parser.add_argument('--output_path', required=True)
155 | parser.add_argument('--loss_path', required=True)
156 | parser.add_argument('--ce_top_k', type=int, default=100)
157 | parser.add_argument('--llm_top_k', type=int, default=9)
158 | parser.add_argument('--learning_rate', type=float, default=0.005)
159 | parser.add_argument('--num_updates', type=int, default=100)
160 | parser.add_argument('--use_logits', type=int, default=0)
161 | parser.add_argument('--use_alpha', type=int, default=0)
162 | parser.add_argument('--llm_loss', type=str, default="lambdarank")
163 |
164 | args = parser.parse_args()
165 |
166 | run_query_teacher_importance_learner(args.inp_path, args.rerank_path, args.output_path, args.loss_path, args.ce_top_k, args.llm_top_k, args.learning_rate, args.num_updates, args.use_logits, args.use_alpha, args.llm_loss)
167 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | from beir import util, LoggingHandler
2 | from beir.datasets.data_loader import GenericDataLoader
3 | from beir.retrieval.evaluation import EvaluateRetrieval
4 | import csv
5 | import os
6 | import logging
7 | import json
8 | from argparse import ArgumentParser
9 |
10 | def write_results_to_text(out_path, ndcg, recall):
11 | with open(out_path, 'w') as f:
12 | f.write(f"{ndcg}\n{recall}\n")
13 |
14 | def eval_rank(output_path, data_dir, dataset, data_type, prefix="", eval_type="rank", rerank_type="text"):
15 | try:
16 | # Load datasets based on type
17 | if rerank_type == "code":
18 | data_path = os.path.join(data_dir, "code_datasets", dataset)
19 | split = "test"
20 | else: # text reranking
21 | data_path = os.path.join(data_dir, "beir", dataset)
22 | split = "dev" if dataset == "msmarco" else "test"
23 |
24 | try:
25 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
26 | except Exception as e:
27 | print(f"Error loading dataset {dataset} from {data_path}: {e}")
28 | return
29 |
30 | # Set file name by evaluation type
31 | fname = "rank.tsv" if eval_type == "rank" else "rank_refit.tsv"
32 |
33 | # Load rank data from appropriate directory
34 | if rerank_type == "code":
35 | model_output_path = os.path.join(output_path, "code_datasets", dataset, fname)
36 | else:
37 | model_output_path = os.path.join(output_path, "beir", dataset, fname)
38 |
39 | if not os.path.exists(model_output_path):
40 | print(f"Rank file not found: {model_output_path}")
41 | return
42 |
43 | results = {}
44 | try:
45 | with open(model_output_path, 'r') as rank_file:
46 | csv_reader = csv.reader(rank_file, delimiter="\t", quotechar='|')
47 | if rerank_type == "code":
48 | next(csv_reader) # Skip header for code files
49 | for row in csv_reader:
50 | qid, pid, score = row[0], row[1], float(row[2])
51 | if qid not in results:
52 | results[qid] = {}
53 | results[qid][pid] = score
54 | except Exception as e:
55 | print(f"Error reading rank file {model_output_path}: {e}")
56 | return
57 |
58 | retriever = EvaluateRetrieval()
59 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100, 125]
60 |
61 | # Evaluate based on rerank type
62 | if rerank_type == "code":
63 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr")
64 | print(f"MRR@{metrics_to_evaluate}: {mrr}")
65 | else:
66 | # Standard text reranking evaluation
67 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, metrics_to_evaluate)
68 | if dataset == "trec-covid":
69 | recall = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "recall_cap")
70 | print(ndcg, recall)
71 |
72 | except Exception as e:
73 | print(f"Error in eval_rank: {e}")
74 | raise
75 |
76 | def eval_rerank(output_path, data_dir, dataset, data_type, suffix="", rerank_type="text"):
77 | try:
78 | # Load datasets based on type
79 | if rerank_type == "code":
80 | data_path = os.path.join(data_dir, "code_datasets", dataset)
81 | split = "test"
82 | else: # text reranking
83 | data_path = os.path.join(data_dir, "beir", dataset)
84 | split = "dev" if dataset == "msmarco" else "test"
85 |
86 | try:
87 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
88 | except Exception as e:
89 | print(f"Error loading dataset {dataset} from {data_path}: {e}")
90 | return
91 |
92 | # Load rerank results from appropriate directory
93 | if rerank_type == "code":
94 | model_output_path = os.path.join(output_path, "code_datasets", dataset, f"rerank_100_{suffix}.json")
95 | else:
96 | model_output_path = os.path.join(output_path, "beir", dataset, f"rerank_100_{suffix}.json")
97 |
98 | if not os.path.exists(model_output_path):
99 | print(f"Rerank file not found: {model_output_path}")
100 | return
101 |
102 | try:
103 | with open(model_output_path, 'r') as json_file:
104 | results = json.load(json_file)
105 | except Exception as e:
106 | print(f"Error reading rerank file {model_output_path}: {e}")
107 | return
108 |
109 | retriever = EvaluateRetrieval()
110 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100]
111 |
112 | if rerank_type == "code":
113 | # For code reranking, focus on MRR and exact matches
114 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr")
115 | print(f"MRR@{metrics_to_evaluate}: {mrr}")
116 | else:
117 | # Standard text reranking evaluation
118 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, metrics_to_evaluate)
119 | mrr = retriever.evaluate_custom(qrels, results, metrics_to_evaluate, "mrr")
120 | print(ndcg, recall)
121 |
122 | except Exception as e:
123 | print(f"Error in eval_rerank: {e}")
124 | raise
125 |
126 | if __name__ == "__main__":
127 | parser = ArgumentParser()
128 | parser.add_argument('--dataset', required=True, help="Name of the dataset to evaluate.")
129 | parser.add_argument('--output_path', required=True, help="Directory where the output files are stored.")
130 | parser.add_argument('--data_dir', required=True, help="Directory where datasets are stored or will be downloaded.")
131 | parser.add_argument('--suffix', default="", type=str, help="Suffix for the evaluation files (e.g., 'ce', 'logits_alpha').")
132 | parser.add_argument('--data_type', required=True, help="Type of the dataset, must be 'beir' or 'codedataset'.")
133 | parser.add_argument('--eval_type', required=True, help="Type of evaluation: 'rank', 'rerank', or 'rank_refit'.")
134 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"],
135 | help="Whether to evaluate code or text reranking results")
136 | args = parser.parse_args()
137 |
138 | assert args.data_type in ["beir", "codedataset"], "Invalid data_type. Must be 'beir' or 'codedataset'."
139 | assert args.eval_type in ["rank", "rerank", "rank_refit"], "Invalid eval_type. Must be 'rank', 'rerank', or 'rank_refit'."
140 |
141 | if args.data_type == "codedataset" and args.rerank_type != "code":
142 | print("Warning: codedataset data_type implies code reranking. Setting rerank_type to 'code'")
143 | args.rerank_type = "code"
144 |
145 | if args.eval_type in ["rank", "rank_refit"]:
146 | eval_rank(args.output_path, args.data_dir, args.dataset, args.data_type,
147 | args.suffix, args.eval_type, args.rerank_type)
148 | else:
149 | eval_rerank(args.output_path, args.data_dir, args.dataset, args.data_type,
150 | args.suffix, args.rerank_type)
151 |
--------------------------------------------------------------------------------
/scripts/prepare_distill.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import csv
3 | from itertools import count
4 | import json
5 | import sys
6 | csv.field_size_limit(sys.maxsize)
7 |
8 | from tqdm import tqdm
9 | from collections import Counter
10 | import random
11 | import pickle
12 | import glob
13 | from argparse import ArgumentParser
14 | import os
15 | import torch
16 | import pickle
17 | import numpy as np
18 | import torch.nn as nn
19 |
20 |
21 | def get_all_examples(rank_path, psg_embs_dir, qry_embs_path, output_path):
22 | import os
23 | import torch
24 | import pickle
25 | import numpy as np
26 | import torch.nn as nn
27 |
28 | embs_path = os.path.join(psg_embs_dir, "split*.pt")
29 | embs_files = glob.glob(embs_path)
30 | embs = list()
31 | ids = list()
32 | p_reps_0, look_up_0 = pickle.load(open(embs_files[0], "rb"))
33 | embs.extend(p_reps_0)
34 | ids.extend(look_up_0)
35 | for f in tqdm(embs_files[1:]):
36 | p_reps, look_up = pickle.load(open(f, "rb"))
37 | embs.extend(p_reps)
38 | ids.extend(look_up)
39 | embs = np.array(embs, dtype=np.float32)
40 | embs_dict = dict()
41 | assert len(embs) == len(ids)
42 | for emb, pid in zip(embs, ids):
43 | embs_dict[pid] = emb
44 |
45 | q_reps, q_look_up = pickle.load(open(qry_embs_path, "rb"))
46 | q_embs = np.array(q_reps, dtype=np.float32)
47 | assert len(q_embs) == len(q_look_up)
48 | q_tuples = list()
49 | for q_emb, qid in zip(q_embs, q_look_up):
50 | q_tuples.append((str(qid), q_emb))
51 |
52 | results = dict()
53 | csv_reader = csv.reader(open(rank_path), delimiter="\t", quotechar='|')
54 | for row in csv_reader:
55 | qid = str(row[0])
56 | pid = str(row[1])
57 | score = float(row[2])
58 | if qid not in results:
59 | results[qid] = dict()
60 | results[qid][pid] = score
61 |
62 | examples = list()
63 | for qid, q_emb in tqdm(q_tuples):
64 | item = dict()
65 | passage_ids = sorted(results[qid].items(), key=lambda item: item[1], reverse=True)
66 | passage_ids = [i[0] for i in passage_ids]
67 |
68 | item["query_rep"] = q_emb
69 | item["passage_ids"] = deepcopy(passage_ids)
70 | item["passage_embs"] = list()
71 | for pid in item["passage_ids"]:
72 | item["passage_embs"].append(embs_dict[pid])
73 | item["passage_embs"] = np.array(item["passage_embs"])
74 | item["query_id"] = qid
75 | item["qrels"] = ""
76 | examples.append(deepcopy(item))
77 | pickle.dump(examples, open(output_path, "wb"))
78 |
79 | if __name__ == "__main__":
80 | parser = ArgumentParser()
81 | parser.add_argument('--rank_path', required=True)
82 | parser.add_argument('--psg_embs_dir', required=True)
83 | parser.add_argument('--qry_embs_path', required=True)
84 | parser.add_argument('--output_path', required=True)
85 | args = parser.parse_args()
86 |
87 | get_all_examples(args.rank_path, args.psg_embs_dir, args.qry_embs_path, args.output_path)
88 |
89 |
90 |
--------------------------------------------------------------------------------
/scripts/rerank_CE.py:
--------------------------------------------------------------------------------
1 | from beir.reranking import Rerank
2 | from beir.reranking.models import CrossEncoder
3 | from beir import util, LoggingHandler
4 | from beir.datasets.data_loader import GenericDataLoader
5 | from beir.retrieval.evaluation import EvaluateRetrieval
6 | import csv
7 | import os
8 | import logging
9 | import json
10 | from argparse import ArgumentParser
11 |
12 | def rerank_beir_outputs(output_path, data_dir, dataset, data_type, top_k):
13 | if data_type == "beir":
14 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
15 | out_dir = os.path.join(data_dir, "datasets")
16 | data_path = util.download_and_unzip(url, out_dir)
17 | else:
18 | data_path = data_dir + dataset
19 |
20 | if dataset == "msmarco":
21 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="dev")
22 | else:
23 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
24 | model_output_path = os.path.join(os.path.join(output_path, dataset), "rank.tsv")
25 | csv_reader = csv.reader(open(model_output_path), delimiter="\t", quotechar='|')
26 | results = dict()
27 | for row in csv_reader:
28 | qid = str(row[0])
29 | pid = str(row[1])
30 | score = float(row[2])
31 | if qid not in results:
32 | results[qid] = dict()
33 | results[qid][pid] = score
34 |
35 | retriever = EvaluateRetrieval()
36 |
37 | #### Evaluate your retrieval using NDCG@k, MAP@K ...
38 | print("Retriever evaluation")
39 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, [1,3,5,10,20,100])
40 | print(ndcg)
41 | mrr = retriever.evaluate_custom(qrels, results, [1,3,5,10,20,100], "mrr")
42 | print(mrr)
43 | if dataset == "trec-covid":
44 | recall_cap = retriever.evaluate_custom(qrels, results, [1,3,5,10,20,100], "recall_cap")
45 | print(recall_cap)
46 | else:
47 | print(recall)
48 |
49 | if data_type == "beir":
50 | cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
51 | else:
52 | cross_encoder_model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
53 | reranker = Rerank(cross_encoder_model, batch_size=64)
54 |
55 | rerank_results = reranker.rerank(corpus, queries, results, top_k=int(top_k))
56 |
57 |
58 | #### Evaluate your retrieval using NDCG@k, MAP@K ...
59 | print("Re-ranker evaluation")
60 | ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, [1,3,5,10,20,100])
61 | print(ndcg)
62 | mrr = retriever.evaluate_custom(qrels, rerank_results, [1,3,5,10,20,100], "mrr")
63 | print(mrr)
64 | if dataset == "trec-covid":
65 | recall_cap = retriever.evaluate_custom(qrels, rerank_results, [1,3,5,10,20,100], "recall_cap")
66 | print(recall_cap)
67 | else:
68 | print(recall)
69 |
70 | rerank_path = os.path.join(os.path.join(output_path, dataset), "rerank_" + str(top_k) + "_ce.json")
71 | json.dump(rerank_results, open(rerank_path, "w"), indent=4)
72 |
73 | if __name__ == "__main__":
74 | parser = ArgumentParser()
75 | parser.add_argument('--dataset', required=True)
76 | parser.add_argument('--output_dir', required=True)
77 | parser.add_argument('--data_dir', required=True)
78 | parser.add_argument('--data_type', required=True)
79 | parser.add_argument('--top_k', required=True)
80 | args = parser.parse_args()
81 |
82 | assert args.data_type == "beir" or args.data_type == "mrtydi"
83 |
84 | rerank_beir_outputs(args.output_dir, args.data_dir, args.dataset, args.data_type, args.top_k)
85 |
--------------------------------------------------------------------------------
/scripts/rerank_llm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from argparse import ArgumentParser
4 | from beir.datasets.data_loader import GenericDataLoader
5 | from utils.result import Result, ResultsLoader
6 | from utils.llm_util import evaluate_results, get_results_to_eval, save_rerank_results, rerank_beir_outputs_llm
7 |
8 | def rerank_beir_outputs(model, output_path, data_dir, dataset, data_type, use_logits, use_alpha, llm_top_k, window_size, step_size, batched, context_size, rerank_type="text", code_prompt_type="docstring"):
9 | try:
10 | # Load dataset based on type
11 | if rerank_type == "code":
12 | data_path = os.path.join(data_dir, dataset)
13 | else: # text reranking
14 | data_path = os.path.join(data_dir, "beir", dataset)
15 |
16 | # Handle dataset loading
17 | split = "dev" if dataset == "msmarco" else "test"
18 | try:
19 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
20 | except Exception as e:
21 | print(f"Error loading dataset {dataset} from {data_path}: {e}")
22 | return
23 |
24 | # Load converted retriever results
25 | try:
26 | if rerank_type == "code":
27 | results_output_path = os.path.join(output_path, dataset, 'rank_100.json')
28 | else:
29 | results_output_path = os.path.join(output_path, "beir", dataset, 'rank_100.json')
30 |
31 | results_loader = ResultsLoader(results_output_path)
32 | results_to_rerank = results_loader.get_results(with_context=True)
33 | except Exception as e:
34 | print(f"Error loading results from {results_output_path}: {e}")
35 | return
36 |
37 | # Reranking
38 | try:
39 | reranked_results = rerank_beir_outputs_llm(
40 | model, results_to_rerank, use_logits=use_logits, use_alpha=use_alpha,
41 | top_k=llm_top_k, window_size=window_size, step_size=step_size,
42 | batched=batched, context_size=context_size,
43 | rerank_type=rerank_type, code_prompt_type=code_prompt_type
44 | )
45 |
46 | # Evaluate results
47 | converted_results = get_results_to_eval(reranked_results)
48 |
49 | if rerank_type == "code":
50 | mrr_at_k = evaluate_results(dataset, qrels, converted_results, rerank_type="code")
51 | print("\nMean Reciprocal Rank (MRR) at different cutoffs:")
52 | for k, mrr in mrr_at_k.items():
53 | print(f"MRR@{k}: {mrr:.4f}")
54 | else:
55 | ndcg, _map, recall, precision = evaluate_results(dataset, qrels, converted_results)
56 | print(f"\nNDCG (Normalized Discounted Cumulative Gain):\n {ndcg}")
57 | print(f"\nRecall:\n {recall}\n")
58 |
59 | # Save rerank results to appropriate directory
60 | if rerank_type == "code":
61 | save_path = os.path.join(output_path, "code_datasets", dataset)
62 | else:
63 | save_path = os.path.join(output_path, "beir", dataset)
64 |
65 | # Create directory if it doesn't exist
66 | os.makedirs(save_path, exist_ok=True)
67 |
68 | save_rerank_results(save_path, dataset, converted_results, llm_top_k,
69 | use_logits, use_alpha, is_llm_result=True)
70 | print(f"Reranked results saved successfully for dataset {dataset}")
71 |
72 | except Exception as e:
73 | print(f"Error during reranking process: {e}")
74 | raise
75 |
76 | except Exception as e:
77 | print(f"Unexpected error in rerank_beir_outputs: {e}")
78 | raise
79 |
80 | if __name__ == "__main__":
81 | parser = ArgumentParser()
82 | parser.add_argument('--model', default="rryisthebest/First_Model")
83 | parser.add_argument('--dataset', required=True)
84 | parser.add_argument('--output_dir', required=True)
85 | parser.add_argument('--data_dir', required=True)
86 | parser.add_argument('--data_type', required=True)
87 | parser.add_argument('--use_logits', default=0, type=int)
88 | parser.add_argument('--use_alpha', default=0, type=int)
89 | parser.add_argument('--context_size', default=32768, type=int)
90 | parser.add_argument('--llm_top_k', default=20, type=int)
91 | parser.add_argument('--window_size', default=9, type=int)
92 | parser.add_argument('--step_size', default=9, type=int)
93 | parser.add_argument('--do_batched', default=0, type=int)
94 | parser.add_argument('--rerank_type', type=str, default="text", choices=["text", "code"],
95 | help="Whether to perform code or text reranking")
96 | parser.add_argument('--code_prompt_type', type=str, default="docstring",
97 | choices=["docstring", "github_issue"],
98 | help="Type of code prompt to use (only applicable when rerank_type is 'code')")
99 | args = parser.parse_args()
100 |
101 | # Validate arguments
102 | if args.rerank_type == "text" and args.code_prompt_type != "docstring":
103 | print("Warning: code_prompt_type is ignored when rerank_type is 'text'")
104 |
105 | if args.rerank_type == "code" and (args.use_logits or args.use_alpha):
106 | print("Warning: Code reranking does not support logits or alpha mode. These will be disabled.")
107 | args.use_logits = 0
108 | args.use_alpha = 0
109 |
110 | rerank_beir_outputs(args.model, args.output_dir, args.data_dir, args.dataset,
111 | args.data_type, args.use_logits, args.use_alpha, args.llm_top_k,
112 | args.window_size, args.step_size, args.do_batched, args.context_size,
113 | args.rerank_type, args.code_prompt_type)
--------------------------------------------------------------------------------
/scripts/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/scripts/utils/__init__.py
--------------------------------------------------------------------------------
/scripts/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | from dataclasses import dataclass, field
4 | from typing import Dict, Optional, Sequence
5 |
6 | import torch
7 | import transformers
8 | from torch.utils.data import Dataset
9 | from ftfy import fix_text
10 |
11 | max_psg_num = 20
12 | START_IDX = ord('A')
13 | IGNORE_INDEX = -100
14 |
15 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
16 | """Tokenize a list of strings."""
17 | tokenized_list = [
18 | tokenizer(
19 | text,
20 | return_tensors="pt",
21 | padding="longest",
22 | truncation=True,
23 | )
24 | for text in strings
25 | ]
26 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
27 | input_ids_lens = labels_lens = [
28 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
29 | ]
30 | return dict(
31 | input_ids=input_ids,
32 | labels=labels,
33 | input_ids_lens=input_ids_lens,
34 | labels_lens=labels_lens,
35 | )
36 |
37 | def preprocess(
38 | sources: Sequence[str],
39 | targets: Sequence[str],
40 | tokenizer: transformers.PreTrainedTokenizer,
41 | ) -> Dict:
42 | """Preprocess the data by tokenizing."""
43 | examples = [s + t for s, t in zip(sources, targets)]
44 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
45 | input_ids = examples_tokenized["input_ids"]
46 | labels = copy.deepcopy(input_ids)
47 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
48 | label[:source_len] = IGNORE_INDEX
49 | return input_ids, labels, sources_tokenized["input_ids_lens"]
50 |
51 | class RankingDataset(Dataset):
52 | def __init__(self, raw_data, model_tokenizer, type) -> None:
53 | self.raw_data = raw_data
54 | self.tokenizer = model_tokenizer
55 | self.tokenizer.padding_side="left"
56 | self.type = type
57 | self.system_message_supported = "system" in self.tokenizer.chat_template
58 |
59 | def __getitem__(self, index):
60 | conversation = self.raw_data[index]["conversations"]
61 | sys_msg = conversation[0]['value']
62 | input_context = conversation[1]['value']
63 | target_generation = conversation[2]["value"]
64 |
65 | if self.system_message_supported:
66 | messages = [
67 | {"role": "system", "content": sys_msg},
68 | {"role": "user", "content": input_context}
69 | ]
70 | else:
71 | messages = [
72 | {"role": "user", "content": sys_msg + "\n " + input_context}
73 | ]
74 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75 | prompt += "["
76 | prompt = fix_text(prompt)
77 |
78 | if self.type == "train":
79 | label_map = {}
80 | label_rank = 0
81 | for token in target_generation:
82 | if token.isalpha():
83 | label_map[token] = label_rank
84 | label_rank += 1
85 |
86 | label = [label_map[chr(c)] for c in range(START_IDX, START_IDX+len(label_map))]
87 |
88 | elif self.type == "eval":
89 | label = [self.raw_data[index]["id"]] + self.raw_data[index]["docids"] + self.raw_data[index]["scores"]
90 | else:
91 | raise Exception("Invalid run type specified for Dataset. Choose from ['train', 'eval']")
92 | return prompt, label
93 |
94 | def __len__(self):
95 | return len(self.raw_data)
96 |
97 | class GenerationDataset(Dataset):
98 | def __init__(self, raw_data, model_tokenizer, combined=False) -> None:
99 | self.raw_data = raw_data
100 | self.tokenizer = model_tokenizer
101 | self.combined = combined
102 | self.system_message_supported = "system" in self.tokenizer.chat_template
103 |
104 | def __getitem__(self, index):
105 | conversation = self.raw_data[index]["conversations"]
106 | sys_msg = conversation[0]['value']
107 | input_context = conversation[1]['value']
108 | label = conversation[2]["value"]
109 | label += self.tokenizer.eos_token
110 |
111 | if self.system_message_supported:
112 | messages = [
113 | {"role": "system", "content": sys_msg},
114 | {"role": "user", "content": input_context}
115 | ]
116 | else:
117 | messages = [
118 | {"role": "user", "content": sys_msg + "\n " + input_context}
119 | ]
120 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
121 | prompt = fix_text(prompt)
122 | if self.combined:
123 | label_map = {}
124 | label_rank = 0
125 | for token in conversation[2]["value"]:
126 | if token.isalpha():
127 | label_map[token] = label_rank
128 | label_rank += 1
129 |
130 | rank_label = [label_map[chr(c)] for c in range(START_IDX, START_IDX+len(label_map))]
131 | return prompt, label, rank_label
132 | else:
133 | return prompt, label
134 |
135 | def __len__(self):
136 | return len(self.raw_data)
137 |
138 | def ranking_collate_fn(data, tokenizer):
139 | prompts, labels = list(zip(*data))
140 | tokenized_inputs = tokenizer(prompts, padding="longest", truncation=False, return_tensors="pt")
141 | return tokenized_inputs, labels
142 |
143 | def generation_collate_fn(data, tokenizer):
144 | prompts, labels = list(zip(*data))
145 | tokenized_inputs, labels, source_lens = preprocess(prompts, labels, tokenizer)
146 | tokenized_inputs = torch.nn.utils.rnn.pad_sequence(
147 | tokenized_inputs, batch_first=True, padding_value=tokenizer.pad_token_id
148 | )
149 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
150 | return tokenized_inputs, labels
151 |
152 | def combined_collate_fn(data, tokenizer):
153 | prompts, labels, rank_labels = list(zip(*data))
154 | tokenized_inputs, labels, source_lens = preprocess(prompts, labels, tokenizer)
155 | tokenized_inputs = torch.nn.utils.rnn.pad_sequence(
156 | tokenized_inputs, batch_first=True, padding_value=tokenizer.pad_token_id
157 | )
158 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
159 | return tokenized_inputs, labels, rank_labels, source_lens
160 |
--------------------------------------------------------------------------------
/scripts/utils/llm_util.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import logging
4 | import json
5 | import math
6 | from beir.retrieval.evaluation import EvaluateRetrieval
7 | from utils.result import Result, ResultsLoader
8 | from utils.rankllm import PromptMode, RankLLM
9 | from utils.reranker import Reranker
10 | from utils.rank_listwise_os_llm import RankListwiseOSLLM
11 |
12 | def evaluate_results(dataset, qrels, rerank_results, rerank_type="text"):
13 | """
14 | Evaluate reranking results for both text and code datasets
15 |
16 | Args:
17 | dataset: Name of the dataset
18 | qrels: Ground truth relevance judgments
19 | rerank_results: Results to evaluate
20 | rerank_type: Either "text" or "code" reranking
21 |
22 | Returns:
23 | For text reranking: (ndcg, _map, recall, precision)
24 | For code reranking: Dictionary of MRR values at different cutoffs
25 | """
26 | metrics_to_evaluate = [1, 3, 5, 10, 20, 100]
27 |
28 | if rerank_type == "code":
29 | mrr_at_k = {}
30 |
31 | for k in metrics_to_evaluate:
32 | mrr_sum = 0.0
33 | num_queries = 0
34 |
35 | for qid in qrels:
36 | if qid in rerank_results:
37 | sorted_docs = sorted(rerank_results[qid].items(), key=lambda x: x[1], reverse=True)[:k]
38 |
39 | for rank, (doc_id, _) in enumerate(sorted_docs, start=1):
40 | if doc_id in qrels[qid] and qrels[qid][doc_id] > 0:
41 | mrr_sum += 1.0 / rank
42 | break
43 | num_queries += 1
44 |
45 | mrr = mrr_sum / num_queries if num_queries > 0 else 0.0
46 | mrr_at_k[k] = mrr
47 |
48 | return mrr_at_k
49 | else: # text reranking
50 | retriever = EvaluateRetrieval()
51 |
52 | ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, metrics_to_evaluate)
53 |
54 | if dataset == "trec-covid":
55 | recall_cap_metrics = metrics_to_evaluate + [125]
56 | recall = retriever.evaluate_custom(qrels, rerank_results, recall_cap_metrics, metric="recall_cap")
57 |
58 | return ndcg, _map, recall, precision
59 |
60 | def get_results_to_eval(results):
61 | eval_results = {}
62 |
63 | for result in results:
64 | hits = result.hits
65 | qid = hits[0]['qid']
66 | eval_results[qid] = {hit['docid']: hit['score'] for hit in hits}
67 |
68 | return eval_results
69 |
70 | def save_rerank_results(output_path, dataset, results, top_k, use_logits=False, use_alpha=False, is_llm_result=False):
71 | suffix_parts = []
72 |
73 | if is_llm_result:
74 | suffix_parts.append("_llm")
75 | suffix_parts.append("_FIRST" if use_logits else "_gen")
76 | suffix_parts.append("_alpha" if use_alpha else "_num")
77 | else:
78 | suffix_parts.append("_ce")
79 |
80 | suffix = "".join(suffix_parts)
81 | if output_path.endswith(dataset):
82 | rerank_path = os.path.join(output_path, f"rerank_{top_k}{suffix}.json")
83 | else:
84 | rerank_path = os.path.join(output_path, f"rerank_{top_k}{suffix}.json")
85 |
86 | os.makedirs(os.path.dirname(rerank_path), exist_ok=True)
87 |
88 | print(f"Saved to: {rerank_path}")
89 | with open(rerank_path, "w") as f:
90 | json.dump(results, f, indent=4)
91 |
92 | def rerank_beir_outputs_llm(model, results_for_rerank, use_logits, use_alpha, top_k, window_size, step_size, batched, context_size, rerank_type="text", code_prompt_type="docstring"):
93 | """
94 | Rerank outputs using either text or code reranking
95 |
96 | Args:
97 | rerank_type (str): Whether to perform "text" or "code" reranking
98 | code_prompt_type (str): For code reranking, whether to use "docstring" or "github_issue" prompts
99 | """
100 | # Validate parameters for code reranking
101 | if rerank_type == "code":
102 | if use_logits or use_alpha:
103 | print("Warning: Code reranking does not support logits or alpha mode. These will be disabled.")
104 | use_logits = False
105 | use_alpha = False
106 |
107 | # Select appropriate system message based on rerank type and prompt type
108 | if rerank_type == "code":
109 | if code_prompt_type == "docstring":
110 | system_message = "You are CodeRanker, an intelligent code reviewer that can analyze doc strings and rank code snippets based on their relevance to the doc string."
111 | elif code_prompt_type == "github_issue":
112 | system_message = "You are CodeRanker, an intelligent code reviewer that can analyze GitHub issues and rank code functions based on their relevance to contain the faults causing the GitHub issue."
113 | else:
114 | raise ValueError(f"Invalid code_prompt_type: {code_prompt_type}")
115 | else: # text reranking
116 | system_message = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query"
117 |
118 | # Initialize the ranking model
119 | agent = RankListwiseOSLLM(
120 | model=model,
121 | context_size=context_size,
122 | prompt_mode=PromptMode.RANK_GPT,
123 | num_few_shot_examples=0,
124 | device="cuda",
125 | num_gpus=1,
126 | variable_passages=True,
127 | window_size=window_size,
128 | system_message=system_message,
129 | batched=batched,
130 | rerank_type=rerank_type,
131 | code_prompt_type=code_prompt_type
132 | )
133 |
134 | # Perform reranking
135 | reranker = Reranker(agent=agent)
136 | reranked_results = reranker.rerank(
137 | retrieved_results=results_for_rerank,
138 | use_logits=use_logits,
139 | use_alpha=use_alpha,
140 | rank_start=0,
141 | rank_end=top_k,
142 | window_size=window_size,
143 | step=step_size,
144 | logging=False,
145 | batched=batched
146 | )
147 |
148 | for result in reranked_results:
149 | for rank, hit in enumerate(result.hits, start=1):
150 | hit['rank'] = rank
151 |
152 | return reranked_results
153 |
--------------------------------------------------------------------------------
/scripts/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch import Tensor
5 | import torch.nn.functional as F
6 | from itertools import product
7 | import numpy as np
8 |
9 | def lambdarank(y_pred, y_true=None, eps=1e-10, padded_value_indicator=-100, weighing_scheme="ndcgLoss2_scheme", k=None,
10 | sigma=1., mu=10., reduction="mean", reduction_log="binary"):
11 | """
12 | LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
13 | Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
14 | :param y_pred: predictions from the model, shape [batch_size, slate_length]
15 | :param y_true: ground truth labels, shape [batch_size, slate_length]
16 | :param eps: epsilon value, used for numerical stability
17 | :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
18 | :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
19 | :param k: rank at which the loss is truncated
20 | :param sigma: score difference weight used in the sigmoid function
21 | :param mu: optional weight used in NDCGLoss2++ weighing scheme
22 | :param reduction: losses reduction method, could be either a sum or a mean
23 | :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
24 | :return: loss value, a torch.Tensor
25 | """
26 | if y_true is None:
27 | y_true = torch.zeros_like(y_pred).to(y_pred.device)
28 | y_true[:, 0] = 1
29 |
30 | device = y_pred.device
31 |
32 | # sort the true and predicted relevancy scores.
33 | y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
34 | y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
35 |
36 | # mask out the pairs of indices (i, j) containing index of a padded element.
37 | true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
38 | true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
39 | padded_pairs_mask = torch.isfinite(true_diffs)
40 |
41 | if weighing_scheme != "ndcgLoss1_scheme":
42 | padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)
43 |
44 | ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
45 | ndcg_at_k_mask[:k, :k] = 1
46 |
47 | # clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
48 | true_sorted_by_preds.clamp_(min=0.)
49 | y_true_sorted.clamp_(min=0.)
50 |
51 | # find the gains, discounts and ideal DCGs per slate.
52 | pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
53 | D = torch.log2(1. + pos_idxs.float())[None, :]
54 | maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
55 | G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
56 |
57 | # apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
58 | if weighing_scheme is None:
59 | weights = 1.
60 | else:
61 | weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore
62 |
63 | # clamping the array entries to maintain correct backprop (log(0) and division by 0)
64 | scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
65 | scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
66 | weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
67 | if reduction_log == "natural":
68 | losses = torch.log(weighted_probas)
69 | elif reduction_log == "binary":
70 | losses = torch.log2(weighted_probas)
71 | else:
72 | raise ValueError("Reduction logarithm base can be either natural or binary")
73 |
74 | if reduction == "sum":
75 | loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
76 | elif reduction == "mean":
77 | loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
78 | else:
79 | raise ValueError("Reduction method can be either sum or mean")
80 |
81 | return loss
82 |
83 | def ndcgLoss1_scheme(G, D, *args):
84 | return (G / D)[:, :, None]
85 |
86 |
87 | def ndcgLoss2_scheme(G, D, *args):
88 | pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
89 | delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
90 | deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
91 | deltas.diagonal().zero_()
92 |
93 | return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])
94 |
95 | def rank_net(y_pred, y_true, weighted=False, use_rank=False, weight_by_diff=False,
96 | weight_by_diff_powed=False):
97 | """
98 | RankNet loss introduced in "Learning to Rank using Gradient Descent".
99 | :param y_pred: predictions from the model, shape [batch_size, slate_length]
100 | :param y_true: ground truth labels, shape [batch_size, slate_length]
101 | :param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences.
102 | :param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences.
103 | :return: loss value, a torch.Tensor
104 | """
105 | if use_rank is None:
106 | y_true = torch.tensor([[1 / (np.argsort(y_true)[::-1][i] + 1) for i in range(y_pred.size(1))]] * y_pred.size(0)).cuda()
107 |
108 | # generate every pair of indices from the range of document length in the batch
109 | document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))
110 |
111 | pairs_true = y_true[:, document_pairs_candidates]
112 | selected_pred = y_pred[:, document_pairs_candidates]
113 |
114 | # calculate the relative true relevance of every candidate pair
115 | true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
116 | pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]
117 |
118 | # filter just the pairs that are 'positive' and did not involve a padded instance
119 | # we can do that since in the candidate pairs we had symetric pairs so we can stick with
120 | # positive ones for a simpler loss function formulation
121 | the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))
122 |
123 | pred_diffs = pred_diffs[the_mask]
124 |
125 | weight = None
126 | if weighted:
127 | values, indices = torch.sort(y_true, descending=True)
128 | ranks = torch.zeros_like(indices)
129 | ranks.scatter_(1, indices, torch.arange(1, y_true.numel() + 1).to(y_true.device).view_as(indices))
130 | pairs_ranks = ranks[:, document_pairs_candidates]
131 | rank_sum = pairs_ranks.sum(-1)
132 | weight = 1/rank_sum[the_mask] #Relevance Feedback
133 | # rank_prod=pairs_ranks[:, :, 0]*pairs_ranks[:, :, 1]
134 | # weight = rank_sum[the_mask]/rank_prod[the_mask]
135 | else:
136 | if weight_by_diff:
137 | abs_diff = torch.abs(true_diffs)
138 | weight = abs_diff[the_mask]
139 | elif weight_by_diff_powed:
140 | true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2)
141 | abs_diff = torch.abs(true_pow_diffs)
142 | weight = abs_diff[the_mask]
143 |
144 | # 'binarize' true relevancy diffs since for a pairwise loss we just need to know
145 | # whether one document is better than the other and not about the actual difference in
146 | # their relevancy levels
147 | true_diffs = (true_diffs > 0).type(torch.float32)
148 | true_diffs = true_diffs[the_mask]
149 |
150 | return nn.BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)
151 |
152 |
153 | class ADRMSELoss(nn.Module):
154 | def __init__(self, eps: float = 1e-10):
155 | super().__init__()
156 | self.eps = eps
157 |
158 | def forward(self, scores: torch.Tensor, reranked_target_loss: float = None) -> torch.Tensor:
159 | """
160 | Compute the Approx Discounted Rank MSE (ADR-MSE) loss.
161 | :param scores: Tensor of shape [batch_size, slate_length] containing scores for each passage.
162 | :param reranked_target_loss: An additional parameter that is ignored in the computation.
163 | :return: Scalar tensor representing the ADR-MSE loss.
164 | """
165 | batch_size, slate_length = scores.size()
166 |
167 | # Compute the approximated ranks
168 | softmax_scores = torch.softmax(scores, dim=1)
169 | approx_ranks = torch.cumsum(softmax_scores, dim=1)
170 |
171 | # Compute the actual ranks
172 | sorted_indices = torch.argsort(scores, dim=1, descending=True)
173 | ranks = torch.argsort(sorted_indices, dim=1) + 1 # Convert to 1-based ranks
174 |
175 | # Compute the logarithmic discount
176 | log_discounts = torch.log2(ranks.float() + 1)
177 |
178 | rank_diffs = (ranks.float() - approx_ranks) ** 2
179 |
180 | # Apply the logarithmic discount
181 | discounted_diffs = rank_diffs / log_discounts
182 |
183 | loss = discounted_diffs.mean()
184 | return loss
185 |
186 |
187 | def listNet(y_pred, y_true, eps=1e-10, padded_value_indicator=-1):
188 | y_pred = y_pred.clone()
189 | y_true = y_true.clone()
190 |
191 | mask = y_true == padded_value_indicator
192 | y_pred[mask] = float('-inf')
193 | y_true[mask] = float('-inf')
194 |
195 | preds_smax = F.softmax(y_pred, dim=1)
196 | true_smax = F.softmax(y_true, dim=1)
197 |
198 | preds_smax = preds_smax + eps
199 | preds_log = torch.log(preds_smax)
200 |
201 | return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
202 |
203 | loss_dict = {
204 | "lambdarank": lambdarank,
205 | "ranknet": rank_net,
206 | "listnet_loss": listNet,
207 | "adr_mse_loss": ADRMSELoss()
208 | }
--------------------------------------------------------------------------------
/scripts/utils/reranker.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pathlib import Path
3 | from typing import List
4 | import os
5 | import time
6 |
7 | from tqdm import tqdm
8 |
9 | from utils.rankllm import RankLLM
10 | from utils.result import Result
11 |
12 | class Reranker:
13 | def __init__(self, agent: RankLLM) -> None:
14 | self._agent = agent
15 |
16 | def rerank(
17 | self,
18 | retrieved_results: List[Result],
19 | use_logits: bool = False,
20 | use_alpha: bool = False,
21 | rank_start: int = 0,
22 | rank_end: int = 100,
23 | window_size: int = 20,
24 | step: int = 10,
25 | logging: bool = False,
26 | batched: bool = False
27 | ) -> List[Result]:
28 | """
29 | Reranks a list of retrieved results using the RankLLM agent.
30 |
31 | This function applies a sliding window algorithm to rerank the results.
32 | Each window of results is processed by the RankLLM agent to obtain a new ranking.
33 |
34 | Args:
35 | retrieved_results (List[Result]): The list of results to be reranked.
36 | rank_start (int, optional): The starting rank for processing. Defaults to 0.
37 | rank_end (int, optional): The end rank for processing. Defaults to 100.
38 | window_size (int, optional): The size of each sliding window. Defaults to 20.
39 | step (int, optional): The step size for moving the window. Defaults to 10.
40 | logging (bool, optional): Enables logging of the reranking process. Defaults to False.
41 |
42 | Returns:
43 | List[Result]: A list containing the reranked results.
44 | """
45 | if batched:
46 | return self._agent.sliding_windows_batched(
47 | retrieved_results,
48 | use_logits=use_logits,
49 | use_alpha=use_alpha,
50 | rank_start=max(rank_start, 0),
51 | rank_end=min(rank_end, len(retrieved_results[0].hits)), #TODO: Fails arbitrary hit sizes
52 | window_size=window_size,
53 | step=step,
54 | logging=logging,
55 | )
56 |
57 | rerank_results = []
58 | for result in tqdm(retrieved_results):
59 | rerank_result = self._agent.sliding_windows(
60 | result,
61 | use_logits=use_logits,
62 | use_alpha=use_alpha,
63 | rank_start=max(rank_start, 0),
64 | rank_end=min(rank_end, len(result.hits)),
65 | window_size=window_size,
66 | step=step,
67 | logging=logging,
68 | )
69 | rerank_results.append(rerank_result)
70 | return rerank_results
71 |
--------------------------------------------------------------------------------
/scripts/utils/result.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, Dict, List
3 |
4 |
5 | class RankingExecInfo:
6 | def __init__(
7 | self, prompt, response: str, input_token_count: int, output_token_count: int
8 | ):
9 | self.prompt = prompt
10 | self.response = response
11 | self.input_token_count = input_token_count
12 | self.output_token_count = output_token_count
13 |
14 | def __repr__(self):
15 | return str(self.__dict__)
16 |
17 |
18 | class Result:
19 | def __init__(
20 | self,
21 | query: str,
22 | hits: List[Dict[str, Any]],
23 | ranking_exec_summary: List[RankingExecInfo] = None,
24 | ):
25 | self.query = query
26 | self.hits = hits
27 | self.ranking_exec_summary = ranking_exec_summary
28 |
29 | def __repr__(self):
30 | return str(self.__dict__)
31 |
32 |
33 | class ResultsWriter:
34 | def __init__(self, results: List[Result], append: bool = False):
35 | self._results = results
36 | self._append = append
37 |
38 | def write_in_json_format(self, filename: str):
39 | results = []
40 | for result in self._results:
41 | results.append({"query": result.query, "hits": result.hits})
42 | with open(filename, "a" if self._append else "w") as f:
43 | json.dump(results, f, indent=2)
44 |
45 | class ResultsLoader:
46 | def __init__(self, filename: str):
47 | data = json.load(open(filename, 'r'))
48 | self._results = []
49 | for item in data:
50 | hits = []
51 | for hit in item['hits']:
52 | hits.append({'qid': hit['qid'], 'docid': hit['docid'], 'score': float(hit['score']), 'content': hit['content']})
53 | self._results.append(Result(query=item['query'], hits=hits))
54 |
55 | def get_results(self, with_context: bool):
56 | if with_context:
57 | return self._results
58 | else:
59 | results = dict()
60 | for result in self._results:
61 | query = result.query
62 | hits = result.hits
63 | qid = hits[0]['qid']
64 | results[qid] = dict()
65 | for hit in hits:
66 | pid = hit['docid']
67 | score = hit['score']
68 | results[qid][pid] = score
69 | return results
--------------------------------------------------------------------------------
/scripts/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import argparse
5 | from transformers import SchedulerType, CONFIG_MAPPING, MODEL_MAPPING
6 |
7 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
8 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
9 |
10 | def load_data(data_path):
11 | data = []
12 | with open(data_path, 'r') as f:
13 | for item in list(f):
14 | data.append(json.loads(item))
15 | f.close()
16 | return data
17 |
18 | # NEFTune: Noisy Embedding
19 | def NEFTune(model, noise_alpha=5):
20 | def noised_embed(orig_embed, noise_alpha):
21 | def new_func(x):
22 | # during training, we add noise to the embedding
23 | # during generation, we don't add noise to the embedding
24 | if model.training:
25 | embed_init = orig_embed(x)
26 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
27 | mag_norm = noise_alpha/torch.sqrt(dims)
28 | return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
29 | else:
30 | return orig_embed(x)
31 | return new_func
32 | orig_forward = model.base_model.embed_tokens.forward
33 | model.base_model.embed_tokens.forward = noised_embed(orig_forward, noise_alpha)
34 | return model
35 |
36 |
37 | def parse_args():
38 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
39 |
40 | # parser.add_argument(
41 | # "--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data."
42 | # )
43 | # parser.add_argument(
44 | # "--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data."
45 | # )
46 | parser.add_argument(
47 | "--model_name_or_path",
48 | type=str,
49 | help="Path to pretrained model or model identifier from huggingface.co/models.",
50 | required=False,
51 | )
52 | parser.add_argument(
53 | "--config_name",
54 | type=str,
55 | default=None,
56 | help="Pretrained config name or path if not the same as model_name",
57 | )
58 | parser.add_argument(
59 | "--tokenizer_name",
60 | type=str,
61 | default=None,
62 | help="Pretrained tokenizer name or path if not the same as model_name",
63 | )
64 | parser.add_argument(
65 | "--train_dataset_path",
66 | type=str,
67 | required=True,
68 | help="Training dataset path in jsonl format"
69 | )
70 | parser.add_argument(
71 | "--eval_dataset_path",
72 | type=str,
73 | required=True,
74 | help="Validation dataset path in jsonl format"
75 | )
76 | parser.add_argument(
77 | "--beir_data_path",
78 | type=str,
79 | required=True,
80 | help="BEIR(MSMARCO) dataset for validation"
81 | )
82 | parser.add_argument(
83 | "--cache_dir",
84 | type=str,
85 | help="Path to cache"
86 | )
87 | parser.add_argument(
88 | "--per_device_train_batch_size",
89 | type=int,
90 | default=8,
91 | help="Batch size (per device) for the training dataloader.",
92 | )
93 | parser.add_argument(
94 | "--per_device_eval_batch_size",
95 | type=int,
96 | default=8,
97 | help="Batch size (per device) for the evaluation dataloader.",
98 | )
99 | parser.add_argument(
100 | "--learning_rate",
101 | type=float,
102 | default=5e-6,
103 | help="Initial learning rate (after the potential warmup period) to use.",
104 | )
105 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
106 | parser.add_argument("--noisy_embedding_alpha", type=int, default=None, help="NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings")
107 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
108 | parser.add_argument(
109 | "--max_train_steps",
110 | type=int,
111 | default=None,
112 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
113 | )
114 | parser.add_argument(
115 | "--gradient_accumulation_steps",
116 | type=int,
117 | default=1,
118 | help="Number of updates steps to accumulate before performing a backward/update pass.",
119 | )
120 | parser.add_argument(
121 | "--ranking_loss",
122 | type=str,
123 | default="lambda",
124 | help="Ranking loss to use",
125 | choices=["lambda", "listnet", "ranknet"]
126 | )
127 | parser.add_argument(
128 | "--weighted", action="store_true", help="Use weighting with Ranknet"
129 | )
130 | parser.add_argument(
131 | "--objective",
132 | type=str,
133 | default="generation",
134 | help="Training objective for reranker training",
135 | choices=["ranking", "generation", "combined"]
136 | )
137 | parser.add_argument(
138 | "--lr_scheduler_type",
139 | type=SchedulerType,
140 | default="linear",
141 | help="The scheduler type to use.",
142 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
143 | )
144 | parser.add_argument(
145 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
146 | )
147 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
148 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
149 | parser.add_argument(
150 | "--model_type",
151 | type=str,
152 | default=None,
153 | help="Model type to use if training from scratch.",
154 | choices=MODEL_TYPES,
155 | )
156 |
157 | parser.add_argument(
158 | "--preprocessing_num_workers",
159 | type=int,
160 | default=None,
161 | help="The number of processes to use for the preprocessing.",
162 | )
163 | parser.add_argument(
164 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
165 | )
166 | parser.add_argument(
167 | "--trust_remote_code",
168 | type=bool,
169 | default=True,
170 | help=(
171 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
172 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
173 | "execute code present on the Hub on your local machine."
174 | ),
175 | )
176 | parser.add_argument(
177 | "--checkpointing_steps",
178 | type=str,
179 | default="epoch",
180 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
181 | )
182 | parser.add_argument(
183 | "--eval_steps",
184 | type=int,
185 | default=100,
186 | help="Evaluate the model at the end of every n steps",
187 | )
188 | parser.add_argument(
189 | "--resume_from_checkpoint",
190 | type=str,
191 | default=None,
192 | help="If the training should continue from a checkpoint folder.",
193 | )
194 | parser.add_argument(
195 | "--with_tracking",
196 | action="store_true",
197 | help="Whether to enable experiment trackers for logging.",
198 | )
199 | parser.add_argument(
200 | "--gradient_checkpointing",
201 | action="store_true",
202 | help="Whether to use gradient checkpointing.",
203 | )
204 | parser.add_argument(
205 | "--report_to",
206 | type=str,
207 | default="all",
208 | help=(
209 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
210 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
211 | "Only applicable when `--with_tracking` is passed."
212 | ),
213 | )
214 | parser.add_argument(
215 | "--low_cpu_mem_usage",
216 | action="store_true",
217 | help=(
218 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
219 | "If passed, LLM loading time and RAM consumption will be benefited."
220 | ),
221 | )
222 | args = parser.parse_args()
223 | assert args.output_dir is not None
224 | return args
--------------------------------------------------------------------------------
/tevatron/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | *.egg-info/
--------------------------------------------------------------------------------
/tevatron/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/tevatron/README.md:
--------------------------------------------------------------------------------
1 | # Tevatron
2 | Tevatron is a simple and efficient toolkit for training and running dense retrievers with deep language models.
3 | The toolkit has a modularized design for easy research; a set of command line tools are also provided for fast
4 | development and testing. A set of easy-to-use interfaces to Huggingface's state-of-the-art pre-trained transformers
5 | ensures Tevatron's superior performance.
6 |
7 | *Tevatron is currently under initial development stage. We will be actively adding new features and API changes
8 | may happen. Suggestions, feature requests and PRs are welcomed.*
9 |
10 | ## Features
11 | - Command line interface for dense retriever training/encoding and dense index search.
12 | - Flexible and extendable Pytorch retriever models.
13 | - Highly efficient Trainer, a subclass of Huggingface Trainer, that naively support training performance features like mixed precision and distributed data parallel.
14 | - Fast and memory-efficient train/inference data access based on memory mapping with Apache Arrow through Huggingface datasets.
15 | - Jax/Flax training/encoding on TPU
16 |
17 | ## Installation
18 | First install neural network and similarity search backends,
19 | namely Pytorch (or Jax) and FAISS.
20 | Check out the official installation guides for [Pytorch](https://pytorch.org/get-started/locally/#start-locally)
21 | , [Jax](https://github.com/google/jax) / [Flax](https://flax.readthedocs.io/en/latest/installation.html)
22 | and [FAISS](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) accordingly.
23 |
24 | Then install Tevatron with pip,
25 | ```bash
26 | pip install tevatron
27 | ```
28 |
29 | Or typically for development and research, clone this repo and install as editable,
30 | ```
31 | git https://github.com/texttron/tevatron
32 | cd tevatron
33 | pip install --editable .
34 | ```
35 |
36 | > Note: The current code base has been tested with, `torch==1.10.1`, `faiss-cpu==1.7.2`, `transformers==4.15.0`, `datasets==1.17.0`
37 |
38 | Optionally, you can also install GradCache to support our gradient cache feature during training by:
39 | ```bash
40 | git clone https://github.com/luyug/GradCache
41 | cd GradCache
42 | pip install .
43 | ```
44 |
45 | ## Documentation
46 | - [**Please view the documentation here**](http://tevatron.ai/)
47 |
48 |
49 | ## Examples
50 | In the `/examples` folder, we provided full pipeline instructions for various IR/QA tasks.
51 |
52 | ## Citation
53 | If you find Tevatron helpful, please consider citing our [paper](https://arxiv.org/abs/2203.05765).
54 | ```
55 | @article{Gao2022TevatronAE,
56 | title={Tevatron: An Efficient and Flexible Toolkit for Dense Retrieval},
57 | author={Luyu Gao and Xueguang Ma and Jimmy J. Lin and Jamie Callan},
58 | journal={ArXiv},
59 | year={2022},
60 | volume={abs/2203.05765}
61 | }
62 | ```
63 |
64 | ## Contacts
65 | If you have a toolkit specific question, feel free to open an issue.
66 |
67 | You can also reach out to us for general comments/suggestions/questions through email.
68 | - Luyu Gao luyug@cs.cmu.edu
69 | - Xueguang Ma x93ma@uwaterloo.ca
70 |
--------------------------------------------------------------------------------
/tevatron/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Tevatron
2 | nav:
3 | - Home: index.md
4 | - Datasets: datasets.md
5 | - Training: training.md
6 | - Encoding: encoding.md
7 | - Retrieval: retrieval.md
8 |
9 | theme:
10 | name: 'material'
11 | palette:
12 | primary: 'green'
13 | accent: 'green'
14 |
15 | markdown_extensions:
16 | - pymdownx.highlight
17 | - pymdownx.superfences
18 | - toc:
19 | permalink: true
20 | - attr_list
21 |
22 | repo_name: 'texttron/tevatron'
23 | repo_url: 'https://github.com/texttron/tevatron'
--------------------------------------------------------------------------------
/tevatron/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='tevatron',
5 | version='0.0.1',
6 | packages=find_packages("src"),
7 | package_dir={'': 'src'},
8 | url='https://github.com/texttron/tevatron',
9 | license='Apache 2.0',
10 | author='Luyu Gao',
11 | author_email='luyug@cs.cmu.edu',
12 | description='Tevatron: A toolkit for learning and running deep dense retrieval models.',
13 | python_requires='>=3.7',
14 | install_requires=[
15 | "torch==2.0.1",
16 | "transformers==4.30.2",
17 | "datasets==2.13.1",
18 | "accelerate==0.20.3",
19 | "faiss-cpu==1.7.2",
20 | "sentencepiece==0.1.99",
21 | "tokenizers==0.13.3"
22 | ]
23 | )
24 |
--------------------------------------------------------------------------------
/tevatron/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/tevatron/src/.DS_Store
--------------------------------------------------------------------------------
/tevatron/src/tevatron/__init__.py:
--------------------------------------------------------------------------------
1 | from .faiss_retriever import BaseFaissIPRetriever
2 | from . import utils
--------------------------------------------------------------------------------
/tevatron/src/tevatron/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional, List
4 | from transformers import TrainingArguments
5 |
6 |
7 | @dataclass
8 | class ModelArguments:
9 | model_name_or_path: str = field(
10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
11 | )
12 | target_model_path: str = field(
13 | default=None,
14 | metadata={"help": "Path to pretrained reranker target model"}
15 | )
16 | config_name: Optional[str] = field(
17 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
18 | )
19 | tokenizer_name: Optional[str] = field(
20 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
21 | )
22 | cache_dir: Optional[str] = field(
23 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
24 | )
25 |
26 | # modeling
27 | untie_encoder: bool = field(
28 | default=False,
29 | metadata={"help": "no weight sharing between qry passage encoders"}
30 | )
31 |
32 | # out projection
33 | add_pooler: bool = field(default=False)
34 | projection_in_dim: int = field(default=768)
35 | projection_out_dim: int = field(default=768)
36 | use_mean_pooling: bool = field(default=False)
37 |
38 | # for Jax training
39 | dtype: Optional[str] = field(
40 | default="float32",
41 | metadata={
42 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
43 | "of `[float32, float16, bfloat16]`. "
44 | },
45 | )
46 |
47 |
48 | @dataclass
49 | class DataArguments:
50 | train_dir: str = field(
51 | default=None, metadata={"help": "Path to train directory"}
52 | )
53 | dataset_name: str = field(
54 | default=None, metadata={"help": "huggingface dataset name"}
55 | )
56 | passage_field_separator: str = field(default=' ')
57 | dataset_proc_num: int = field(
58 | default=12, metadata={"help": "number of proc used in dataset preprocess"}
59 | )
60 | train_n_passages: int = field(default=8)
61 | positive_passage_no_shuffle: bool = field(
62 | default=False, metadata={"help": "always use the first positive passage"})
63 | negative_passage_no_shuffle: bool = field(
64 | default=False, metadata={"help": "always use the first negative passages"})
65 |
66 | encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"})
67 | encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
68 | encode_is_qry: bool = field(default=False)
69 | encode_num_shard: int = field(default=1)
70 | encode_shard_index: int = field(default=0)
71 |
72 | q_max_len: int = field(
73 | default=32,
74 | metadata={
75 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
76 | "than this will be truncated, sequences shorter will be padded."
77 | },
78 | )
79 | p_max_len: int = field(
80 | default=128,
81 | metadata={
82 | "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
83 | "than this will be truncated, sequences shorter will be padded."
84 | },
85 | )
86 | data_cache_dir: Optional[str] = field(
87 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"}
88 | )
89 | normalize_text: bool = field(default=False, metadata={"help": "Whether to normalize text"})
90 |
91 | lower_case: bool = field(default=False, metadata={"help": "Whether to lower case text"})
92 |
93 |
94 | def __post_init__(self):
95 | if self.dataset_name is not None:
96 | info = self.dataset_name.split('/')
97 | self.dataset_split = info[-1] if len(info) == 3 else 'train'
98 | self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
99 | self.dataset_language = 'default'
100 | if ':' in self.dataset_name:
101 | self.dataset_name, self.dataset_language = self.dataset_name.split(':')
102 | else:
103 | self.dataset_name = 'json'
104 | self.dataset_split = 'train'
105 | self.dataset_language = 'default'
106 | if self.train_dir is not None:
107 | files = os.listdir(self.train_dir)
108 | self.train_path = [
109 | os.path.join(self.train_dir, f)
110 | for f in files
111 | if f.endswith('jsonl') or f.endswith('json')
112 | ]
113 | else:
114 | self.train_path = None
115 |
116 |
117 | @dataclass
118 | class TevatronTrainingArguments(TrainingArguments):
119 | warmup_ratio: float = field(default=0.1)
120 | negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
121 | do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
122 |
123 | grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
124 | gc_q_chunk_size: int = field(default=4)
125 | gc_p_chunk_size: int = field(default=32)
126 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/data.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass
3 | from typing import List, Tuple
4 |
5 | import datasets
6 | from torch.utils.data import Dataset
7 | from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding
8 |
9 |
10 | from .arguments import DataArguments
11 | from .trainer import TevatronTrainer
12 |
13 | import logging
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class TrainDataset(Dataset):
18 | def __init__(
19 | self,
20 | data_args: DataArguments,
21 | dataset: datasets.Dataset,
22 | tokenizer: PreTrainedTokenizer,
23 | trainer: TevatronTrainer = None,
24 | ):
25 | self.train_data = dataset
26 | self.tok = tokenizer
27 | self.trainer = trainer
28 |
29 | self.data_args = data_args
30 | self.total_len = len(self.train_data)
31 |
32 | def create_one_example(self, text_encoding: List[int], is_query=False):
33 | item = self.tok.encode_plus(
34 | text_encoding,
35 | truncation='only_first',
36 | max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len,
37 | padding=False,
38 | return_attention_mask=False,
39 | return_token_type_ids=False,
40 | )
41 | return item
42 |
43 | def __len__(self):
44 | return self.total_len
45 |
46 | def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
47 | group = self.train_data[item]
48 | epoch = int(self.trainer.state.epoch)
49 |
50 | _hashed_seed = hash(item + self.trainer.args.seed)
51 |
52 | qry = group['query']
53 | encoded_query = self.create_one_example(qry, is_query=True)
54 |
55 | encoded_passages = []
56 | group_positives = group['positives']
57 | group_negatives = group['negatives']
58 |
59 | if self.data_args.positive_passage_no_shuffle:
60 | pos_psg = group_positives[0]
61 | else:
62 | pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]
63 | encoded_passages.append(self.create_one_example(pos_psg))
64 |
65 | negative_size = self.data_args.train_n_passages - 1
66 | if len(group_negatives) < negative_size:
67 | negs = random.choices(group_negatives, k=negative_size)
68 | elif self.data_args.train_n_passages == 1:
69 | negs = []
70 | elif self.data_args.negative_passage_no_shuffle:
71 | negs = group_negatives[:negative_size]
72 | else:
73 | _offset = epoch * negative_size % len(group_negatives)
74 | negs = [x for x in group_negatives]
75 | random.Random(_hashed_seed).shuffle(negs)
76 | negs = negs * 2
77 | negs = negs[_offset: _offset + negative_size]
78 |
79 | for neg_psg in negs:
80 | encoded_passages.append(self.create_one_example(neg_psg))
81 |
82 | return encoded_query, encoded_passages
83 |
84 |
85 | class EncodeDataset(Dataset):
86 | input_keys = ['text_id', 'text']
87 |
88 | def __init__(self, dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer, max_len=128):
89 | self.encode_data = dataset
90 | self.tok = tokenizer
91 | self.max_len = max_len
92 |
93 | def __len__(self):
94 | return len(self.encode_data)
95 |
96 | def __getitem__(self, item) -> Tuple[str, BatchEncoding]:
97 | text_id, text = (self.encode_data[item][f] for f in self.input_keys)
98 | encoded_text = self.tok.encode_plus(
99 | text,
100 | max_length=self.max_len,
101 | truncation='only_first',
102 | padding=False,
103 | return_token_type_ids=False,
104 | )
105 | return text_id, encoded_text
106 |
107 |
108 | @dataclass
109 | class QPCollator(DataCollatorWithPadding):
110 | """
111 | Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
112 | and pass batch separately to the actual collator.
113 | Abstract out data detail for the model.
114 | """
115 | max_q_len: int = 32
116 | max_p_len: int = 128
117 |
118 | def __call__(self, features):
119 | qq = [f[0] for f in features]
120 | dd = [f[1] for f in features]
121 |
122 | if isinstance(qq[0], list):
123 | qq = sum(qq, [])
124 | if isinstance(dd[0], list):
125 | dd = sum(dd, [])
126 |
127 | q_collated = self.tokenizer.pad(
128 | qq,
129 | padding='max_length',
130 | max_length=self.max_q_len,
131 | return_tensors="pt",
132 | )
133 | d_collated = self.tokenizer.pad(
134 | dd,
135 | padding='max_length',
136 | max_length=self.max_p_len,
137 | return_tensors="pt",
138 | )
139 |
140 | return q_collated, d_collated
141 |
142 |
143 | @dataclass
144 | class EncodeCollator(DataCollatorWithPadding):
145 | def __call__(self, features):
146 | text_ids = [x[0] for x in features]
147 | text_features = [x[1] for x in features]
148 | collated_features = super().__call__(text_features)
149 | return text_ids, collated_features
--------------------------------------------------------------------------------
/tevatron/src/tevatron/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import HFTrainDataset, HFQueryDataset, HFCorpusDataset
2 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor
3 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from transformers import PreTrainedTokenizer
3 | from .preprocessor import TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor
4 | from ..arguments import DataArguments
5 |
6 | DEFAULT_PROCESSORS = [TrainPreProcessor, QueryPreProcessor, CorpusPreProcessor]
7 | PROCESSOR_INFO = {
8 | 'Tevatron/wikipedia-nq': DEFAULT_PROCESSORS,
9 | 'Tevatron/wikipedia-trivia': DEFAULT_PROCESSORS,
10 | 'Tevatron/wikipedia-curated': DEFAULT_PROCESSORS,
11 | 'Tevatron/wikipedia-wq': DEFAULT_PROCESSORS,
12 | 'Tevatron/wikipedia-squad': DEFAULT_PROCESSORS,
13 | 'Tevatron/scifact': DEFAULT_PROCESSORS,
14 | 'Tevatron/msmarco-passage': DEFAULT_PROCESSORS,
15 | 'json': [None, None, None]
16 | }
17 |
18 |
19 | class HFTrainDataset:
20 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
21 | data_files = data_args.train_path
22 | if data_files:
23 | data_files = {data_args.dataset_split: data_files}
24 | self.dataset = load_dataset(data_args.dataset_name,
25 | data_args.dataset_language,
26 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
27 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][0] if data_args.dataset_name in PROCESSOR_INFO\
28 | else DEFAULT_PROCESSORS[0]
29 | self.tokenizer = tokenizer
30 | self.q_max_len = data_args.q_max_len
31 | self.p_max_len = data_args.p_max_len
32 | self.proc_num = data_args.dataset_proc_num
33 | self.neg_num = data_args.train_n_passages - 1
34 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator)
35 |
36 | def process(self, shard_num=1, shard_idx=0):
37 | self.dataset = self.dataset.shard(shard_num, shard_idx)
38 | print("Processing Dataset")
39 | if self.preprocessor is not None:
40 | self.dataset = self.dataset.map(
41 | self.preprocessor(self.tokenizer, self.q_max_len, self.p_max_len, self.separator),
42 | batched=False,
43 | num_proc=self.proc_num,
44 | remove_columns=self.dataset.column_names,
45 | desc="Running tokenizer on train dataset",
46 | )
47 | return self.dataset
48 |
49 |
50 | class HFQueryDataset:
51 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
52 | data_files = data_args.encode_in_path
53 | if data_files:
54 | data_files = {data_args.dataset_split: data_files}
55 | self.dataset = load_dataset(data_args.dataset_name,
56 | data_args.dataset_language,
57 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
58 | self.preprocessor = PROCESSOR_INFO[data_args.dataset_name][1] if data_args.dataset_name in PROCESSOR_INFO \
59 | else DEFAULT_PROCESSORS[1]
60 | self.tokenizer = tokenizer
61 | self.q_max_len = data_args.q_max_len
62 | self.proc_num = data_args.dataset_proc_num
63 |
64 | def process(self, shard_num=1, shard_idx=0):
65 | self.dataset = self.dataset.shard(shard_num, shard_idx)
66 | if self.preprocessor is not None:
67 | self.dataset = self.dataset.map(
68 | self.preprocessor(self.tokenizer, self.q_max_len),
69 | batched=False,
70 | num_proc=self.proc_num,
71 | remove_columns=self.dataset.column_names,
72 | desc="Running tokenization",
73 | )
74 | return self.dataset
75 |
76 |
77 | class HFCorpusDataset:
78 | def __init__(self, tokenizer: PreTrainedTokenizer, data_args: DataArguments, cache_dir: str):
79 | data_files = data_args.encode_in_path
80 | if data_files:
81 | data_files = {data_args.dataset_split: data_files}
82 | self.dataset = load_dataset(data_args.dataset_name,
83 | data_args.dataset_language,
84 | data_files=data_files, cache_dir=cache_dir)[data_args.dataset_split]
85 | script_prefix = data_args.dataset_name
86 | if script_prefix.endswith('-corpus'):
87 | script_prefix = script_prefix[:-7]
88 | self.preprocessor = PROCESSOR_INFO[script_prefix][2] \
89 | if script_prefix in PROCESSOR_INFO else DEFAULT_PROCESSORS[2]
90 | self.tokenizer = tokenizer
91 | self.p_max_len = data_args.p_max_len
92 | self.proc_num = data_args.dataset_proc_num
93 | self.separator = getattr(self.tokenizer, data_args.passage_field_separator, data_args.passage_field_separator)
94 |
95 | def process(self, shard_num=1, shard_idx=0):
96 | self.dataset = self.dataset.shard(shard_num, shard_idx)
97 | if self.preprocessor is not None:
98 | self.dataset = self.dataset.map(
99 | self.preprocessor(self.tokenizer, self.p_max_len, self.separator),
100 | batched=False,
101 | num_proc=self.proc_num,
102 | remove_columns=self.dataset.column_names,
103 | desc="Running tokenization",
104 | )
105 | return self.dataset
106 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/datasets/preprocessor.py:
--------------------------------------------------------------------------------
1 | class TrainPreProcessor:
2 | def __init__(self, tokenizer, query_max_length=32, text_max_length=256, separator=' '):
3 | self.tokenizer = tokenizer
4 | self.query_max_length = query_max_length
5 | self.text_max_length = text_max_length
6 | self.separator = separator
7 |
8 | def __call__(self, example):
9 | query = self.tokenizer.encode(example['query'],
10 | add_special_tokens=False,
11 | max_length=self.query_max_length,
12 | truncation=True)
13 | positives = []
14 | for pos in example['positive_passages']:
15 | text = pos['title'] + self.separator + pos['text'] if 'title' in pos else pos['text']
16 | positives.append(self.tokenizer.encode(text,
17 | add_special_tokens=False,
18 | max_length=self.text_max_length,
19 | truncation=True))
20 | negatives = []
21 | for neg in example['negative_passages']:
22 | text = neg['title'] + self.separator + neg['text'] if 'title' in neg else neg['text']
23 | negatives.append(self.tokenizer.encode(text,
24 | add_special_tokens=False,
25 | max_length=self.text_max_length,
26 | truncation=True))
27 | return {'query': query, 'positives': positives, 'negatives': negatives}
28 |
29 |
30 | class QueryPreProcessor:
31 | def __init__(self, tokenizer, query_max_length=32):
32 | self.tokenizer = tokenizer
33 | self.query_max_length = query_max_length
34 |
35 | def __call__(self, example):
36 | query_id = example['query_id']
37 | query = self.tokenizer.encode(example['query'],
38 | add_special_tokens=False,
39 | max_length=self.query_max_length,
40 | truncation=True)
41 | return {'text_id': query_id, 'text': query}
42 |
43 |
44 | class CorpusPreProcessor:
45 | def __init__(self, tokenizer, text_max_length=256, separator=' '):
46 | self.tokenizer = tokenizer
47 | self.text_max_length = text_max_length
48 | self.separator = separator
49 |
50 | def __call__(self, example):
51 | docid = example['docid']
52 | text = example['title'] + self.separator + example['text'] if 'title' in example else example['text']
53 | text = self.tokenizer.encode(text,
54 | add_special_tokens=False,
55 | max_length=self.text_max_length,
56 | truncation=True)
57 | return {'text_id': docid, 'text': text}
58 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gangiswag/llm-reranker/2d7cba423ad555064bdfc719313570b5f9525887/tevatron/src/tevatron/driver/__init__.py
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/encode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 |
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoConfig, AutoTokenizer
14 | from transformers import (
15 | HfArgumentParser,
16 | )
17 |
18 | from tevatron.arguments import ModelArguments, DataArguments, \
19 | TevatronTrainingArguments as TrainingArguments
20 | from tevatron.data import EncodeDataset, EncodeCollator
21 | from tevatron.modeling import EncoderOutput, DenseModel
22 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
38 | raise NotImplementedError('Multi-GPU encoding is not supported.')
39 |
40 | # Setup logging
41 | logging.basicConfig(
42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
43 | datefmt="%m/%d/%Y %H:%M:%S",
44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
45 | )
46 |
47 | num_labels = 1
48 | config = AutoConfig.from_pretrained(
49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
50 | num_labels=num_labels,
51 | cache_dir=model_args.cache_dir,
52 | )
53 | tokenizer = AutoTokenizer.from_pretrained(
54 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
55 | cache_dir=model_args.cache_dir,
56 | use_fast=False,
57 | )
58 |
59 | model = DenseModel.load(
60 | model_name_or_path=model_args.model_name_or_path,
61 | config=config,
62 | cache_dir=model_args.cache_dir,
63 | )
64 |
65 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
66 | if data_args.encode_is_qry:
67 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
68 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
69 | else:
70 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
71 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
72 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
73 | tokenizer, max_len=text_max_length)
74 |
75 | encode_loader = DataLoader(
76 | encode_dataset,
77 | batch_size=training_args.per_device_eval_batch_size,
78 | collate_fn=EncodeCollator(
79 | tokenizer,
80 | max_length=text_max_length,
81 | padding='max_length'
82 | ),
83 | shuffle=False,
84 | drop_last=False,
85 | num_workers=training_args.dataloader_num_workers,
86 | )
87 | encoded = []
88 | lookup_indices = []
89 | model = model.to(training_args.device)
90 | model.eval()
91 |
92 | for (batch_ids, batch) in tqdm(encode_loader):
93 | lookup_indices.extend(batch_ids)
94 | with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext():
95 | with torch.no_grad():
96 | for k, v in batch.items():
97 | batch[k] = v.to(training_args.device)
98 | if data_args.encode_is_qry:
99 | model_output: EncoderOutput = model(query=batch)
100 | encoded.append(model_output.q_reps.cpu().detach().numpy())
101 | else:
102 | model_output: EncoderOutput = model(passage=batch)
103 | encoded.append(model_output.p_reps.cpu().detach().numpy())
104 |
105 | encoded = np.concatenate(encoded)
106 |
107 | with open(data_args.encoded_save_path, 'wb') as f:
108 | pickle.dump((encoded, lookup_indices), f)
109 |
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/encode_new.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 | from copy import deepcopy
10 | import torch
11 | import json
12 |
13 | from torch.utils.data import DataLoader
14 | from transformers import AutoConfig, AutoTokenizer
15 | from transformers import (
16 | HfArgumentParser,
17 | )
18 | import csv
19 |
20 | from tevatron.arguments import ModelArguments, DataArguments, \
21 | TevatronTrainingArguments as TrainingArguments
22 | from tevatron.data import EncodeDataset, EncodeCollator
23 | from tevatron.modeling import EncoderOutput, DenseModel
24 | from tevatron.utils.normalize_text import normalize
25 | logger = logging.getLogger(__name__)
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
38 | raise NotImplementedError('Multi-GPU encoding is not supported.')
39 |
40 | # Setup logging
41 | logging.basicConfig(
42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
43 | datefmt="%m/%d/%Y %H:%M:%S",
44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
45 | )
46 |
47 | num_labels = 1
48 | config = AutoConfig.from_pretrained(
49 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
50 | num_labels=num_labels,
51 | cache_dir=model_args.cache_dir,
52 | )
53 | tokenizer = AutoTokenizer.from_pretrained(
54 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
55 | cache_dir=model_args.cache_dir,
56 | use_fast=False,
57 | )
58 | model = DenseModel.load(
59 | model_name_or_path=model_args.model_name_or_path,
60 | config=config,
61 | cache_dir=model_args.cache_dir,
62 | )
63 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
64 |
65 | encoded = []
66 | lookup_indices = []
67 | model = model.to(training_args.device)
68 | model.eval()
69 | corpus = list()
70 |
71 | if data_args.encode_is_qry:
72 | tsv_reader = csv.reader(open(data_args.encode_in_path[0]), delimiter="\t")
73 | for row in tsv_reader:
74 | corpus.append({"id": row[0], "text": row[1]})
75 | corpus_items = [(c["id"],c["text"]) for c in corpus]
76 |
77 | else:
78 | corpus_lines = open(data_args.encode_in_path[0]).readlines()
79 | for line in corpus_lines:
80 | corpus.append(deepcopy(json.loads(line)))
81 |
82 | corpus_items = [(c["id"], c["title"] + " " +c["text"]) for c in corpus]
83 |
84 | if data_args.normalize_text:
85 | corpus_items = [(c[0], normalize(c[1])) for c in corpus_items]
86 |
87 | if data_args.lower_case:
88 | corpus_items = [(c[0], c[1].lower()) for c in corpus_items]
89 |
90 | corpus_items = sorted(corpus_items, key=lambda item: len(item[1]), reverse=True)
91 |
92 | batch_size = training_args.per_device_eval_batch_size
93 | nbatch = int(len(corpus_items) / batch_size) + 1
94 |
95 | with torch.no_grad():
96 | for k in tqdm(range(nbatch)):
97 | try:
98 | start_idx = k * batch_size
99 | end_idx = min((k + 1) * batch_size, len(corpus))
100 |
101 | batch_items = corpus_items[start_idx:end_idx]
102 | batch_ids = [item[0] for item in batch_items]
103 | batch_text = [item[1] for item in batch_items]
104 |
105 | cencode = tokenizer(
106 | batch_text,
107 | padding=True,
108 | truncation='longest_first',
109 | return_tensors="pt",
110 | )
111 |
112 | cencode = {key: value.cuda() for key, value in cencode.items()}
113 |
114 | if data_args.encode_is_qry:
115 | model_output: EncoderOutput = model(query=cencode)
116 | encoded.append(model_output.q_reps.cpu().detach().numpy())
117 | else:
118 | model_output: EncoderOutput = model(passage=cencode)
119 | encoded.append(model_output.p_reps.cpu().detach().numpy())
120 |
121 | lookup_indices.extend(batch_ids)
122 | except Exception as e:
123 | print(e)
124 | continue
125 |
126 |
127 | encoded = np.concatenate(encoded)
128 |
129 | with open(data_args.encoded_save_path, 'wb') as f:
130 | pickle.dump((encoded, lookup_indices), f)
131 |
132 |
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/jax_encode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 |
6 | import datasets
7 | import jax
8 | import numpy as np
9 | from flax.training.common_utils import shard
10 | from jax import pmap
11 | from tevatron.arguments import DataArguments
12 | from tevatron.arguments import TevatronTrainingArguments as TrainingArguments
13 | from tevatron.arguments import ModelArguments
14 | from tevatron.data import EncodeCollator, EncodeDataset
15 | from tevatron.datasets import HFQueryDataset, HFCorpusDataset
16 | from torch.utils.data import DataLoader
17 | from tqdm import tqdm
18 | from flax.training.train_state import TrainState
19 | from flax import jax_utils
20 | import optax
21 | from transformers import (AutoConfig, AutoTokenizer, FlaxAutoModel,
22 | HfArgumentParser, TensorType)
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | # Setup logging
38 | logging.basicConfig(
39 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
40 | datefmt="%m/%d/%Y %H:%M:%S",
41 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
42 | )
43 |
44 | num_labels = 1
45 | config = AutoConfig.from_pretrained(
46 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
47 | num_labels=num_labels,
48 | cache_dir=model_args.cache_dir,
49 | )
50 | tokenizer = AutoTokenizer.from_pretrained(
51 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
52 | cache_dir=model_args.cache_dir,
53 | use_fast=False,
54 | )
55 |
56 | model = FlaxAutoModel.from_pretrained(model_args.model_name_or_path, config=config, from_pt=False)
57 |
58 | text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
59 | if data_args.encode_is_qry:
60 | encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
61 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
62 | else:
63 | encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
64 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
65 | encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
66 | tokenizer, max_len=text_max_length)
67 |
68 | # prepare padding batch (for last nonfull batch)
69 | dataset_size = len(encode_dataset)
70 | padding_prefix = "padding_"
71 | total_batch_size = len(jax.devices()) * training_args.per_device_eval_batch_size
72 | features = list(encode_dataset.encode_data.features.keys())
73 | padding_batch = {features[0]: [], features[1]: []}
74 | for i in range(total_batch_size - (dataset_size % total_batch_size)):
75 | padding_batch["text_id"].append(f"{padding_prefix}{i}")
76 | padding_batch["text"].append([0])
77 | padding_batch = datasets.Dataset.from_dict(padding_batch)
78 | encode_dataset.encode_data = datasets.concatenate_datasets([encode_dataset.encode_data, padding_batch])
79 |
80 | encode_loader = DataLoader(
81 | encode_dataset,
82 | batch_size=training_args.per_device_eval_batch_size * len(jax.devices()),
83 | collate_fn=EncodeCollator(
84 | tokenizer,
85 | max_length=text_max_length,
86 | padding='max_length',
87 | pad_to_multiple_of=16,
88 | return_tensors=TensorType.NUMPY,
89 | ),
90 | shuffle=False,
91 | drop_last=False,
92 | num_workers=training_args.dataloader_num_workers,
93 | )
94 |
95 | # craft a fake state for now to replicate on devices
96 | adamw = optax.adamw(0.0001)
97 | state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
98 |
99 | def encode_step(batch, state):
100 | embedding = state.apply_fn(**batch, params=state.params, train=False)[0]
101 | return embedding[:, 0]
102 |
103 | p_encode_step = pmap(encode_step)
104 | state = jax_utils.replicate(state)
105 |
106 | encoded = []
107 | lookup_indices = []
108 |
109 | for (batch_ids, batch) in tqdm(encode_loader):
110 | lookup_indices.extend(batch_ids)
111 | batch_embeddings = p_encode_step(shard(batch.data), state)
112 | encoded.extend(np.concatenate(batch_embeddings, axis=0))
113 | with open(data_args.encoded_save_path, 'wb') as f:
114 | pickle.dump((encoded[:dataset_size], lookup_indices[:dataset_size]), f)
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/jax_train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from functools import partial
5 |
6 | import datasets
7 | import jax
8 | import jax.numpy as jnp
9 | import optax
10 | from flax import jax_utils, traverse_util
11 | from flax.jax_utils import prefetch_to_device
12 | from flax.training.common_utils import get_metrics, shard
13 | from torch.utils.data import DataLoader, IterableDataset
14 | from tqdm import tqdm
15 | from transformers import AutoConfig, AutoTokenizer, FlaxAutoModel
16 | from transformers import (
17 | HfArgumentParser,
18 | set_seed,
19 | )
20 |
21 | from tevatron.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
22 | from tevatron.tevax.training import TiedParams, RetrieverTrainState, retriever_train_step, grad_cache_train_step, \
23 | DualParams
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def main():
29 | parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
30 |
31 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
32 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
33 | else:
34 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
35 | model_args: ModelArguments
36 | data_args: DataArguments
37 | training_args: TevatronTrainingArguments
38 |
39 | if (
40 | os.path.exists(training_args.output_dir)
41 | and os.listdir(training_args.output_dir)
42 | and training_args.do_train
43 | and not training_args.overwrite_output_dir
44 | ):
45 | raise ValueError(
46 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
47 | )
48 |
49 | # Setup logging
50 | logging.basicConfig(
51 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
52 | datefmt="%m/%d/%Y %H:%M:%S",
53 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
54 | )
55 | logger.warning(
56 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
57 | training_args.local_rank,
58 | training_args.device,
59 | training_args.n_gpu,
60 | bool(training_args.local_rank != -1),
61 | training_args.fp16,
62 | )
63 | logger.info("Training/evaluation parameters %s", training_args)
64 | logger.info("MODEL parameters %s", model_args)
65 |
66 | set_seed(training_args.seed)
67 |
68 | config = AutoConfig.from_pretrained(
69 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
70 | cache_dir=model_args.cache_dir,
71 | )
72 | tokenizer = AutoTokenizer.from_pretrained(
73 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
74 | cache_dir=model_args.cache_dir,
75 | )
76 | try:
77 | model = FlaxAutoModel.from_pretrained(
78 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
79 | )
80 | except:
81 | model = FlaxAutoModel.from_pretrained(
82 | model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype),
83 | from_pt=True
84 | )
85 |
86 | if data_args.train_dir:
87 | data_files = {
88 | 'train': data_args.train_path
89 | }
90 | else:
91 | data_files = None
92 |
93 | train_dataset = \
94 | datasets.load_dataset(data_args.dataset_name, data_args.dataset_language, cache_dir=model_args.cache_dir,
95 | data_files=data_files)[data_args.dataset_split]
96 |
97 | def tokenize_train(example):
98 | tokenize = partial(tokenizer, return_attention_mask=False, return_token_type_ids=False, padding=False,
99 | truncation=True)
100 | query = example['query']
101 | pos_psgs = [p['title'] + " " + p['text'] for p in example['positive_passages']]
102 | neg_psgs = [p['title'] + " " + p['text'] for p in example['negative_passages']]
103 |
104 | example['query_input_ids'] = dict(tokenize(query, max_length=32))
105 | example['pos_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in pos_psgs]
106 | example['neg_psgs_input_ids'] = [dict(tokenize(x, max_length=data_args.p_max_len)) for x in neg_psgs]
107 |
108 | return example
109 |
110 | train_data = train_dataset.map(
111 | tokenize_train,
112 | batched=False,
113 | num_proc=data_args.dataset_proc_num,
114 | desc="Running tokenizer on train dataset",
115 | )
116 | train_data = train_data.filter(
117 | function=lambda data: len(data["pos_psgs_input_ids"]) >= 1 and \
118 | len(data["neg_psgs_input_ids"]) >= data_args.train_n_passages-1, num_proc=64
119 | )
120 |
121 | class TrainDataset:
122 | def __init__(self, train_data, group_size, tokenizer):
123 | self.group_size = group_size
124 | self.data = train_data
125 | self.tokenizer = tokenizer
126 |
127 | def __len__(self):
128 | return len(self.data)
129 |
130 | def get_example(self, i, epoch):
131 | example = self.data[i]
132 | q = example['query_input_ids']
133 |
134 | pp = example['pos_psgs_input_ids']
135 | p = pp[0]
136 |
137 | nn = example['neg_psgs_input_ids']
138 | off = epoch * (self.group_size - 1) % len(nn)
139 | nn = nn * 2
140 | nn = nn[off: off + self.group_size - 1]
141 |
142 | return q, [p] + nn
143 |
144 | def get_batch(self, indices, epoch):
145 | qq, dd = zip(*[self.get_example(i, epoch) for i in map(int, indices)])
146 | dd = sum(dd, [])
147 | return dict(tokenizer.pad(qq, max_length=32, padding='max_length', return_tensors='np')), dict(
148 | tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np'))
149 |
150 | train_dataset = TrainDataset(train_data, data_args.train_n_passages, tokenizer)
151 |
152 | def create_learning_rate_fn(
153 | train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int,
154 | learning_rate: float
155 | ):
156 | """Returns a linear warmup, linear_decay learning rate function."""
157 | steps_per_epoch = train_ds_size // train_batch_size
158 | num_train_steps = steps_per_epoch * num_train_epochs
159 | warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
160 | decay_fn = optax.linear_schedule(
161 | init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
162 | )
163 | schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
164 | return schedule_fn
165 |
166 | def _decay_mask_fn(params):
167 | flat_params = traverse_util.flatten_dict(params)
168 | layer_norm_params = [
169 | (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
170 | ]
171 | flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
172 | return traverse_util.unflatten_dict(flat_mask)
173 |
174 | def decay_mask_fn(params):
175 | param_nodes, treedef = jax.tree_flatten(params, lambda v: isinstance(v, dict))
176 | masks = [_decay_mask_fn(param_node) for param_node in param_nodes]
177 | return jax.tree_unflatten(treedef, masks)
178 |
179 | num_epochs = int(training_args.num_train_epochs)
180 | train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
181 | steps_per_epoch = len(train_dataset) // train_batch_size
182 | total_train_steps = steps_per_epoch * num_epochs
183 |
184 | linear_decay_lr_schedule_fn = create_learning_rate_fn(
185 | len(train_dataset),
186 | train_batch_size,
187 | int(training_args.num_train_epochs),
188 | int(total_train_steps * 0.1),
189 | training_args.learning_rate,
190 | )
191 |
192 | adamw = optax.adamw(
193 | learning_rate=linear_decay_lr_schedule_fn,
194 | b1=training_args.adam_beta1,
195 | b2=training_args.adam_beta2,
196 | eps=training_args.adam_epsilon,
197 | weight_decay=training_args.weight_decay,
198 | mask=decay_mask_fn,
199 | )
200 |
201 | if model_args.untie_encoder:
202 | params = DualParams.create(model.params)
203 | else:
204 | params = TiedParams.create(model.params)
205 | state = RetrieverTrainState.create(apply_fn=model.__call__, params=params, tx=adamw)
206 |
207 | if training_args.grad_cache:
208 | q_n_subbatch = train_batch_size // training_args.gc_q_chunk_size
209 | p_n_subbatch = train_batch_size * data_args.train_n_passages // training_args.gc_p_chunk_size
210 | p_train_step = jax.pmap(
211 | partial(grad_cache_train_step, q_n_subbatch=q_n_subbatch, p_n_subbatch=p_n_subbatch),
212 | "device"
213 | )
214 | else:
215 | p_train_step = jax.pmap(
216 | retriever_train_step,
217 | "device"
218 | )
219 |
220 | state = jax_utils.replicate(state)
221 | rng = jax.random.PRNGKey(training_args.seed)
222 | dropout_rngs = jax.random.split(rng, jax.local_device_count())
223 |
224 | class IterableTrain(IterableDataset):
225 | def __init__(self, dataset, batch_idx, epoch):
226 | super(IterableTrain).__init__()
227 | self.dataset = dataset
228 | self.batch_idx = batch_idx
229 | self.epoch = epoch
230 |
231 | def __iter__(self):
232 | for idx in self.batch_idx:
233 | batch = self.dataset.get_batch(idx, self.epoch)
234 | batch = shard(batch)
235 | yield batch
236 |
237 | logger.info("***** Running training *****")
238 | logger.info(f" Num examples = {len(train_dataset)}")
239 | logger.info(f" Num Epochs = {num_epochs}")
240 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
241 | logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
242 | logger.info(f" Total optimization steps = {total_train_steps}")
243 |
244 | train_metrics = []
245 | for epoch in tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0):
246 | # ======================== Training ================================
247 | # Create sampling rng
248 | rng, input_rng = jax.random.split(rng)
249 |
250 | steps_per_epoch = len(train_dataset) // train_batch_size
251 |
252 | batch_idx = jax.random.permutation(input_rng, len(train_dataset))
253 | batch_idx = batch_idx[: steps_per_epoch * train_batch_size]
254 | batch_idx = batch_idx.reshape((steps_per_epoch, train_batch_size)).tolist()
255 |
256 | train_loader = prefetch_to_device(
257 | iter(DataLoader(
258 | IterableTrain(train_dataset, batch_idx, epoch),
259 | num_workers=16, prefetch_factor=256, batch_size=None, collate_fn=lambda v: v)
260 | ), 2)
261 |
262 | # train
263 | epochs = tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)
264 | for step in epochs:
265 | cur_step = epoch * (len(train_dataset) // train_batch_size) + step
266 | batch = next(train_loader)
267 |
268 | loss, state, dropout_rngs = p_train_step(state, *batch, dropout_rngs)
269 | train_metrics.append({'loss': loss})
270 |
271 | if cur_step % training_args.logging_steps == 0 and cur_step > 0:
272 | train_metrics = get_metrics(train_metrics)
273 | print(
274 | f"Step... ({cur_step} | Loss: {train_metrics['loss'].mean()},"
275 | f" Learning Rate: {linear_decay_lr_schedule_fn(cur_step)})",
276 | flush=True,
277 | )
278 | train_metrics = []
279 |
280 | epochs.write(
281 | f"Epoch... ({epoch + 1}/{num_epochs})"
282 | )
283 |
284 | params = jax_utils.unreplicate(state.params)
285 |
286 | if model_args.untie_encoder:
287 | os.makedirs(training_args.output_dir, exist_ok=True)
288 | model.save_pretrained(os.path.join(training_args.output_dir, 'query_encoder'), params=params.q_params)
289 | model.save_pretrained(os.path.join(training_args.output_dir, 'passage_encoder'), params=params.p_params)
290 | else:
291 | model.save_pretrained(training_args.output_dir, params=params.p_params)
292 | tokenizer.save_pretrained(training_args.output_dir)
293 |
294 |
295 | if __name__ == "__main__":
296 | main()
297 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/driver/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | from transformers import AutoConfig, AutoTokenizer
6 | from transformers import (
7 | HfArgumentParser,
8 | set_seed,
9 | )
10 |
11 | from tevatron.arguments import ModelArguments, DataArguments, \
12 | TevatronTrainingArguments as TrainingArguments
13 | from tevatron.data import TrainDataset, QPCollator
14 | from tevatron.modeling import DenseModel
15 | from tevatron.trainer import TevatronTrainer as Trainer, GCTrainer
16 | from tevatron.datasets import HFTrainDataset
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def main():
22 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
23 |
24 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
25 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
26 | else:
27 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
28 | model_args: ModelArguments
29 | data_args: DataArguments
30 | training_args: TrainingArguments
31 |
32 | if (
33 | os.path.exists(training_args.output_dir)
34 | and os.listdir(training_args.output_dir)
35 | and training_args.do_train
36 | and not training_args.overwrite_output_dir
37 | ):
38 | raise ValueError(
39 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
40 | )
41 |
42 | # Setup logging
43 | logging.basicConfig(
44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
45 | datefmt="%m/%d/%Y %H:%M:%S",
46 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
47 | )
48 | logger.warning(
49 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
50 | training_args.local_rank,
51 | training_args.device,
52 | training_args.n_gpu,
53 | bool(training_args.local_rank != -1),
54 | training_args.fp16,
55 | )
56 | logger.info("Training/evaluation parameters %s", training_args)
57 | logger.info("MODEL parameters %s", model_args)
58 |
59 | set_seed(training_args.seed)
60 |
61 | num_labels = 1
62 | config = AutoConfig.from_pretrained(
63 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
64 | num_labels=num_labels,
65 | cache_dir=model_args.cache_dir,
66 | )
67 | tokenizer = AutoTokenizer.from_pretrained(
68 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
69 | cache_dir=model_args.cache_dir,
70 | use_fast=False,
71 | )
72 | model = DenseModel.build(
73 | model_args,
74 | training_args,
75 | config=config,
76 | cache_dir=model_args.cache_dir,
77 | )
78 |
79 | train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args,
80 | cache_dir=data_args.data_cache_dir or model_args.cache_dir)
81 | train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer)
82 |
83 | trainer_cls = GCTrainer if training_args.grad_cache else Trainer
84 | trainer = trainer_cls(
85 | model=model,
86 | args=training_args,
87 | train_dataset=train_dataset,
88 | data_collator=QPCollator(
89 | tokenizer,
90 | max_p_len=data_args.p_max_len,
91 | max_q_len=data_args.q_max_len
92 | ),
93 | )
94 | train_dataset.trainer = trainer
95 |
96 | trainer.train() # TODO: resume training
97 | trainer.save_model()
98 | if trainer.is_world_process_zero():
99 | tokenizer.save_pretrained(training_args.output_dir)
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/faiss_retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from .retriever import BaseFaissIPRetriever, BaseFaissHNSWRetriever
2 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/faiss_retriever/__main__.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import numpy as np
4 | import glob
5 | from argparse import ArgumentParser
6 | from itertools import chain
7 | from tqdm import tqdm
8 | import os
9 | import json
10 | from .retriever import BaseFaissIPRetriever, BaseFaissHNSWRetriever
11 |
12 | import logging
13 | logger = logging.getLogger(__name__)
14 | logging.basicConfig(
15 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
16 | datefmt="%m/%d/%Y %H:%M:%S",
17 | level=logging.INFO,
18 | )
19 |
20 |
21 | def search_queries(retriever, q_reps, p_lookup, args):
22 | if args.batch_size > 0:
23 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size)
24 | else:
25 | all_scores, all_indices = retriever.search(q_reps, args.depth)
26 |
27 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices]
28 | psg_indices = np.array(psg_indices)
29 | return all_scores, psg_indices
30 |
31 |
32 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file):
33 | with open(ranking_save_file, 'w') as f:
34 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices):
35 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
36 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
37 | for s, idx in score_list:
38 | f.write(f'{qid}\t{idx}\t{s}\n')
39 |
40 | def pickle_load(path):
41 | with open(path, 'rb') as f:
42 | reps, lookup = pickle.load(f)
43 | return np.array(reps, dtype=np.float32), lookup
44 |
45 | def pickle_save(obj, path):
46 | with open(path, 'wb') as f:
47 | pickle.dump(obj, f)
48 |
49 | def main():
50 | parser = ArgumentParser()
51 | parser.add_argument('--query_reps', required=True)
52 | parser.add_argument('--passage_reps', required=True)
53 | parser.add_argument('--batch_size', type=int, default=128)
54 | parser.add_argument('--depth', type=int, default=1000)
55 | parser.add_argument('--save_ranking_to', required=True)
56 | parser.add_argument('--save_text', action='store_true')
57 | parser.add_argument('--use_hnsw', action='store_true')
58 | parser.add_argument('--save_index_path')
59 | parser.add_argument('--quiet', action='store_true')
60 |
61 | args = parser.parse_args()
62 |
63 | index_files = glob.glob(args.passage_reps)
64 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
65 |
66 | if args.use_hnsw:
67 | p_reps_0, p_lookup_0 = pickle_load(index_files[0])
68 | retriever = BaseFaissHNSWRetriever(p_reps_0)
69 |
70 | if os.path.exists(os.path.join(args.save_index_path, "index.faiss")):
71 | logger.info("Loading from saved index")
72 | retriever.load(os.path.join(args.save_index_path, "index.faiss"))
73 | look_up = json.load(open(os.path.join(args.save_index_path, "index.json")))
74 |
75 | else:
76 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
77 | if len(index_files) > 1:
78 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))
79 | look_up = []
80 | all_p_reps = list()
81 | for p_reps, p_lookup in shards:
82 | all_p_reps.extend(p_reps)
83 | look_up += p_lookup
84 |
85 | all_p_reps = np.array(all_p_reps, dtype=np.float32)
86 | logger.info("Building index")
87 | retriever.build(all_p_reps)
88 | logger.info("Saving index")
89 | retriever.save(os.path.join(args.save_index_path, "index.faiss"))
90 | json.dump(look_up, open(os.path.join(args.save_index_path, "index.json"), "w"))
91 |
92 | else:
93 | p_reps_0, p_lookup_0 = pickle_load(index_files[0])
94 | retriever = BaseFaissIPRetriever(p_reps_0)
95 |
96 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
97 | if len(index_files) > 1:
98 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))
99 | look_up = []
100 | for p_reps, p_lookup in shards:
101 | retriever.add(p_reps)
102 | look_up += p_lookup
103 |
104 | q_reps, q_lookup = pickle_load(args.query_reps)
105 | q_reps = q_reps
106 |
107 | logger.info('Index Search Start')
108 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args)
109 | logger.info('Index Search Finished')
110 |
111 | if args.save_text:
112 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to)
113 | else:
114 | pickle_save((all_scores, psg_indices), args.save_ranking_to)
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/faiss_retriever/reducer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import faiss
3 | from argparse import ArgumentParser
4 | from tqdm import tqdm
5 | from typing import Iterable, Tuple
6 | from numpy import ndarray
7 | from .__main__ import pickle_load, write_ranking
8 |
9 |
10 | def combine_faiss_results(results: Iterable[Tuple[ndarray, ndarray]]):
11 | rh = None
12 | for scores, indices in results:
13 | if rh is None:
14 | print(f'Initializing Heap. Assuming {scores.shape[0]} queries.')
15 | rh = faiss.ResultHeap(scores.shape[0], scores.shape[1])
16 | rh.add_result(-scores, indices)
17 | rh.finalize()
18 | corpus_scores, corpus_indices = -rh.D, rh.I
19 |
20 | return corpus_scores, corpus_indices
21 |
22 |
23 | def main():
24 | parser = ArgumentParser()
25 | parser.add_argument('--score_dir', required=True)
26 | parser.add_argument('--query', required=True)
27 | parser.add_argument('--save_ranking_to', required=True)
28 | args = parser.parse_args()
29 |
30 | partitions = glob.glob(f'{args.score_dir}/*')
31 |
32 | corpus_scores, corpus_indices = combine_faiss_results(map(pickle_load, tqdm(partitions)))
33 |
34 | _, q_lookup = pickle_load(args.query)
35 | write_ranking(corpus_indices, corpus_scores, q_lookup, args.save_ranking_to)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/faiss_retriever/retriever.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import faiss
3 | from tqdm.autonotebook import trange
4 | import time
5 |
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | class BaseFaissHNSWRetriever:
11 | def __init__(self, init_reps: np.ndarray, hnsw_store_n: int = 512, hnsw_ef_search: int = 128, hnsw_ef_construction: int = 64, similarity_metric=faiss.METRIC_INNER_PRODUCT):
12 | self.index = faiss.IndexHNSWFlat(init_reps.shape[1], hnsw_store_n, similarity_metric)
13 | self.index.hnsw.efSearch = hnsw_ef_search
14 | self.index.hnsw.efConstruction = hnsw_ef_construction
15 | # self.index = faiss.index_factory(init_reps.shape[1], "HNSW" + str(hnsw_store_n))
16 |
17 | def load(self, fname: str):
18 | self.index = faiss.read_index(fname)
19 |
20 | def save(self, fname: str):
21 | faiss.write_index(self.index, fname)
22 |
23 | def build(self, p_reps: np.ndarray, buffer_size: int = 1000):
24 | # sq_norms = (p_reps ** 2).sum(1)
25 | # max_sq_norm = float(sq_norms.max())
26 | # aux_dims = np.sqrt(max_sq_norm - sq_norms)
27 | # p_reps = np.hstack((p_reps, aux_dims.reshape(-1, 1)))
28 | for start in trange(0, p_reps.shape[0], buffer_size):
29 | self.index.add(p_reps[start : start + buffer_size])
30 |
31 | def search(self, q_reps: np.ndarray, k: int):
32 | # q_reps = np.hstack((q_reps, np.zeros((q_reps.shape[0], 1), dtype=np.float32)))
33 | return self.index.search(q_reps, k)
34 |
35 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int):
36 | num_query = q_reps.shape[0]
37 | # q_reps = np.hstack((q_reps, np.zeros((q_reps.shape[0], 1), dtype=np.float32)))
38 | all_scores = []
39 | all_indices = []
40 | for start_idx in trange(0, num_query, batch_size):
41 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k)
42 | all_scores.append(nn_scores)
43 | all_indices.append(nn_indices)
44 | all_scores = np.concatenate(all_scores, axis=0)
45 | all_indices = np.concatenate(all_indices, axis=0)
46 |
47 | return all_scores, all_indices
48 |
49 | class BaseFaissIPRetriever:
50 | def __init__(self, init_reps: np.ndarray):
51 | index = faiss.IndexFlatIP(init_reps.shape[1])
52 | self.index = index
53 |
54 | def add(self, p_reps: np.ndarray):
55 | self.index.add(p_reps)
56 |
57 | def search(self, q_reps: np.ndarray, k: int):
58 | return self.index.search(q_reps, k)
59 |
60 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int):
61 | num_query = q_reps.shape[0]
62 | all_scores = []
63 | all_indices = []
64 | for start_idx in trange(0, num_query, batch_size):
65 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k)
66 | all_scores.append(nn_scores)
67 | all_indices.append(nn_indices)
68 | all_scores = np.concatenate(all_scores, axis=0)
69 | all_indices = np.concatenate(all_indices, axis=0)
70 |
71 | return all_scores, all_indices
72 |
73 |
74 | class FaissRetriever(BaseFaissIPRetriever):
75 |
76 | def __init__(self, init_reps: np.ndarray, factory_str: str):
77 | index = faiss.index_factory(init_reps.shape[1], factory_str)
78 | self.index = index
79 | self.index.verbose = True
80 | if not self.index.is_trained:
81 | self.index.train(init_reps)
82 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from torch import distributed as dist
5 |
6 |
7 | class SimpleContrastiveLoss:
8 |
9 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
10 | if target is None:
11 | target_per_qry = y.size(0) // x.size(0)
12 | target = torch.arange(
13 | 0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)
14 | logits = torch.matmul(x, y.transpose(0, 1))
15 | return F.cross_entropy(logits, target, reduction=reduction)
16 |
17 |
18 | class DistributedContrastiveLoss(SimpleContrastiveLoss):
19 | def __init__(self, n_target: int = 0, scale_loss: bool = True):
20 | assert dist.is_initialized(), "Distributed training has not been properly initialized."
21 | super().__init__()
22 | self.word_size = dist.get_world_size()
23 | self.rank = dist.get_rank()
24 | self.scale_loss = scale_loss
25 |
26 | def __call__(self, x: Tensor, y: Tensor, **kwargs):
27 | dist_x = self.gather_tensor(x)
28 | dist_y = self.gather_tensor(y)
29 | loss = super().__call__(dist_x, dist_y, **kwargs)
30 | if self.scale_loss:
31 | loss = loss * self.word_size
32 | return loss
33 |
34 | def gather_tensor(self, t):
35 | gathered = [torch.empty_like(t) for _ in range(self.word_size)]
36 | dist.all_gather(gathered, t)
37 | gathered[self.rank] = t
38 | return torch.cat(gathered, dim=0)
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import EncoderModel, EncoderPooler, EncoderOutput
2 | from .dense import DenseModel
3 | from .unicoil import UniCoilModel
4 | from .splade import SpladeModel
5 | from .colbert import ColbertModel
6 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/colbert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | import logging
5 | from .encoder import EncoderPooler, EncoderModel
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class ColbertPooler(EncoderPooler):
11 | def __init__(self, input_dim: int = 768, output_dim: int = 32, tied=True):
12 | super(ColbertPooler, self).__init__()
13 | self.linear_q = nn.Linear(input_dim, output_dim)
14 | if tied:
15 | self.linear_p = self.linear_q
16 | else:
17 | self.linear_p = nn.Linear(input_dim, output_dim)
18 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied}
19 |
20 | def forward(self, q: Tensor = None, p: Tensor = None, **kwargs):
21 | if q is not None:
22 | return self.linear_q(q)
23 | elif p is not None:
24 | return self.linear_p(p)
25 | else:
26 | raise ValueError
27 |
28 |
29 | class ColbertModel(EncoderModel):
30 | def encode_passage(self, psg):
31 | if psg is None:
32 | return None
33 | psg_out = self.lm_p(**psg, return_dict=True)
34 | p_hidden = psg_out.last_hidden_state
35 | p_reps = self.pooler(p=p_hidden)
36 | p_reps *= psg['attention_mask'][:, :, None].float()
37 | return p_reps
38 |
39 | def encode_query(self, qry):
40 | if qry is None:
41 | return None
42 | qry_out = self.lm_q(**qry, return_dict=True)
43 | q_hidden = qry_out.last_hidden_state
44 | q_reps = self.pooler(q=q_hidden)
45 | q_reps *= qry['attention_mask'][:, :, None].float()
46 | return q_reps
47 |
48 | def compute_similarity(self, q_reps, p_reps):
49 | token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
50 | scores, _ = token_scores.max(-1)
51 | scores = scores.sum(1)
52 | return scores
53 |
54 | @staticmethod
55 | def load_pooler(model_weights_file, **config):
56 | pooler = ColbertPooler(**config)
57 | pooler.load(model_weights_file)
58 | return pooler
59 |
60 | @staticmethod
61 | def build_pooler(model_args):
62 | pooler = ColbertPooler(
63 | model_args.projection_in_dim,
64 | model_args.projection_out_dim,
65 | tied=not model_args.untie_encoder
66 | )
67 | pooler.load(model_args.model_name_or_path)
68 | return pooler
69 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/dense.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | import logging
5 | from .encoder import EncoderPooler, EncoderModel
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class DensePooler(EncoderPooler):
11 | def __init__(self, input_dim: int = 768, output_dim: int = 768, tied=True):
12 | super(DensePooler, self).__init__()
13 | self.linear_q = nn.Linear(input_dim, output_dim)
14 | if tied:
15 | self.linear_p = self.linear_q
16 | else:
17 | self.linear_p = nn.Linear(input_dim, output_dim)
18 | self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied}
19 |
20 | def forward(self, q: Tensor = None, p: Tensor = None, **kwargs):
21 | if q is not None:
22 | return self.linear_q(q[:, 0])
23 | elif p is not None:
24 | return self.linear_p(p[:, 0])
25 | else:
26 | raise ValueError
27 |
28 |
29 | class DenseModel(EncoderModel):
30 | def encode_passage(self, psg):
31 | if psg is None:
32 | return None
33 | psg_out = self.lm_p(**psg, return_dict=True)
34 | p_hidden = psg_out.last_hidden_state
35 | if self.pooler is not None:
36 | p_reps = self.pooler(p=p_hidden) # D * d
37 | else:
38 | if self.use_mean_pooling:
39 | p_reps = p_hidden.masked_fill(~psg['attention_mask'][..., None].bool(), 0.)
40 | p_reps = p_reps.sum(dim=1) / psg['attention_mask'].sum(dim=1)[..., None]
41 | else:
42 | p_reps = p_hidden[:, 0]
43 | return p_reps
44 |
45 | def encode_query(self, qry):
46 | if qry is None:
47 | return None
48 | qry_out = self.lm_q(**qry, return_dict=True)
49 | q_hidden = qry_out.last_hidden_state
50 | if self.pooler is not None:
51 | q_reps = self.pooler(q=q_hidden)
52 | else:
53 | if self.use_mean_pooling:
54 | q_reps = q_hidden.masked_fill(~qry['attention_mask'][..., None].bool(), 0.)
55 | q_reps = q_reps.sum(dim=1) / qry['attention_mask'].sum(dim=1)[..., None]
56 | else:
57 | q_reps = q_hidden[:, 0]
58 | return q_reps
59 |
60 | def compute_similarity(self, q_reps, p_reps):
61 | return torch.matmul(q_reps, p_reps.transpose(0, 1))
62 |
63 | @staticmethod
64 | def load_pooler(model_weights_file, **config):
65 | pooler = DensePooler(**config)
66 | pooler.load(model_weights_file)
67 | return pooler
68 |
69 | @staticmethod
70 | def build_pooler(model_args):
71 | pooler = DensePooler(
72 | model_args.projection_in_dim,
73 | model_args.projection_out_dim,
74 | tied=not model_args.untie_encoder
75 | )
76 | pooler.load(model_args.model_name_or_path)
77 | return pooler
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/encoder.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | from dataclasses import dataclass
5 | from typing import Dict, Optional
6 |
7 | import torch
8 | from torch import nn, Tensor
9 | import torch.distributed as dist
10 | from transformers import PreTrainedModel, AutoModel
11 | from transformers.file_utils import ModelOutput
12 |
13 | from tevatron.arguments import ModelArguments, \
14 | TevatronTrainingArguments as TrainingArguments
15 |
16 | import logging
17 | import pickle
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | @dataclass
23 | class EncoderOutput(ModelOutput):
24 | q_reps: Optional[Tensor] = None
25 | p_reps: Optional[Tensor] = None
26 | loss: Optional[Tensor] = None
27 | scores: Optional[Tensor] = None
28 |
29 |
30 | class EncoderPooler(nn.Module):
31 | def __init__(self, **kwargs):
32 | super(EncoderPooler, self).__init__()
33 | self._config = {}
34 |
35 | def forward(self, q_reps, p_reps):
36 | raise NotImplementedError('EncoderPooler is an abstract class')
37 |
38 | def load(self, model_dir: str):
39 | pooler_path = os.path.join(model_dir, 'pooler.pt')
40 | if pooler_path is not None:
41 | if os.path.exists(pooler_path):
42 | logger.info(f'Loading Pooler from {pooler_path}')
43 | state_dict = torch.load(pooler_path, map_location='cpu')
44 | self.load_state_dict(state_dict)
45 | return
46 | logger.info("Training Pooler from scratch")
47 | return
48 |
49 | def save_pooler(self, save_path):
50 | torch.save(self.state_dict(), os.path.join(save_path, 'pooler.pt'))
51 | with open(os.path.join(save_path, 'pooler_config.json'), 'w') as f:
52 | json.dump(self._config, f)
53 |
54 |
55 | class EncoderModel(nn.Module):
56 | TRANSFORMER_CLS = AutoModel
57 |
58 | def __init__(self,
59 | lm_q: PreTrainedModel,
60 | lm_p: PreTrainedModel,
61 | pooler: nn.Module = None,
62 | untie_encoder: bool = False,
63 | negatives_x_device: bool = False,
64 | model_args: ModelArguments = None
65 | ):
66 | super().__init__()
67 | self.lm_q = lm_q
68 | self.lm_p = lm_p
69 | self.pooler = pooler
70 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
71 | self.negatives_x_device = negatives_x_device
72 | self.untie_encoder = untie_encoder
73 | if model_args:
74 | self.use_mean_pooling = model_args.use_mean_pooling
75 | else:
76 | logger.info("Mean Pooling Enabled")
77 | self.use_mean_pooling = True
78 | if self.negatives_x_device:
79 | if not dist.is_initialized():
80 | raise ValueError('Distributed training has not been initialized for representation all gather.')
81 | self.process_rank = dist.get_rank()
82 | self.world_size = dist.get_world_size()
83 |
84 | def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None):
85 | q_reps = self.encode_query(query)
86 | p_reps = self.encode_passage(passage)
87 |
88 | # for inference
89 | if q_reps is None or p_reps is None:
90 | return EncoderOutput(
91 | q_reps=q_reps,
92 | p_reps=p_reps
93 | )
94 |
95 | # for training
96 | if self.training:
97 | if self.negatives_x_device:
98 | q_reps = self._dist_gather_tensor(q_reps)
99 | p_reps = self._dist_gather_tensor(p_reps)
100 |
101 | scores = self.compute_similarity(q_reps, p_reps)
102 | scores = scores.view(q_reps.size(0), -1)
103 |
104 | target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
105 | target = target * (p_reps.size(0) // q_reps.size(0))
106 |
107 | loss = self.compute_loss(scores, target)
108 | if self.negatives_x_device:
109 | loss = loss * self.world_size # counter average weight reduction
110 | # for eval
111 | else:
112 | scores = self.compute_similarity(q_reps, p_reps)
113 | loss = None
114 | return EncoderOutput(
115 | loss=loss,
116 | scores=scores,
117 | q_reps=q_reps,
118 | p_reps=p_reps,
119 | )
120 |
121 | @staticmethod
122 | def build_pooler(model_args):
123 | return None
124 |
125 | @staticmethod
126 | def load_pooler(weights, **config):
127 | return None
128 |
129 | def encode_passage(self, psg):
130 | raise NotImplementedError('EncoderModel is an abstract class')
131 |
132 | def encode_query(self, qry):
133 | raise NotImplementedError('EncoderModel is an abstract class')
134 |
135 | def compute_similarity(self, q_reps, p_reps):
136 | return torch.matmul(q_reps, p_reps.transpose(0, 1))
137 |
138 | def compute_loss(self, scores, target):
139 | return self.cross_entropy(scores, target)
140 |
141 | def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
142 | if t is None:
143 | return None
144 | t = t.contiguous()
145 |
146 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
147 | dist.all_gather(all_tensors, t)
148 |
149 | all_tensors[self.process_rank] = t
150 | all_tensors = torch.cat(all_tensors, dim=0)
151 |
152 | return all_tensors
153 |
154 | @classmethod
155 | def build(
156 | cls,
157 | model_args: ModelArguments,
158 | train_args: TrainingArguments,
159 | **hf_kwargs,
160 | ):
161 | # load local
162 | if os.path.isdir(model_args.model_name_or_path):
163 | if model_args.untie_encoder:
164 | _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model')
165 | _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model')
166 | if not os.path.exists(_qry_model_path):
167 | _qry_model_path = model_args.model_name_or_path
168 | _psg_model_path = model_args.model_name_or_path
169 | logger.info(f'loading query model weight from {_qry_model_path}')
170 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(
171 | _qry_model_path,
172 | **hf_kwargs
173 | )
174 | logger.info(f'loading passage model weight from {_psg_model_path}')
175 | lm_p = cls.TRANSFORMER_CLS.from_pretrained(
176 | _psg_model_path,
177 | **hf_kwargs
178 | )
179 | else:
180 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
181 | lm_p = lm_q
182 | # load pre-trained
183 | else:
184 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
185 | lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q
186 |
187 | if model_args.add_pooler:
188 | pooler = cls.build_pooler(model_args)
189 | else:
190 | pooler = None
191 |
192 | model = cls(
193 | lm_q=lm_q,
194 | lm_p=lm_p,
195 | pooler=pooler,
196 | negatives_x_device=train_args.negatives_x_device,
197 | untie_encoder=model_args.untie_encoder,
198 | model_args=model_args
199 | )
200 | return model
201 |
202 | @classmethod
203 | def load(
204 | cls,
205 | model_name_or_path,
206 | **hf_kwargs,
207 | ):
208 | # load local
209 | untie_encoder = True
210 | if os.path.isdir(model_name_or_path):
211 | _qry_model_path = os.path.join(model_name_or_path, 'query_model')
212 | _psg_model_path = os.path.join(model_name_or_path, 'passage_model')
213 | if os.path.exists(_qry_model_path):
214 | logger.info(f'found separate weight for query/passage encoders')
215 | logger.info(f'loading query model weight from {_qry_model_path}')
216 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(
217 | _qry_model_path,
218 | **hf_kwargs
219 | )
220 | logger.info(f'loading passage model weight from {_psg_model_path}')
221 | lm_p = cls.TRANSFORMER_CLS.from_pretrained(
222 | _psg_model_path,
223 | **hf_kwargs
224 | )
225 | untie_encoder = False
226 | else:
227 | logger.info(f'try loading tied weight')
228 | logger.info(f'loading model weight from {model_name_or_path}')
229 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
230 | lm_p = lm_q
231 | else:
232 | logger.info(f'try loading tied weight')
233 | logger.info(f'loading model weight from {model_name_or_path}')
234 | lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
235 | lm_p = lm_q
236 |
237 | pooler_weights = os.path.join(model_name_or_path, 'pooler.pt')
238 | pooler_config = os.path.join(model_name_or_path, 'pooler_config.json')
239 | if os.path.exists(pooler_weights) and os.path.exists(pooler_config):
240 | logger.info(f'found pooler weight and configuration')
241 | with open(pooler_config) as f:
242 | pooler_config_dict = json.load(f)
243 | pooler = cls.load_pooler(model_name_or_path, **pooler_config_dict)
244 | else:
245 | pooler = None
246 |
247 | model_args_path = os.path.join(model_name_or_path, 'model_args.bin')
248 | model_args = None
249 | if os.path.exists(model_args_path):
250 | model_args = pickle.load(open(model_args_path, "rb"))
251 |
252 | model = cls(
253 | lm_q=lm_q,
254 | lm_p=lm_p,
255 | pooler=pooler,
256 | untie_encoder=untie_encoder,
257 | model_args = model_args
258 | )
259 | return model
260 |
261 | def save(self, output_dir: str):
262 | if self.untie_encoder:
263 | os.makedirs(os.path.join(output_dir, 'query_model'))
264 | os.makedirs(os.path.join(output_dir, 'passage_model'))
265 | self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
266 | self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
267 | else:
268 | self.lm_q.save_pretrained(output_dir)
269 | if self.pooler:
270 | self.pooler.save_pooler(output_dir)
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/splade.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from transformers import AutoModelForMaskedLM
4 | from .encoder import EncoderModel
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class SpladeModel(EncoderModel):
10 | TRANSFORMER_CLS = AutoModelForMaskedLM
11 |
12 | def encode_passage(self, psg):
13 | if psg is None:
14 | return None
15 | psg_out = self.lm_p(**psg, return_dict=True).logits
16 | aggregated_psg_out, _ = torch.max(torch.log(1 + torch.relu(psg_out)) * psg['attention_mask'].unsqueeze(-1), dim=1)
17 | return aggregated_psg_out
18 |
19 | def encode_query(self, qry):
20 | if qry is None:
21 | return None
22 | qry_out = self.lm_q(**qry, return_dict=True).logits
23 | aggregated_psg_out, _ = torch.max(torch.log(1 + torch.relu(qry_out)) * qry['attention_mask'].unsqueeze(-1), dim=1)
24 | return aggregated_psg_out
25 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/modeling/unicoil.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 | import logging
4 |
5 | from .encoder import EncoderPooler, EncoderModel
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class UniCoilPooler(EncoderPooler):
11 | def __init__(self, input_dim: int = 768, tied=True):
12 | super(UniCoilPooler, self).__init__()
13 | self.linear_q = nn.Linear(input_dim, 1)
14 | if tied:
15 | self.linear_p = self.linear_q
16 | else:
17 | self.linear_p = nn.Linear(input_dim, 1)
18 | self._config = {'input_dim': input_dim, 'tied': tied}
19 |
20 | def forward(self, q: Tensor = None, p: Tensor = None):
21 | if q is not None:
22 | return self.linear_q(q)
23 | elif p is not None:
24 | return self.linear_p(p)
25 | else:
26 | raise ValueError
27 |
28 |
29 | class UniCoilModel(EncoderModel):
30 | def encode_passage(self, psg):
31 | if psg is None:
32 | return None
33 | psg_out = self.lm_p(**psg, return_dict=True)
34 | p_hidden = psg_out.last_hidden_state
35 | p_reps = self.pooler(p=p_hidden)
36 | return self._weights_to_vec(psg['input_ids'], p_reps)
37 |
38 | def encode_query(self, qry):
39 | if qry is None:
40 | return None
41 | qry_out = self.lm_q(**qry, return_dict=True)
42 | q_hidden = qry_out.last_hidden_state
43 | q_reps = self.pooler(q=q_hidden)
44 | return self._weights_to_vec(qry['input_ids'], q_reps)
45 |
46 | def compute_similarity(self, q_reps, p_reps):
47 | return torch.matmul(q_reps, p_reps.transpose(0, 1))
48 |
49 | def _weights_to_vec(self, input_ids, tok_weights):
50 | input_shape = input_ids.size()
51 | tok_weights = torch.relu(tok_weights)
52 | tok_emb = torch.zeros(input_shape[0], input_shape[1], self.lm_p.config.vocab_size, dtype=tok_weights.dtype,
53 | device=input_ids.device)
54 | tok_emb = torch.scatter(tok_emb, dim=-1, index=input_ids.unsqueeze(-1), src=tok_weights)
55 | disabled_token_ids = [0, 101, 102, 103] # hard code for bert for now, can pass in a tokenizer in the future
56 | tok_emb = torch.max(tok_emb, dim=1).values
57 | tok_emb[:, disabled_token_ids] *= 0
58 | return tok_emb
59 |
60 | @staticmethod
61 | def build_pooler(model_args):
62 | pooler = UniCoilPooler(
63 | model_args.projection_in_dim,
64 | tied=not model_args.untie_encoder
65 | )
66 | pooler.load(model_args.model_name_or_path)
67 | return pooler
68 |
69 | @staticmethod
70 | def load_pooler(model_weights_file, **config):
71 | pooler = UniCoilPooler(**config)
72 | pooler.load(model_weights_file)
73 | return pooler
74 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/preprocessor/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocessor_tsv import SimpleTrainPreProcessor as MarcoPassageTrainPreProcessor, \
2 | SimpleCollectionPreProcessor as MarcoPassageCollectionPreProcessor
--------------------------------------------------------------------------------
/tevatron/src/tevatron/preprocessor/normalize_text.py:
--------------------------------------------------------------------------------
1 | #: Control characters.
2 | CONTROLS = {
3 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
4 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
5 | }
6 | # There are further control characters, but they are instead replaced with a space by unicode normalization
7 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
8 |
9 |
10 | #: Hyphen and dash characters.
11 | HYPHENS = {
12 | '-', # \u002d Hyphen-minus
13 | '‐', # \u2010 Hyphen
14 | '‑', # \u2011 Non-breaking hyphen
15 | '⁃', # \u2043 Hyphen bullet
16 | '‒', # \u2012 figure dash
17 | '–', # \u2013 en dash
18 | '—', # \u2014 em dash
19 | '―', # \u2015 horizontal bar
20 | }
21 |
22 | #: Minus characters.
23 | MINUSES = {
24 | '-', # \u002d Hyphen-minus
25 | '−', # \u2212 Minus
26 | '-', # \uff0d Full-width Hyphen-minus
27 | '⁻', # \u207b Superscript minus
28 | }
29 |
30 | #: Plus characters.
31 | PLUSES = {
32 | '+', # \u002b Plus
33 | '+', # \uff0b Full-width Plus
34 | '⁺', # \u207a Superscript plus
35 | }
36 |
37 | #: Slash characters.
38 | SLASHES = {
39 | '/', # \u002f Solidus
40 | '⁄', # \u2044 Fraction slash
41 | '∕', # \u2215 Division slash
42 | }
43 |
44 | #: Tilde characters.
45 | TILDES = {
46 | '~', # \u007e Tilde
47 | '˜', # \u02dc Small tilde
48 | '⁓', # \u2053 Swung dash
49 | '∼', # \u223c Tilde operator #in mbert vocab
50 | '∽', # \u223d Reversed tilde
51 | '∿', # \u223f Sine wave
52 | '〜', # \u301c Wave dash #in mbert vocab
53 | '~', # \uff5e Full-width tilde #in mbert vocab
54 | }
55 |
56 | #: Apostrophe characters.
57 | APOSTROPHES = {
58 | "'", # \u0027
59 | '’', # \u2019
60 | '՚', # \u055a
61 | 'Ꞌ', # \ua78b
62 | 'ꞌ', # \ua78c
63 | ''', # \uff07
64 | }
65 |
66 | #: Single quote characters.
67 | SINGLE_QUOTES = {
68 | "'", # \u0027
69 | '‘', # \u2018
70 | '’', # \u2019
71 | '‚', # \u201a
72 | '‛', # \u201b
73 |
74 | }
75 |
76 | #: Double quote characters.
77 | DOUBLE_QUOTES = {
78 | '"', # \u0022
79 | '“', # \u201c
80 | '”', # \u201d
81 | '„', # \u201e
82 | '‟', # \u201f
83 | }
84 |
85 | #: Accent characters.
86 | ACCENTS = {
87 | '`', # \u0060
88 | '´', # \u00b4
89 | }
90 |
91 | #: Prime characters.
92 | PRIMES = {
93 | '′', # \u2032
94 | '″', # \u2033
95 | '‴', # \u2034
96 | '‵', # \u2035
97 | '‶', # \u2036
98 | '‷', # \u2037
99 | '⁗', # \u2057
100 | }
101 |
102 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
103 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
104 |
105 | def normalize(text):
106 | for control in CONTROLS:
107 | text = text.replace(control, '')
108 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
109 |
110 | for hyphen in HYPHENS | MINUSES:
111 | text = text.replace(hyphen, '-')
112 | text = text.replace('\u00ad', '')
113 |
114 | for double_quote in DOUBLE_QUOTES:
115 | text = text.replace(double_quote, '"') # \u0022
116 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
117 | text = text.replace(single_quote, "'") # \u0027
118 | text = text.replace('′', "'") # \u2032 prime
119 | text = text.replace('‵', "'") # \u2035 reversed prime
120 | text = text.replace('″', "''") # \u2033 double prime
121 | text = text.replace('‶', "''") # \u2036 reversed double prime
122 | text = text.replace('‴', "'''") # \u2034 triple prime
123 | text = text.replace('‷', "'''") # \u2037 reversed triple prime
124 | text = text.replace('⁗', "''''") # \u2057 quadruple prime
125 |
126 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
127 |
128 | for slash in SLASHES:
129 | text = text.replace(slash, '/')
130 |
131 | #for tilde in TILDES:
132 | # text = text.replace(tilde, '~')
133 |
134 | return text
--------------------------------------------------------------------------------
/tevatron/src/tevatron/preprocessor/preprocessor_tsv.py:
--------------------------------------------------------------------------------
1 | import json
2 | import csv
3 | import datasets
4 | from transformers import PreTrainedTokenizer
5 | from dataclasses import dataclass
6 | from .normalize_text import normalize
7 |
8 | @dataclass
9 | class SimpleTrainPreProcessor:
10 | query_file: str
11 | collection_file: str
12 | tokenizer: PreTrainedTokenizer
13 |
14 | max_length: int = 128
15 | columns = ['text_id', 'title', 'text']
16 | title_field = 'title'
17 | text_field = 'text'
18 |
19 | def __post_init__(self):
20 | self.queries = self.read_queries(self.query_file)
21 | self.collection = datasets.load_dataset(
22 | 'csv',
23 | data_files=self.collection_file,
24 | column_names=self.columns,
25 | delimiter='\t',
26 | )['train']
27 |
28 | @staticmethod
29 | def read_queries(queries):
30 | qmap = {}
31 | with open(queries) as f:
32 | for l in f:
33 | qid, qry = l.strip().split('\t')
34 | qmap[qid] = qry
35 | return qmap
36 |
37 | @staticmethod
38 | def read_qrel(relevance_file):
39 | qrel = {}
40 | with open(relevance_file, encoding='utf8') as f:
41 | tsvreader = csv.reader(f, delimiter="\t")
42 | for [topicid, _, docid, rel] in tsvreader:
43 | assert rel == "1"
44 | if topicid in qrel:
45 | qrel[topicid].append(docid)
46 | else:
47 | qrel[topicid] = [docid]
48 | return qrel
49 |
50 | def get_query(self, q):
51 | query_encoded = self.tokenizer.encode(
52 | self.queries[q],
53 | add_special_tokens=False,
54 | max_length=self.max_length,
55 | truncation=True
56 | )
57 | return query_encoded
58 |
59 | def get_passage(self, p):
60 | entry = self.collection[int(p)]
61 | title = entry[self.title_field]
62 | title = "" if title is None else title
63 | body = entry[self.text_field]
64 | content = title + self.tokenizer.sep_token + body
65 |
66 | passage_encoded = self.tokenizer.encode(
67 | content,
68 | add_special_tokens=False,
69 | max_length=self.max_length,
70 | truncation=True
71 | )
72 |
73 | return passage_encoded
74 |
75 | def process_one(self, train):
76 | q, pp, nn = train
77 | train_example = {
78 | 'query': self.get_query(q),
79 | 'positives': [self.get_passage(p) for p in pp],
80 | 'negatives': [self.get_passage(n) for n in nn],
81 | }
82 |
83 | return json.dumps(train_example)
84 |
85 |
86 | @dataclass
87 | class SimpleCollectionPreProcessor:
88 | tokenizer: PreTrainedTokenizer
89 | separator: str = '\t'
90 | max_length: int = 128
91 | is_query: bool = False
92 | text_first: bool = False
93 | lower_case: bool = False
94 | normalize_text: bool = False
95 |
96 | def process_line(self, line: str):
97 | xx = line.strip().split(self.separator)
98 | text_id, text = xx[0], xx[1:]
99 | if len(text) == 1 and (not self.is_query):
100 | text.append("")
101 | if self.text_first:
102 | text = text[::-1]
103 | for i in range(0, len(text)):
104 | if self.normalize_text:
105 | text[i] = normalize(text[i])
106 | if self.lower_case:
107 | text[i] = text[i].lower()
108 | text_encoded = self.tokenizer.encode(
109 | self.tokenizer.sep_token.join(text),
110 | add_special_tokens=False,
111 | max_length=self.max_length,
112 | truncation=True
113 | )
114 | encoded = {
115 | 'text_id': text_id,
116 | 'text': text_encoded
117 | }
118 | return json.dumps(encoded)
119 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/tevax/__init__.py:
--------------------------------------------------------------------------------
1 | from .training import TiedParams, DualParams, RetrieverTrainState, retriever_train_step
2 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/tevax/loss.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import lax
3 | import optax
4 | import chex
5 |
6 |
7 | def _onehot(labels: chex.Array, num_classes: int) -> chex.Array:
8 | x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,))
9 | x = lax.select(x, jnp.ones(x.shape), jnp.zeros(x.shape))
10 | return x.astype(jnp.float32)
11 |
12 |
13 | def p_contrastive_loss(ss: chex.Array, tt: chex.Array, axis: str = 'device') -> chex.Array:
14 | per_shard_targets = tt.shape[0]
15 | per_sample_targets = int(tt.shape[0] / ss.shape[0])
16 | labels = jnp.arange(0, per_shard_targets, per_sample_targets) + per_shard_targets * lax.axis_index(axis)
17 |
18 | tt = lax.all_gather(tt, axis).reshape((-1, ss.shape[-1]))
19 | scores = jnp.dot(ss, jnp.transpose(tt))
20 |
21 | return optax.softmax_cross_entropy(scores, _onehot(labels, scores.shape[-1]))
22 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/tevax/training.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple, Any, Union
3 |
4 | import jax
5 | from jax import numpy as jnp
6 |
7 | from flax.training.train_state import TrainState
8 | from flax.core import FrozenDict
9 | from flax.struct import PyTreeNode
10 |
11 | from .loss import p_contrastive_loss
12 |
13 |
14 | class TiedParams(PyTreeNode):
15 | params: FrozenDict[str, Any]
16 |
17 | @property
18 | def q_params(self):
19 | return self.params
20 |
21 | @property
22 | def p_params(self):
23 | return self.params
24 |
25 | @classmethod
26 | def create(cls, params):
27 | return cls(params=params)
28 |
29 |
30 | class DualParams(PyTreeNode):
31 | params: Tuple[FrozenDict[str, Any], FrozenDict[str, Any]]
32 |
33 | @property
34 | def q_params(self):
35 | return self.params[0]
36 |
37 | @property
38 | def p_params(self):
39 | return self.params[1]
40 |
41 | @classmethod
42 | def create(cls, *ps):
43 | if len(ps) == 1:
44 | return cls(params=ps*2)
45 | else:
46 | p_params, q_params = ps
47 | return cls(params=[p_params, q_params])
48 |
49 |
50 | class RetrieverTrainState(TrainState):
51 | params: Union[TiedParams, DualParams]
52 |
53 |
54 | def retriever_train_step(state, queries, passages, dropout_rng, axis='device'):
55 | q_dropout_rng, p_dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, 3)
56 |
57 | def compute_loss(params):
58 | q_reps = state.apply_fn(**queries, params=params.q_params, dropout_rng=q_dropout_rng, train=True)[0][:, 0, :]
59 | p_reps = state.apply_fn(**passages, params=params.p_params, dropout_rng=p_dropout_rng, train=True)[0][:, 0, :]
60 | return jnp.mean(p_contrastive_loss(q_reps, p_reps, axis=axis))
61 |
62 | loss, grad = jax.value_and_grad(compute_loss)(state.params)
63 | loss, grad = jax.lax.pmean([loss, grad], axis)
64 |
65 | new_state = state.apply_gradients(grads=grad)
66 |
67 | return loss, new_state, new_dropout_rng
68 |
69 |
70 | def grad_cache_train_step(state, queries, passages, dropout_rng, axis='device', q_n_subbatch=1, p_n_subbatch=1):
71 | try:
72 | from grad_cache import cachex
73 | except ImportError:
74 | raise ModuleNotFoundError('GradCache packaged needs to be installed for running grad_cache_train_step')
75 |
76 | def encode_query(params, **kwargs):
77 | return state.apply_fn(**kwargs, params=params.q_params, train=True)[0][:, 0, :]
78 |
79 | def encode_passage(params, **kwargs):
80 | return state.apply_fn(**kwargs, params=params.p_params, train=True)[0][:, 0, :]
81 |
82 | queries, passages = cachex.tree_chunk(queries, q_n_subbatch), cachex.tree_chunk(passages, p_n_subbatch)
83 | q_rngs, p_rngs, new_rng = jax.random.split(dropout_rng, 3)
84 | q_rngs = jax.random.split(q_rngs, q_n_subbatch)
85 | p_rngs = jax.random.split(p_rngs, p_n_subbatch)
86 |
87 | q_reps = cachex.chunk_encode(partial(encode_query, state.params))(**queries, dropout_rng=q_rngs)
88 | p_reps = cachex.chunk_encode(partial(encode_passage, state.params))(**passages, dropout_rng=p_rngs)
89 |
90 | @cachex.unchunk_args(axis=0, argnums=(0, 1))
91 | def compute_loss(xx, yy):
92 | return jnp.mean(p_contrastive_loss(xx, yy, axis=axis))
93 |
94 | loss, (q_grads, p_grads) = jax.value_and_grad(compute_loss, argnums=(0, 1))(q_reps, p_reps)
95 |
96 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params)
97 | grads = cachex.cache_grad(encode_query)(state.params, grads, q_grads, **queries, dropout_rng=q_rngs)
98 | grads = cachex.cache_grad(encode_passage)(state.params, grads, p_grads, **passages, dropout_rng=p_rngs)
99 |
100 | loss, grads = jax.lax.pmean([loss, grads], axis)
101 | new_state = state.apply_gradients(grads=grads)
102 | return loss, new_state, new_rng
103 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 | from typing import Dict, List, Tuple, Optional, Any, Union
4 |
5 | from transformers.trainer import Trainer
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import torch.distributed as dist
10 |
11 | from .loss import SimpleContrastiveLoss, DistributedContrastiveLoss
12 |
13 | import logging
14 | logger = logging.getLogger(__name__)
15 |
16 | try:
17 | from grad_cache import GradCache
18 | _grad_cache_available = True
19 | except ModuleNotFoundError:
20 | _grad_cache_available = False
21 |
22 |
23 | class TevatronTrainer(Trainer):
24 | def __init__(self, *args, **kwargs):
25 | super(TevatronTrainer, self).__init__(*args, **kwargs)
26 | self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1
27 |
28 | def _save(self, output_dir: Optional[str] = None):
29 | output_dir = output_dir if output_dir is not None else self.args.output_dir
30 | os.makedirs(output_dir, exist_ok=True)
31 | logger.info("Saving model checkpoint to %s", output_dir)
32 | self.model.save(output_dir)
33 |
34 | def _prepare_inputs(
35 | self,
36 | inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...]
37 | ) -> List[Dict[str, Union[torch.Tensor, Any]]]:
38 | prepared = []
39 | for x in inputs:
40 | if isinstance(x, torch.Tensor):
41 | prepared.append(x.to(self.args.device))
42 | else:
43 | prepared.append(super()._prepare_inputs(x))
44 | return prepared
45 |
46 | def get_train_dataloader(self) -> DataLoader:
47 | if self.train_dataset is None:
48 | raise ValueError("Trainer: training requires a train_dataset.")
49 | train_sampler = self._get_train_sampler()
50 |
51 | return DataLoader(
52 | self.train_dataset,
53 | batch_size=self.args.train_batch_size,
54 | sampler=train_sampler,
55 | collate_fn=self.data_collator,
56 | drop_last=True,
57 | num_workers=self.args.dataloader_num_workers,
58 | )
59 |
60 | def compute_loss(self, model, inputs):
61 | query, passage = inputs
62 | return model(query=query, passage=passage).loss
63 |
64 | def training_step(self, *args):
65 | return super(TevatronTrainer, self).training_step(*args) / self._dist_loss_scale_factor
66 |
67 |
68 | def split_dense_inputs(model_input: dict, chunk_size: int):
69 | assert len(model_input) == 1
70 | arg_key = list(model_input.keys())[0]
71 | arg_val = model_input[arg_key]
72 |
73 | keys = list(arg_val.keys())
74 | chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
75 | chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]
76 |
77 | return [{arg_key: c} for c in chunked_arg_val]
78 |
79 |
80 | def get_dense_rep(x):
81 | if x.q_reps is None:
82 | return x.p_reps
83 | else:
84 | return x.q_reps
85 |
86 |
87 | class GCTrainer(TevatronTrainer):
88 | def __init__(self, *args, **kwargs):
89 | logger.info('Initializing Gradient Cache Trainer')
90 | if not _grad_cache_available:
91 | raise ValueError(
92 | 'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')
93 | super(GCTrainer, self).__init__(*args, **kwargs)
94 |
95 | loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss
96 | loss_fn = loss_fn_cls()
97 |
98 | self.gc = GradCache(
99 | models=[self.model, self.model],
100 | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size],
101 | loss_fn=loss_fn,
102 | split_input_fn=split_dense_inputs,
103 | get_rep_fn=get_dense_rep,
104 | fp16=self.args.fp16,
105 | scaler=self.scaler if self.args.fp16 else None
106 | )
107 |
108 | def training_step(self, model, inputs) -> torch.Tensor:
109 | model.train()
110 | queries, passages = self._prepare_inputs(inputs)
111 | queries, passages = {'query': queries}, {'passage': passages}
112 |
113 | _distributed = self.args.local_rank > -1
114 | self.gc.models = [model, model]
115 | loss = self.gc(queries, passages, no_sync_except_last=_distributed)
116 |
117 | return loss / self._dist_loss_scale_factor
118 |
--------------------------------------------------------------------------------
/tevatron/src/tevatron/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import evaluation
--------------------------------------------------------------------------------
/tevatron/src/tevatron/utils/convert_from_dpr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 |
5 | from transformers import AutoConfig, AutoTokenizer
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--dpr_model', required=True)
10 | parser.add_argument('--save_to', required=True)
11 | args = parser.parse_args()
12 |
13 | dpr_model_ckpt = torch.load(args.dpr_model, map_location='cpu')
14 | config_name = dpr_model_ckpt['encoder_params']['pretrained_model_cfg']
15 | dpr_model_dict = dpr_model_ckpt['model_dict']
16 |
17 | AutoConfig.from_pretrained(config_name).save_pretrained(args.save_to)
18 | AutoTokenizer.from_pretrained(config_name).save_pretrained(args.save_to)
19 |
20 | question_keys = [k for k in dpr_model_dict.keys() if k.startswith('question_model')]
21 | ctx_keys = [k for k in dpr_model_dict.keys() if k.startswith('ctx_model')]
22 |
23 | question_dict = dict([(k[len('question_model')+1:], dpr_model_dict[k]) for k in question_keys])
24 | ctx_dict = dict([(k[len('ctx_model')+1:], dpr_model_dict[k]) for k in ctx_keys])
25 |
26 | os.makedirs(os.path.join(args.save_to, 'query_model'), exist_ok=True)
27 | os.makedirs(os.path.join(args.save_to, 'passage_model'), exist_ok=True)
28 | torch.save(question_dict, os.path.join(args.save_to, 'query_model', 'pytorch_model.bin'))
29 | torch.save(ctx_dict, os.path.join(args.save_to, 'passage_model', 'pytorch_model.bin'))
30 |
31 |
32 | if __name__ == '__main__':
33 | main()
--------------------------------------------------------------------------------
/tevatron/src/tevatron/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import collections
9 | import logging
10 | import regex
11 | import string
12 | import unicodedata
13 | from functools import partial
14 | from multiprocessing import Pool as ProcessPool
15 | from typing import Tuple, List, Dict
16 | import numpy as np
17 |
18 | """
19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR
20 | """
21 |
22 | class SimpleTokenizer(object):
23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
24 | NON_WS = r'[^\p{Z}\p{C}]'
25 |
26 | def __init__(self):
27 | """
28 | Args:
29 | annotators: None or empty set (only tokenizes).
30 | """
31 | self._regexp = regex.compile(
32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
34 | )
35 |
36 | def tokenize(self, text, uncased=False):
37 | matches = [m for m in self._regexp.finditer(text)]
38 | if uncased:
39 | tokens = [m.group().lower() for m in matches]
40 | else:
41 | tokens = [m.group() for m in matches]
42 | return tokens
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
47 |
48 | def calculate_matches(data: List, workers_num: int):
49 | """
50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results
52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
53 | :param answers: list of answers's list. One list per question
54 | :param closest_docs: document ids of the top results along with their scores
55 | :param workers_num: amount of parallel threads to process data
56 | :param match_type: type of answer matching. Refer to has_answer code for available options
57 | :return: matching information tuple.
58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
59 | valid matches across an entire dataset.
60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
61 | """
62 |
63 | logger.info('Matching answers in top docs...')
64 |
65 | tokenizer = SimpleTokenizer()
66 | get_score_partial = partial(check_answer, tokenizer=tokenizer)
67 |
68 | processes = ProcessPool(processes=workers_num)
69 | scores = processes.map(get_score_partial, data)
70 |
71 | logger.info('Per question validation results len=%d', len(scores))
72 |
73 | n_docs = len(data[0]['ctxs'])
74 | top_k_hits = [0] * n_docs
75 | for question_hits in scores:
76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None)
77 | if best_hit is not None:
78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
79 |
80 | return QAMatchStats(top_k_hits, scores)
81 |
82 | def check_answer(example, tokenizer) -> List[bool]:
83 | """Search through all the top docs to see if they have any of the answers."""
84 | answers = example['answers']
85 | ctxs = example['ctxs']
86 |
87 | hits = []
88 |
89 | for i, doc in enumerate(ctxs):
90 | text = doc['text']
91 |
92 | if text is None: # cannot find the document for some reason
93 | logger.warning("no doc in db")
94 | hits.append(False)
95 | continue
96 |
97 | hits.append(has_answer(answers, text, tokenizer))
98 |
99 | return hits
100 |
101 | def has_answer(answers, text, tokenizer) -> bool:
102 | """Check if a document contains an answer string."""
103 | text = _normalize(text)
104 | text = tokenizer.tokenize(text, uncased=True)
105 |
106 | for answer in answers:
107 | answer = _normalize(answer)
108 | answer = tokenizer.tokenize(answer, uncased=True)
109 | for i in range(0, len(text) - len(answer) + 1):
110 | if answer == text[i: i + len(answer)]:
111 | return True
112 | return False
113 |
114 | #################################################
115 | ######## READER EVALUATION ########
116 | #################################################
117 |
118 | def _normalize(text):
119 | return unicodedata.normalize('NFD', text)
120 |
121 | #Normalization from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
122 | def normalize_answer(s):
123 | def remove_articles(text):
124 | return regex.sub(r'\b(a|an|the)\b', ' ', text)
125 |
126 | def white_space_fix(text):
127 | return ' '.join(text.split())
128 |
129 | def remove_punc(text):
130 | exclude = set(string.punctuation)
131 | return ''.join(ch for ch in text if ch not in exclude)
132 |
133 | def lower(text):
134 | return text.lower()
135 |
136 | return white_space_fix(remove_articles(remove_punc(lower(s))))
137 |
138 | def exact_match_score(prediction, ground_truth):
139 | return normalize_answer(prediction) == normalize_answer(ground_truth)
140 |
141 | def ems(prediction, ground_truths):
142 | return max([exact_match_score(prediction, gt) for gt in ground_truths])
143 |
144 | ####################################################
145 | ######## RETRIEVER EVALUATION ########
146 | ####################################################
147 |
148 | def eval_batch(scores, inversions, avg_topk, idx_topk):
149 | for k, s in enumerate(scores):
150 | s = s.cpu().numpy()
151 | sorted_idx = np.argsort(-s)
152 | score(sorted_idx, inversions, avg_topk, idx_topk)
153 |
154 | def count_inversions(arr):
155 | inv_count = 0
156 | lenarr = len(arr)
157 | for i in range(lenarr):
158 | for j in range(i + 1, lenarr):
159 | if (arr[i] > arr[j]):
160 | inv_count += 1
161 | return inv_count
162 |
163 | def score(x, inversions, avg_topk, idx_topk):
164 | x = np.array(x)
165 | inversions.append(count_inversions(x))
166 | for k in avg_topk:
167 | # ratio of passages in the predicted top-k that are
168 | # also in the topk given by gold score
169 | avg_pred_topk = (x[:k]