├── src
├── __init__.py
├── index.py
├── inbatch.py
├── dist_utils.py
├── slurm.py
├── normalize_text.py
├── moco.py
├── contriever.py
├── finetuning_data.py
├── options.py
├── evaluation.py
├── utils.py
├── beir_utils.py
└── data.py
├── requirements.txt
├── evaluate.sh
├── run_reranker.sh
├── template
├── mistral.jinja
├── llama2.jinja
└── llama3.jinja
├── aggregate.py
├── run_retriever.sh
├── llama8b.yaml
├── llama8b_dpo.yaml
├── split_for_sft_dpo.py
├── split_data.py
├── inference.sh
├── reward_trajectories.py
├── reranker_sequence.py
├── evaluate.py
├── metrics.py
├── process_training_data.py
├── construct_generator_sft.py
├── reranker.py
├── utils.py
├── sampling_dpo_trajectories.py
├── constrcut_dpo.py
├── inference.py
├── top_inference.py
├── retriever.py
└── readme.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | rouge
2 | bert_score
3 | faiss-gpu
4 | tqdm
5 | openai
6 | vllm
7 | deepspeed
8 | flash-attn
9 | llamafactory-cli
--------------------------------------------------------------------------------
/evaluate.sh:
--------------------------------------------------------------------------------
1 | # install nltk, rouge_score, spacy
2 | # python -m spacy download en_core_web_sm
3 |
4 |
5 | python evaluate.py \
6 | --results_file results/llama3_8b_nq.json \
7 | --metric match
--------------------------------------------------------------------------------
/run_reranker.sh:
--------------------------------------------------------------------------------
1 | python reranker.py --model_name_or_path models/reranker/monot5 \
2 | --input_file output/retrieval_data.jsonl \
3 | --output_file output/retrieval_data_rerank.jsonl \
4 | --device cuda
5 |
--------------------------------------------------------------------------------
/template/mistral.jinja:
--------------------------------------------------------------------------------
1 | {% if messages[0]['role'] == 'system' %}
2 | {% set loop_messages = messages[1:] %}
3 | {% set system_message = messages[0]['content'] %}
4 | {% else %}
5 | {% set loop_messages = messages %}
6 | {% endif %}
7 |
8 | {% if system_message is defined %}
9 | {{ '' }}{{ system_message }}
10 | {% endif %}
11 |
12 | {% for message in loop_messages %}
13 | {% set content = message['content'] %}
14 |
15 | {% if message['role'] == 'user' %}
16 | {{ '[INST] ' + content + ' [/INST]' }}
17 | {% elif message['role'] == 'assistant' %}
18 | {{ content + '' }}
19 | {% endif %}
20 | {% endfor %}
21 |
--------------------------------------------------------------------------------
/template/llama2.jinja:
--------------------------------------------------------------------------------
1 | {% if messages[0]['role'] == 'system' %}
2 | {% set loop_messages = messages[1:] %}
3 | {% set system_message = messages[0]['content'] %}
4 | {% else %}
5 | {% set loop_messages = messages %}
6 | {% endif %}
7 |
8 | {% for message in loop_messages %}
9 | {% set content = message['content'] %}
10 |
11 | {% if loop.index0 == 0 and system_message is defined %}
12 | {% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}
13 | {% endif %}
14 |
15 | {% if message['role'] == 'user' %}
16 | {{ '' + '[INST] ' + content + ' [/INST]' }}
17 | {% elif message['role'] == 'assistant' %}
18 | {{ content + '' }}
19 | {% endif %}
20 | {% endfor %}
21 |
--------------------------------------------------------------------------------
/aggregate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from tqdm import tqdm
4 |
5 | input_files = [f"retrieval_split_{i}.json" for i in range(32)]
6 | output_file = "retrieval_data.jsonl"
7 |
8 |
9 | data = []
10 | id_set = set()
11 |
12 |
13 | for file in tqdm(input_files, desc="Processing files"):
14 | with open(file, 'r', encoding='utf-8') as f:
15 | for line in f:
16 | record = json.loads(line)
17 | record_id = record.get("id")
18 | if record_id not in id_set:
19 | id_set.add(record_id)
20 | data.append(record)
21 |
22 |
23 | with open(output_file, 'w', encoding='utf-8') as f:
24 | for record in tqdm(data, desc="Writing to file"):
25 | f.write(json.dumps(record, ensure_ascii=False) + "\n")
26 |
27 |
--------------------------------------------------------------------------------
/template/llama3.jinja:
--------------------------------------------------------------------------------
1 | {{ '<|begin_of_text|>' }}
2 |
3 | {% if messages[0]['role'] == 'system' %}
4 | {% set loop_messages = messages[1:] %}
5 | {% set system_message = messages[0]['content'] %}
6 | {% else %}
7 | {% set loop_messages = messages %}
8 | {% endif %}
9 |
10 | {% if system_message is defined %}
11 | {{ '<|start_header_id|>system<|end_header_id|>' }}
12 | {{ system_message }}
13 | {{ '<|eot_id|>' }}
14 | {% endif %}
15 |
16 | {% for message in loop_messages %}
17 | {% set content = message['content'] %}
18 |
19 | {% if message['role'] == 'user' %}
20 | {{ '<|start_header_id|>user<|end_header_id|>' }}
21 | {{ content }}
22 | {{ '<|eot_id|>' }}
23 | {{ '<|start_header_id|>assistant<|end_header_id|>' }}
24 | {% elif message['role'] == 'assistant' %}
25 | {{ content }}
26 | {{ '<|eot_id|>' }}
27 | {% endif %}
28 | {% endfor %}
29 |
--------------------------------------------------------------------------------
/run_retriever.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NUM_GPUS=8
4 | INPUT_FILE="data/rag_training_data.json"
5 | SPLIT_DIR="data/splits"
6 |
7 | python split_data.py --input_file $INPUT_FILE --output_dir $SPLIT_DIR --num_splits $NUM_GPUS
8 |
9 | for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
10 | SPLIT_FILE="${SPLIT_DIR}/split_${GPU_ID}.json"
11 | OUTPUT_FILE="output/retrieval_split_${GPU_ID}.json"
12 | log_file="logs/retriever_split_${GPU_ID}.log"
13 | CUDA_VISIBLE_DEVICES=$GPU_ID python retriever.py \
14 | --model_name_or_path models/retriever \
15 | --passages data/psgs_w100.tsv \
16 | --passages_embeddings "data/wikipedia_embeddings/*" \
17 | --query $SPLIT_FILE \
18 | --output_dir $OUTPUT_FILE \
19 | --n_docs 50 \
20 | 1>"$log_file" 2>&1 &
21 |
22 | echo "Started process on GPU $GPU_ID with input $SPLIT_FILE"
23 | done
24 |
25 | wait
26 | echo "All processes completed."
27 |
--------------------------------------------------------------------------------
/llama8b.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: dynamicRAG/models/generator/llama3_8b
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 | deepspeed: examples/deepspeed/ds_z3_config.json
9 |
10 | ### dataset
11 | dataset: alpaca_data, reranker_bc, generator_sft
12 | template: llama3
13 | cutoff_len: 8192
14 | # max_samples: 1000
15 | overwrite_cache: true
16 | preprocessing_num_workers: 16
17 |
18 | ### output
19 | output_dir: saves/dynamicrag_llama3_8b/
20 | logging_steps: 1
21 | save_steps: 2000
22 | plot_loss: true
23 | overwrite_output_dir: true
24 |
25 | ### train
26 | per_device_train_batch_size: 2
27 | gradient_accumulation_steps: 2
28 | learning_rate: 1.0e-5
29 | num_train_epochs: 3.0
30 | lr_scheduler_type: cosine
31 | warmup_ratio: 0.1
32 | bf16: true
33 | ddp_timeout: 180000000
34 |
35 | ### eval
36 | eval_dataset: alpaca_en_demo
37 | per_device_eval_batch_size: 1
38 | eval_strategy: steps
39 | eval_steps: 1000000000
40 |
--------------------------------------------------------------------------------
/llama8b_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: LLaMA-Factory/saves/dynamicrag_llama3_8b
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | deepspeed: examples/deepspeed/ds_z3_config.json
9 |
10 | ### dataset
11 | dataset: llama3_generator_dpo, llama3_reranker_dpo
12 | template: llama3
13 | cutoff_len: 8192
14 | # max_samples: 1000
15 | overwrite_cache: true
16 | preprocessing_num_workers: 16
17 |
18 | ### output
19 | #
20 | output_dir: saves/dynamicrag_llama3_8b_dpo/
21 | logging_steps: 1
22 | save_steps: 200
23 | plot_loss: true
24 | overwrite_output_dir: true
25 |
26 | ### train
27 | per_device_train_batch_size: 1
28 | gradient_accumulation_steps: 4
29 | learning_rate: 5.0e-6
30 | num_train_epochs: 1.0
31 | lr_scheduler_type: cosine
32 | warmup_ratio: 0.1
33 | bf16: true
34 | ddp_timeout: 180000000
35 |
36 | ### eval
37 | eval_dataset: llama3_reranker_dpo
38 | per_device_eval_batch_size: 1
39 | eval_strategy: steps
40 | eval_steps: 1000000000
--------------------------------------------------------------------------------
/split_for_sft_dpo.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | def shuffle_and_split_jsonl_to_json(input_file, n, output_file_1, output_file_2):
5 | data = []
6 | with open(input_file, 'r', encoding='utf-8') as f:
7 | for line in f:
8 | data.append(json.loads(line.strip()))
9 | print(len(data))
10 | if not isinstance(data, list):
11 | raise ValueError("The input JSONL must contain a list of items.")
12 |
13 | random.shuffle(data)
14 |
15 | data_part_1 = data[:n]
16 | data_part_2 = data[n:]
17 |
18 | with open(output_file_1, 'w', encoding='utf-8') as f:
19 | json.dump(data_part_1, f, ensure_ascii=False, indent=4)
20 |
21 | with open(output_file_2, 'w', encoding='utf-8') as f:
22 | json.dump(data_part_2, f, ensure_ascii=False, indent=4)
23 |
24 | print(f"Shuffled and split data saved to {output_file_1} and {output_file_2}")
25 |
26 | input_file = 'training_data/retrieval_data_rerank_normal.jsonl'
27 | n = 70000
28 | output_file_1 = 'training_data/training_data_sft.json'
29 | output_file_2 = 'training_data/training_data_dpo.json'
30 |
31 | shuffle_and_split_jsonl_to_json(input_file, n, output_file_1, output_file_2)
32 |
--------------------------------------------------------------------------------
/split_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import math
5 |
6 | def split_data(input_file, output_dir, num_splits):
7 | with open(input_file, 'r', encoding='utf-8') as f:
8 | data = [json.loads(line.strip()) for line in f]
9 |
10 | total_items = len(data)
11 | items_per_split = math.ceil(total_items / num_splits)
12 |
13 | os.makedirs(output_dir, exist_ok=True)
14 |
15 | for i in range(num_splits):
16 | split_data = data[i * items_per_split:(i + 1) * items_per_split]
17 | split_file = os.path.join(output_dir, f"split_{i}.json")
18 | with open(split_file, 'w', encoding='utf-8') as out_f:
19 | json.dump(split_data, out_f, indent=4, ensure_ascii=False)
20 |
21 | print(f"Data split into {num_splits} parts and saved in {output_dir}")
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--input_file", type=str, required=True, help="Path to input JSONL file")
26 | parser.add_argument("--output_dir", type=str, required=True, help="Directory to save splits")
27 | parser.add_argument("--num_splits", type=int, required=True, help="Number of splits")
28 |
29 | args = parser.parse_args()
30 | split_data(args.input_file, args.output_dir, args.num_splits)
31 |
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | LOG_DIR="eval_logs"
5 | mkdir -p $LOG_DIR
6 |
7 | run_inference() {
8 | local input_file=$1
9 | local output_file=$2
10 | local remain_output_file=$3
11 |
12 | echo "Running inference for $input_file..."
13 | python inference.py \
14 | --template template/llama3.jinja \
15 | --llm-model DynamicRAG_llama3_8b \
16 | --input-json $input_file \
17 | --output-json $output_file \
18 | --remain-output-json $remain_output_file \
19 | >> $LOG_DIR/$(basename $output_file .json)_log.txt 2>&1
20 |
21 | sleep 5
22 | }
23 |
24 |
25 | run_inference "eval_data/triviaqa.jsonl" \
26 | "results/llama3_8b_triviaqa.json" \
27 | "results/llama3_8b_triviaqa_remain.json"
28 |
29 | run_inference "eval_data/nq.jsonl" \
30 | "results/llama3_8b_nq.json" \
31 | "results/llama3_8b_nq_remain.json"
32 |
33 | run_inference "eval_data/hotpotqa.jsonl" \
34 | "results/llama3_8b_hotpotqa.json" \
35 | "results/llama3_8b_hotpotqa_remain.json"
36 |
37 | run_inference "eval_data/2wikimqa.jsonl" \
38 | "results/llama3_8b_2wikimqa.json" \
39 | "results/llama3_8b_2wikimqa_remain.json"
40 |
41 | run_inference "eval_data/fever.jsonl" \
42 | "results/llama3_8b_fever.json" \
43 | "results/llama3_8b_fever_remain.json"
44 |
45 | run_inference "eval_data/eli5.jsonl" \
46 | "results/llama3_8b_eli5.json" \
47 | "results/llama3_8b_eli5_remain.json"
48 |
49 | run_inference "eval_data/asqa_eval_gtr_top100.jsonl" \
50 | "results/llama3_8b_asqa.json" \
51 | "results/llama3_8b_asqa_remain.json"
52 |
53 | echo "All tasks completed. Logs are available in $LOG_DIR."
54 |
--------------------------------------------------------------------------------
/reward_trajectories.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tqdm import tqdm
4 | from rouge import Rouge
5 | from utils import compute_bert_score, compute_rouge_score, compute_exact_match, inverse_reward, llm_eval
6 |
7 | def process_data(input_file, output_file):
8 | with open(input_file, 'r', encoding='utf-8') as f:
9 | data = json.load(f)
10 |
11 | rouge_score = Rouge()
12 |
13 | weights = [0.2, 0.2, 0.2, 0.2, 0.2]
14 |
15 | for item in tqdm(data):
16 | instruction = item['question']
17 | answers = [ans for ans in item.get("answer", []) if ans.strip()]
18 | responses = item["responses"]
19 | id_list = item.get("id_lists", [])
20 |
21 | scores = []
22 | if len(set(responses)) == 1:
23 | scores = [0] * len(responses)
24 | else:
25 | for response in responses:
26 | response_scores = [
27 | weights[0] * compute_bert_score(answer, response) +
28 | weights[1] * compute_rouge_score(rouge_score, answer, response) +
29 | weights[2] * compute_exact_match(answer, response) +
30 | weights[3] * inverse_reward(id_list) +
31 | weights[4] * llm_eval(instruction, answer, response)
32 | for answer in answers
33 | ]
34 | scores.append(sum(response_scores) / len(response_scores) if response_scores else 0)
35 |
36 | item["response_scores"] = scores
37 |
38 | with open(output_file, 'w', encoding='utf-8') as f:
39 | json.dump(data, f, indent=4, ensure_ascii=False)
40 |
41 | print(f"Processed data saved to {output_file}")
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--input_file", type=str, required=True, help="Path to input JSON file")
46 | parser.add_argument("--output_file", type=str, required=True, help="Path to save processed JSON file")
47 |
48 | args = parser.parse_args()
49 | process_data(args.input_file, args.output_file)
50 |
--------------------------------------------------------------------------------
/reranker_sequence.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 |
4 |
5 | input_file = 'training_data/retrieval_data_rerank_sequence.jsonl'
6 | output_file = 'training_data/reranker_bc_training.json'
7 |
8 |
9 | data = []
10 | with open(input_file, "r", encoding="utf-8") as f:
11 | for line in f:
12 | data.append(json.loads(line.strip()))
13 |
14 |
15 | new_data = []
16 |
17 | for item in tqdm(data, desc="Processing items"):
18 | sequence = item.get("sequence", [])
19 | docs = item.get("docs", [])
20 | question = item.get("question", "")
21 |
22 | query = question
23 |
24 | system_prompt = (
25 | f"You are an expert at dynamically generating document identifiers to answer a given query.\n"
26 | f"I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n"
27 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n"
28 | f"Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n"
29 | f"If no documents are required to answer the query, output \"None\".\n"
30 | f"Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n"
31 | f"Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text."
32 | )
33 |
34 | if not sequence:
35 | output = "None"
36 | else:
37 | output = ", ".join(f"[{seq}]" for seq in sequence)
38 |
39 | retrieved_content = "Retrieved Content:\n" + "\n".join(
40 | [
41 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}"
42 | for i, doc in enumerate(docs)
43 | ]
44 | )
45 | instruction_data = {
46 | "instruction": f"Query: {question}\n\n{retrieved_content}",
47 | "input": "",
48 | "output": output,
49 | "system": system_prompt
50 | }
51 |
52 | new_data.append(instruction_data)
53 |
54 |
55 | with open(output_file, "w", encoding="utf-8") as f:
56 | json.dump(new_data, f, ensure_ascii=False, indent=4)
57 |
--------------------------------------------------------------------------------
/src/index.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import pickle
9 | from typing import List, Tuple
10 |
11 | import faiss
12 | import numpy as np
13 | from tqdm import tqdm
14 |
15 | class Indexer(object):
16 |
17 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8):
18 | if n_subquantizers > 0:
19 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
20 | else:
21 | self.index = faiss.IndexFlatIP(vector_sz)
22 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
23 | self.index_id_to_db_id = []
24 |
25 | def index_data(self, ids, embeddings):
26 | self._update_id_mapping(ids)
27 | embeddings = embeddings.astype('float32')
28 | if not self.index.is_trained:
29 | self.index.train(embeddings)
30 | self.index.add(embeddings)
31 |
32 | print(f'Total data indexed {len(self.index_id_to_db_id)}')
33 |
34 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]:
35 | query_vectors = query_vectors.astype('float32')
36 | result = []
37 | nbatch = (len(query_vectors)-1) // index_batch_size + 1
38 | for k in tqdm(range(nbatch)):
39 | start_idx = k*index_batch_size
40 | end_idx = min((k+1)*index_batch_size, len(query_vectors))
41 | q = query_vectors[start_idx: end_idx]
42 | scores, indexes = self.index.search(q, top_docs)
43 | # convert to external ids
44 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
45 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
46 | return result
47 |
48 | def serialize(self, dir_path):
49 | index_file = os.path.join(dir_path, 'index.faiss')
50 | meta_file = os.path.join(dir_path, 'index_meta.faiss')
51 | print(f'Serializing index to {index_file}, meta data to {meta_file}')
52 |
53 | faiss.write_index(self.index, index_file)
54 | with open(meta_file, mode='wb') as f:
55 | pickle.dump(self.index_id_to_db_id, f)
56 |
57 | def deserialize_from(self, dir_path):
58 | index_file = os.path.join(dir_path, 'index.faiss')
59 | meta_file = os.path.join(dir_path, 'index_meta.faiss')
60 | print(f'Loading index from {index_file}, meta data from {meta_file}')
61 |
62 | self.index = faiss.read_index(index_file)
63 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
64 |
65 | with open(meta_file, "rb") as reader:
66 | self.index_id_to_db_id = pickle.load(reader)
67 | assert len(
68 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
69 |
70 | def _update_id_mapping(self, db_ids: List):
71 | #new_ids = np.array(db_ids, dtype=np.int64)
72 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
73 | self.index_id_to_db_id.extend(db_ids)
--------------------------------------------------------------------------------
/src/inbatch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import math
7 | import random
8 | import transformers
9 | import logging
10 | import torch.distributed as dist
11 |
12 | from src import contriever, dist_utils, utils
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class InBatch(nn.Module):
18 | def __init__(self, opt, retriever=None, tokenizer=None):
19 | super(InBatch, self).__init__()
20 |
21 | self.opt = opt
22 | self.norm_doc = opt.norm_doc
23 | self.norm_query = opt.norm_query
24 | self.label_smoothing = opt.label_smoothing
25 | if retriever is None or tokenizer is None:
26 | retriever, tokenizer = self._load_retriever(
27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
28 | )
29 | self.tokenizer = tokenizer
30 | self.encoder = retriever
31 |
32 | def _load_retriever(self, model_id, pooling, random_init):
33 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
34 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
35 |
36 | if "xlm" in model_id:
37 | model_class = contriever.XLMRetriever
38 | else:
39 | model_class = contriever.Contriever
40 |
41 | if random_init:
42 | retriever = model_class(cfg)
43 | else:
44 | retriever = utils.load_hf(model_class, model_id)
45 |
46 | if "bert-" in model_id:
47 | if tokenizer.bos_token_id is None:
48 | tokenizer.bos_token = "[CLS]"
49 | if tokenizer.eos_token_id is None:
50 | tokenizer.eos_token = "[SEP]"
51 |
52 | retriever.config.pooling = pooling
53 |
54 | return retriever, tokenizer
55 |
56 | def get_encoder(self):
57 | return self.encoder
58 |
59 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
60 |
61 | bsz = len(q_tokens)
62 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
63 |
64 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
65 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
66 |
67 | gather_fn = dist_utils.gather
68 |
69 | gather_kemb = gather_fn(kemb)
70 |
71 | labels = labels + dist_utils.get_rank() * len(kemb)
72 |
73 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
74 |
75 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
76 |
77 | # log stats
78 | if len(stats_prefix) > 0:
79 | stats_prefix = stats_prefix + "/"
80 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
81 |
82 | predicted_idx = torch.argmax(scores, dim=-1)
83 | accuracy = 100 * (predicted_idx == labels).float().mean()
84 | stdq = torch.std(qemb, dim=0).mean().item()
85 | stdk = torch.std(kemb, dim=0).mean().item()
86 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
87 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
88 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
89 |
90 | return loss, iter_stats
91 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from metrics import metric_max_over_ground_truths, exact_match_score, match, rouge, f1
3 | import json
4 | from tqdm import tqdm
5 | from rouge import Rouge
6 |
7 | def load_file(file_path):
8 | """
9 | Load data from a JSON or JSONL file.
10 |
11 | Args:
12 | file_path (str): Path to the file to load.
13 |
14 | Returns:
15 | list: List of dictionaries loaded from the file.
16 |
17 | Raises:
18 | ValueError: If the file format is not supported.
19 | """
20 | try:
21 | with open(file_path, 'r', encoding='utf-8') as f:
22 | if file_path.endswith('.json'):
23 | data = json.load(f)
24 | if isinstance(data, dict):
25 | return [data]
26 | elif isinstance(data, list):
27 | return data
28 | else:
29 | raise ValueError("Unsupported JSON structure. Expecting list or dict.")
30 | elif file_path.endswith('.jsonl'):
31 | data = [json.loads(line.strip()) for line in f if line.strip()]
32 | return data
33 | else:
34 | raise ValueError(f"Unsupported file format: {file_path}")
35 | except Exception as e:
36 | print(f"Error loading file {file_path}: {e}")
37 | raise
38 |
39 |
40 | def get_args():
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--results_file", type=str, help="File containing results with both generated responses and ground truth answers")
43 | parser.add_argument("--metric", type=str, choices=["em", "accuracy", "match", "rouge", "f1"], help="Metric to use for evaluation")
44 | return parser.parse_args()
45 |
46 | def main():
47 | args = get_args()
48 |
49 | results = load_file(args.results_file)
50 |
51 | scores = []
52 | cnt=0
53 | rouge_score = Rouge()
54 | for item in tqdm(results):
55 | response = item['response'][0]
56 |
57 | question = item['question']
58 |
59 | if 'asqa' in args.results_file:
60 | answers= []
61 | for ans in item['qa_pairs']:
62 | answers.extend(ans['short_answers'])
63 | else:
64 | answers = item['answers']
65 | if not answers:
66 | print(f"Warning: No answers provided for ID {item.get('id', 'unknown')}")
67 | continue
68 |
69 | if args.metric == "em":
70 | metric_result = metric_max_over_ground_truths(
71 | exact_match_score, response, answers
72 | )
73 | elif args.metric == "accuracy":
74 | response = response.replace('\n','').strip()[0]
75 | answer = answers[0][0]
76 | if response==answer:
77 | metric_result = 1.0
78 | else:
79 | metric_result = 0.0
80 | elif args.metric == "match":
81 | metric_result = match(question, response, answers)
82 | elif args.metric == "rouge":
83 | metric_result = rouge(rouge_score, response, answers)
84 | elif args.metric == "f1":
85 | metric_result = f1(response, answers)
86 | else:
87 | raise NotImplementedError(f"Metric {args.metric} is not implemented.")
88 |
89 | scores.append(metric_result)
90 |
91 | if scores:
92 | print(f'Overall result: {sum(scores) / len(scores)}')
93 | else:
94 | print("No scores were calculated. Please check your input file.")
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import string
3 | import re
4 | from collections import Counter
5 | import re
6 | from typing import List
7 |
8 | # nltk.download('punkt_tab')
9 |
10 | def convert_to_capitalized(s):
11 | if s.isupper():
12 | return s.capitalize()
13 | return s
14 |
15 | def exact_match_score(prediction, ground_truth):
16 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
17 |
18 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
19 | scores_for_ground_truths = []
20 | for ground_truth in ground_truths:
21 | score = metric_fn(prediction, ground_truth)
22 | scores_for_ground_truths.append(score)
23 | return max(scores_for_ground_truths)
24 |
25 | def accuracy(preds, labels):
26 | match_count = 0
27 |
28 | for pred, label in zip(preds, labels):
29 | target = label[0]
30 | if pred == target or pred[0]==target:
31 | match_count += 1
32 | if match_count == 0:
33 | print(repr(pred))
34 | print(target)
35 | return 100 * (match_count / len(preds))
36 |
37 |
38 | def f1(decoded_preds, decoded_labels):
39 | f1_all = []
40 | for prediction, answers in zip(decoded_preds, decoded_labels):
41 | if type(answers) == list:
42 | if len(answers) == 0:
43 | return 0
44 | f1_all.append(np.max([qa_f1_score(prediction, gt)
45 | for gt in answers]))
46 | else:
47 | f1_all.append(qa_f1_score(prediction, answers))
48 | return 100 * np.mean(f1_all)
49 |
50 |
51 | def qa_f1_score(prediction, ground_truth):
52 | prediction_tokens = normalize_answer(prediction).split()
53 | ground_truth_tokens = normalize_answer(ground_truth).split()
54 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
55 | num_same = sum(common.values())
56 | if num_same == 0:
57 | return 0
58 | precision = 1.0 * num_same / len(prediction_tokens)
59 | recall = 1.0 * num_same / len(ground_truth_tokens)
60 | f1 = (2 * precision * recall) / (precision + recall)
61 | return f1
62 |
63 |
64 | def normalize_answer(s):
65 | def remove_articles(text):
66 | return re.sub(r'\b(a|an|the)\b', ' ', text)
67 |
68 | def white_space_fix(text):
69 | return ' '.join(text.split())
70 |
71 | def remove_punc(text):
72 | exclude = set(string.punctuation)
73 | return ''.join(ch for ch in text if ch not in exclude)
74 |
75 | def lower(text):
76 | return text.lower()
77 | return white_space_fix(remove_articles(remove_punc(lower(s))))
78 |
79 | def find_entity_tags(sentence):
80 | entity_regex = r'(.+?)(?=\s<|$)'
81 | tag_regex = r'<(.+?)>'
82 | entity_names = re.findall(entity_regex, sentence)
83 | tags = re.findall(tag_regex, sentence)
84 |
85 | results = {}
86 | for entity, tag in zip(entity_names, tags):
87 | if "<" in entity:
88 | results[entity.split("> ")[1]] = tag
89 | else:
90 | results[entity] = tag
91 | return results
92 |
93 | def match(question, prediction, ground_truth):
94 | prediction = prediction.replace('\n','').strip().lower()
95 | for gt in ground_truth:
96 | if prediction in gt.lower():
97 | return 1
98 | return 0
99 |
100 |
101 | def rouge(rouge, prediction: str, ground_truths: List[str]) -> float:
102 | if prediction=="":
103 | return 0.0
104 |
105 | highest = 0.0
106 |
107 | for ref in ground_truths:
108 | if ref=="":
109 | return 0.0
110 | score = rouge.get_scores(ref, prediction)
111 | highest = max(highest, score[0]['rouge-l']['f'])
112 |
113 | return highest
--------------------------------------------------------------------------------
/process_training_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from collections import Counter
4 | from tqdm import tqdm
5 |
6 |
7 | def load_jsonl(file_path):
8 | data = []
9 | with open(file_path, 'r', encoding='utf-8') as file:
10 | for line in file:
11 | data.append(json.loads(line))
12 | return data
13 |
14 |
15 | def save_jsonl(data, file_path):
16 | with open(file_path, 'w', encoding='utf-8') as file:
17 | for item in data:
18 | json.dump(item, file, ensure_ascii=False)
19 | file.write("\n")
20 |
21 |
22 | def process_jsonl(input_file, output_sequence_file, output_normal_file):
23 | result = []
24 | remaining_data = []
25 | length_counters = {length: 0 for length in range(15)}
26 | true_counts = []
27 |
28 | data = load_jsonl(input_file)
29 | unused_data = data.copy()
30 | total_items = len(data)
31 |
32 | with tqdm(total=total_items, desc="Processing", unit="item") as pbar:
33 | while sum(length_counters.values()) < 30000 and unused_data:
34 | item = unused_data.pop(0)
35 | used = False
36 |
37 | true_ids = []
38 | false_ids = []
39 |
40 | for idx, doc in enumerate(item['docs'], start=1):
41 | if doc.get('rerank_label') == "true":
42 | true_ids.append(idx)
43 | else:
44 | false_ids.append(idx)
45 |
46 | item['true_id'] = true_ids
47 | item['false_id'] = false_ids
48 | item['true_count'] = len(true_ids)
49 | true_counts.append(len(true_ids))
50 |
51 | for length in range(15):
52 | if length_counters[length] >= 2000:
53 | continue
54 |
55 | if len(true_ids) > length:
56 | continue
57 |
58 | shuffled_true_id = true_ids.copy()
59 | shuffled_false_id = false_ids.copy()
60 | random.shuffle(shuffled_true_id)
61 | random.shuffle(shuffled_false_id)
62 |
63 | sequence = shuffled_true_id
64 | remaining_length = length - len(shuffled_true_id)
65 | if remaining_length > 0:
66 | sequence += shuffled_false_id[:remaining_length]
67 |
68 | new_item = item.copy()
69 | new_item["sequence_length"] = length
70 | new_item["sequence"] = sequence
71 | result.append(new_item)
72 |
73 | length_counters[length] += 1
74 | used = True
75 |
76 | if sum(length_counters.values()) >= 30000:
77 | break
78 |
79 | if sum(length_counters.values()) >= 30000:
80 | break
81 |
82 | if not used:
83 | remaining_data.append(item)
84 |
85 | pbar.update(1)
86 |
87 | remaining_data.extend(unused_data)
88 |
89 | save_jsonl(result, output_sequence_file)
90 | save_jsonl(remaining_data, output_normal_file)
91 | print(f"Data generation complete. Total sequence items: {len(result)}")
92 | print(f"Remaining data items: {len(remaining_data)}")
93 |
94 | true_count_distribution = Counter(true_counts)
95 | print("True Count Distribution:")
96 | for count, freq in sorted(true_count_distribution.items()):
97 | print(f"True Count = {count}: Frequency = {freq}")
98 |
99 | # 主流程
100 | if __name__ == "__main__":
101 | input_jsonl_file = 'training_data/retrieval_data_rerank.jsonl'
102 | output_sequence_file = 'training_data/retrieval_data_rerank_sequence.jsonl'
103 | output_normal_file = 'training_data/retrieval_data_rerank_normal.jsonl'
104 | process_jsonl(input_jsonl_file, output_sequence_file, output_normal_file)
105 |
--------------------------------------------------------------------------------
/src/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | class Gather(torch.autograd.Function):
8 | @staticmethod
9 | def forward(ctx, x: torch.tensor):
10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
11 | dist.all_gather(output, x)
12 | return tuple(output)
13 |
14 | @staticmethod
15 | def backward(ctx, *grads):
16 | all_gradients = torch.stack(grads)
17 | dist.all_reduce(all_gradients)
18 | return all_gradients[dist.get_rank()]
19 |
20 |
21 | def gather(x: torch.tensor):
22 | if not dist.is_initialized():
23 | return x
24 | x_gather = Gather.apply(x)
25 | x_gather = torch.cat(x_gather, dim=0)
26 | return x_gather
27 |
28 |
29 | @torch.no_grad()
30 | def gather_nograd(x: torch.tensor):
31 | if not dist.is_initialized():
32 | return x
33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())]
34 | dist.all_gather(x_gather, x, async_op=False)
35 |
36 | x_gather = torch.cat(x_gather, dim=0)
37 | return x_gather
38 |
39 |
40 | @torch.no_grad()
41 | def varsize_gather_nograd(x: torch.Tensor):
42 | """gather tensors of different sizes along the first dimension"""
43 | if not dist.is_initialized():
44 | return x
45 |
46 | # determine max size
47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
49 | dist.all_gather(allsizes, size)
50 | max_size = max([size.cpu().max() for size in allsizes])
51 |
52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
53 | padded[: x.shape[0]] = x
54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
55 | dist.all_gather(output, padded)
56 |
57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
58 | output = torch.cat(output, dim=0)
59 |
60 | return output
61 |
62 |
63 | @torch.no_grad()
64 | def get_varsize(x: torch.Tensor):
65 | """gather tensors of different sizes along the first dimension"""
66 | if not dist.is_initialized():
67 | return [x.shape[0]]
68 |
69 | # determine max size
70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
72 | dist.all_gather(allsizes, size)
73 | allsizes = torch.cat(allsizes)
74 | return allsizes
75 |
76 |
77 | def get_rank():
78 | if not dist.is_available():
79 | return 0
80 | if not dist.is_initialized():
81 | return 0
82 | return dist.get_rank()
83 |
84 |
85 | def is_main():
86 | return get_rank() == 0
87 |
88 |
89 | def get_world_size():
90 | if not dist.is_initialized():
91 | return 1
92 | else:
93 | return dist.get_world_size()
94 |
95 |
96 | def barrier():
97 | if dist.is_initialized():
98 | dist.barrier()
99 |
100 |
101 | def average_main(x):
102 | if not dist.is_initialized():
103 | return x
104 | if dist.is_initialized() and dist.get_world_size() > 1:
105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM)
106 | if is_main():
107 | x = x / dist.get_world_size()
108 | return x
109 |
110 |
111 | def sum_main(x):
112 | if not dist.is_initialized():
113 | return x
114 | if dist.is_initialized() and dist.get_world_size() > 1:
115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM)
116 | return x
117 |
118 |
119 | def weighted_average(x, count):
120 | if not dist.is_initialized():
121 | if isinstance(x, torch.Tensor):
122 | x = x.item()
123 | return x, count
124 | t_loss = torch.tensor([x * count]).cuda()
125 | t_total = torch.tensor([count]).cuda()
126 | t_loss = sum_main(t_loss)
127 | t_total = sum_main(t_total)
128 | return (t_loss / t_total).item(), t_total.item()
129 |
--------------------------------------------------------------------------------
/src/slurm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from logging import getLogger
8 | import os
9 | import sys
10 | import torch
11 | import socket
12 | import signal
13 | import subprocess
14 |
15 |
16 | logger = getLogger()
17 |
18 | def sig_handler(signum, frame):
19 | logger.warning("Signal handler called with signal " + str(signum))
20 | prod_id = int(os.environ['SLURM_PROCID'])
21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
22 | if prod_id == 0:
23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
25 | else:
26 | logger.warning("Not the main process, no need to requeue.")
27 | sys.exit(-1)
28 |
29 |
30 | def term_handler(signum, frame):
31 | logger.warning("Signal handler called with signal " + str(signum))
32 | logger.warning("Bypassing SIGTERM.")
33 |
34 |
35 | def init_signal_handler():
36 | """
37 | Handle signals sent by SLURM for time limit / pre-emption.
38 | """
39 | signal.signal(signal.SIGUSR1, sig_handler)
40 | signal.signal(signal.SIGTERM, term_handler)
41 |
42 |
43 | def init_distributed_mode(params):
44 | """
45 | Handle single and multi-GPU / multi-node / SLURM jobs.
46 | Initialize the following variables:
47 | - local_rank
48 | - global_rank
49 | - world_size
50 | """
51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ
52 | has_local_rank = hasattr(params, 'local_rank')
53 |
54 | # SLURM job without torch.distributed.launch
55 | if is_slurm_job and has_local_rank:
56 |
57 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM
58 |
59 | # local rank on the current node / global rank
60 | params.local_rank = int(os.environ['SLURM_LOCALID'])
61 | params.global_rank = int(os.environ['SLURM_PROCID'])
62 | params.world_size = int(os.environ['SLURM_NTASKS'])
63 |
64 | # define master address and master port
65 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
66 | params.main_addr = hostnames.split()[0].decode('utf-8')
67 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1
68 |
69 | # set environment variables for 'env://'
70 | os.environ['MASTER_ADDR'] = params.main_addr
71 | os.environ['MASTER_PORT'] = str(params.main_port)
72 | os.environ['WORLD_SIZE'] = str(params.world_size)
73 | os.environ['RANK'] = str(params.global_rank)
74 | is_distributed = True
75 |
76 |
77 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
78 | elif has_local_rank and params.local_rank != -1:
79 |
80 | assert params.main_port == -1
81 |
82 | # read environment variables
83 | params.global_rank = int(os.environ['RANK'])
84 | params.world_size = int(os.environ['WORLD_SIZE'])
85 |
86 | is_distributed = True
87 |
88 | # local job (single GPU)
89 | else:
90 | params.local_rank = 0
91 | params.global_rank = 0
92 | params.world_size = 1
93 | is_distributed = False
94 |
95 | # set GPU device
96 | torch.cuda.set_device(params.local_rank)
97 |
98 | # initialize multi-GPU
99 | if is_distributed:
100 |
101 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
102 | # 'env://' will read these environment variables:
103 | # MASTER_PORT - required; has to be a free port on machine with rank 0
104 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node
105 | # WORLD_SIZE - required; can be set either here, or in a call to init function
106 | # RANK - required; can be set either here, or in a call to init function
107 |
108 | #print("Initializing PyTorch distributed ...")
109 | torch.distributed.init_process_group(
110 | init_method='env://',
111 | backend='nccl',
112 | #world_size=params.world_size,
113 | #rank=params.global_rank,
114 | )
--------------------------------------------------------------------------------
/construct_generator_sft.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 |
5 | def read_json(file_path):
6 | with open(file_path, 'r', encoding='utf-8') as file:
7 | return json.load(file)
8 |
9 | def format_data(data):
10 | error_data = 0
11 | formatted_data = []
12 |
13 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules:
14 |
15 | 1. **Task**:
16 | - Use the provided `[Retrieved Content]` to generate responses.
17 | - If the Retrieved Content is None, you should generate answer based on your own knowledge.
18 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.”
19 |
20 | 2. **Requirements**:
21 | - **Accuracy**: Base your answers on the retrieved content.
22 | - **Conciseness**: Keep answers brief and relevant.
23 | - **Context Awareness**: Ensure your responses align with the user’s query.
24 |
25 | 3. **Input Format**:
26 | - Query: `[User Query]`
27 | - Retrieved: `[Retrieved Content]`
28 |
29 | 4. **Output Format**:
30 | - A structured, clear response tailored to the query.
31 |
32 | Always prioritize clarity and reliability."""
33 | for item in data:
34 |
35 | item_id = item['id']
36 |
37 | if 'fever' in item_id:
38 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.'
39 | elif 'nq' in item_id:
40 | remember_prompt = 'Please answer the question with a short phrase.'
41 | elif 'hotpotqa' in item_id:
42 | remember_prompt = 'Please answer the question with a short phrase.'
43 | elif 'eli5' in item_id:
44 | remember_prompt = 'Please answer the question with a paragraph.'
45 | elif 'tc' in item_id:
46 | remember_prompt = 'Please answer the question with a short phrase.'
47 | elif 'asqa' in item_id:
48 | remember_prompt = 'Please answer the question with a short phrase.'
49 | else:
50 | remember_prompt = 'Please answer the question with a short phrase.'
51 |
52 | true_ids = item.get('true_id', [])[:15]
53 | true_docs = [item['docs'][i - 1] for i in true_ids if i <= len(item['docs'])]
54 |
55 | if len(true_docs) == 0:
56 | retrieved_content = "Retrieved Content: None"
57 | else:
58 | retrieved_content = "Retrieved Content:\n" + "\n".join(
59 | [
60 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}"
61 | for i, doc in enumerate(true_docs)
62 | ]
63 | )
64 |
65 | if 'tc' in item_id:
66 | filtered_answers = [item['answer'].strip()]
67 | if type(item['answer']) == list:
68 | filtered_answers = [ans.strip() for ans in item.get('answer', []) if ans.strip()]
69 | elif type(item['answer']) == str:
70 | filtered_answers = [item['answer'].strip()]
71 | else:
72 | error_data += 1
73 | continue
74 |
75 | if filtered_answers:
76 | first_answer = filtered_answers[0]
77 | else:
78 | error_data += 1
79 | continue
80 |
81 | instruction_data = {
82 | "instruction": f"{remember_prompt}\nQuery: {item['question']}\n\n{retrieved_content}",
83 | "input": "",
84 | "output": first_answer,
85 | "system": system_prompt
86 | }
87 |
88 | formatted_data.append(instruction_data)
89 |
90 | print(f"Error Data Count: {error_data}")
91 | return formatted_data
92 |
93 |
94 | def write_json(file_path, data):
95 | with open(file_path, 'w', encoding='utf-8') as file:
96 | json.dump(data, file, ensure_ascii=False, indent=4)
97 |
98 |
99 | def main(input_file, output_file):
100 | data = read_json(input_file)
101 | random.shuffle(data)
102 | formatted_data = format_data(data)
103 | print(len(formatted_data))
104 | write_json(output_file, formatted_data)
105 |
106 |
107 | if __name__ == "__main__":
108 | input_file = 'training_data/training_data_sft.json'
109 | output_file = 'training_data/generator_sft_training.json'
110 |
111 | main(input_file, output_file)
112 |
--------------------------------------------------------------------------------
/reranker.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from transformers import T5Tokenizer, T5ForConditionalGeneration
4 | from typing import List
5 | import argparse
6 | from tqdm import tqdm
7 |
8 |
9 | class Reranker:
10 | def __init__(self, model_name_or_path: str, device: str = "cuda"):
11 | """
12 | Initializes the Reranker with a specified model.
13 |
14 | Args:
15 | model_name_or_path (str): Path to the pretrained reranker model.
16 | device (str): Device to load the model on ("cuda" or "cpu").
17 | """
18 | self.device = device
19 | print(f"Loading reranker model from: {model_name_or_path}")
20 | self.tokenizer = T5Tokenizer.from_pretrained(model_name_or_path)
21 | self.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)
22 | self.model = self.model.to(self.device)
23 | self.model.eval()
24 | print("Reranker model loaded successfully.")
25 |
26 | def rerank(self, query: str, docs: List[dict]) -> List[str]:
27 | """
28 | Rerank a list of documents based on their relevance to the query.
29 |
30 | Args:
31 | query (str): The input query.
32 | docs (List[dict]): A list of documents, each represented as a dictionary with keys "id" and "text".
33 |
34 | Returns:
35 | List[str]: A list of relevance labels ("true" or "false") for the documents.
36 | """
37 | labels = []
38 | inputs = []
39 |
40 | # Prepare inputs for the model
41 | for doc in docs:
42 | combined_input = f"Query: {query} Document: {doc['text']} Relevant:"
43 | inputs.append(combined_input)
44 |
45 | # Tokenize inputs in batches
46 | batch_size = 128 # Adjust based on available memory
47 | for start in range(0, len(inputs), batch_size):
48 | batch = inputs[start:start + batch_size]
49 | encoded_batch = self.tokenizer.batch_encode_plus(
50 | batch,
51 | return_tensors="pt",
52 | padding=True,
53 | truncation=True,
54 | max_length=512
55 | )
56 | encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
57 |
58 | # Generate labels
59 | with torch.no_grad():
60 | outputs = self.model.generate(
61 | input_ids=encoded_batch["input_ids"],
62 | attention_mask=encoded_batch["attention_mask"],
63 | max_length=2 # Only need a short output like "true" or "false"
64 | )
65 |
66 | # Decode generated sequences to labels
67 | decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
68 | labels.extend([output.strip().lower() for output in decoded_outputs])
69 |
70 | return labels
71 |
72 |
73 | if __name__ == "__main__":
74 | # Define command-line arguments
75 | parser = argparse.ArgumentParser(description="Reranker Demo")
76 | parser.add_argument("--model_name_or_path", type=str, default="",
77 | help="Path to the pretrained reranker model.")
78 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
79 | help="Device to run the reranker on ('cuda' or 'cpu').")
80 | parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSONL file.")
81 | parser.add_argument("--output_file", type=str, required=True, help="Path to save the output JSONL file.")
82 | args = parser.parse_args()
83 |
84 | # Load the reranker model
85 | reranker = Reranker(model_name_or_path=args.model_name_or_path, device=args.device)
86 |
87 | # Process JSONL file
88 | with open(args.input_file, "r", encoding='utf-8') as input_file, open(args.output_file, "a", encoding='utf-8') as output_file:
89 | for line in tqdm(input_file):
90 | if not line.strip():
91 | continue
92 | # Parse the JSON object
93 | example = json.loads(line.strip())
94 | query = example["question"]
95 | docs = example["docs"]
96 |
97 | # Get rerank labels
98 | rerank_labels = reranker.rerank(query=query, docs=docs)
99 |
100 | # Add rerank labels to each document
101 | for doc, label in zip(docs, rerank_labels):
102 | doc["rerank_label"] = label
103 |
104 | # Save the updated example to the output file as a JSONL entry
105 | output_file.write(json.dumps(example, ensure_ascii=False) + "\n")
106 |
107 | print(f"Reranking completed and saved to {args.output_file}.")
108 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from bert_score import score
2 | from rouge import Rouge
3 | import json
4 | import jsonlines
5 | import openai
6 |
7 |
8 | openai.api_key = "Your_API_Key"
9 |
10 | def compute_bert_score(reference: str, candidate: str):
11 | """
12 | Compute the BERTScore between two strings.
13 |
14 | Parameters:
15 | reference (str): The reference string.
16 | candidate (str): The candidate string to evaluate against the reference.
17 |
18 | Returns:
19 | float: The BERTScore (F1 score) between the reference and candidate.
20 | """
21 | if candidate == "" or reference == "":
22 | return 0
23 | try:
24 | precision, recall, f1 = score([candidate], [reference], lang="en", rescale_with_baseline=True) #
25 | except:
26 | return 0
27 | return f1[0].item()
28 |
29 | def compute_rouge_score(rouge_score, reference: str, candidate: str):
30 | """
31 | Compute the ROUGE score between two strings.
32 |
33 | Parameters:
34 | reference (str): The reference string.
35 | candidate (str): The candidate string to evaluate against the reference.
36 |
37 | Returns:
38 | dict: A dictionary containing ROUGE-1, ROUGE-2, and ROUGE-L scores (F1 scores).
39 | """
40 | if candidate == "" or reference == "":
41 | return 0
42 | try:
43 | scores = rouge_score.get_scores(candidate, reference, avg=True)
44 | except:
45 | return 0
46 | return scores["rouge-l"]["f"]
47 |
48 | def compute_exact_match(reference: str, candidate: str):
49 | """
50 | Compute whether two strings are an exact match.
51 |
52 | Parameters:
53 | reference (str): The reference string.
54 | candidate (str): The candidate string to evaluate against the reference.
55 |
56 | Returns:
57 | bool: True if the strings are an exact match, False otherwise.
58 | """
59 | if candidate == "" or reference == "":
60 | return 0
61 | return reference.strip().lower() in candidate.strip().lower() or candidate.strip().lower() in reference.strip().lower() or reference.strip().lower() in candidate.strip().lower()
62 |
63 | def inverse_reward(sentence):
64 | length = len(sentence)
65 | return 1 / (1 + length)
66 |
67 | def llm_eval(instruction: str, reference: str, candidate: str):
68 | prompt = f"""Use the following criteria to evaluate the quality of the model's response in a knowledge-intensive task, considering the provided ground-truth answer. Assign a score between 0-100 based on the overall quality, relevance, and correctness of the response:
69 |
70 | Relevance to the Prompt (20 points):
71 |
72 | Award up to 20 points if the response aligns well with the user's query, even if minor errors are present.
73 | Deduct points if the response lacks focus or deviates significantly from the query.
74 | Accuracy of Factual Information (20 points):
75 |
76 | Grant up to 20 points for correct factual details aligning with the ground-truth answer.
77 | Penalize for inaccuracies, missing essential elements, or presenting incorrect knowledge.
78 | Handling of Temporal and Logical Reasoning (20 points):
79 |
80 | Award up to 20 points for demonstrating correct temporal and logical reasoning.
81 | Deduct points if temporal reasoning is flawed or logical consistency is missing.
82 | Clarity and Coherence of Response (20 points):
83 |
84 | Assign up to 15 points for clear, coherent, and well-structured responses.
85 | Reduce points for ambiguity, confusion, or poor organization.
86 | Potential Misleading Nature or Misconceptions (20 points):
87 |
88 | Award up to 10 points if the response avoids being misleading.
89 | Penalize responses that could confuse or mislead the user, even if partially relevant.
90 | After evaluating the response based on these criteria, provide a total score in the format:
91 | “Score: points”.
92 |
93 | User: {instruction}
94 | Ground-Truth Answer: {reference}
95 | Model Response: {candidate}
96 |
97 | Score: """
98 | response = openai.ChatCompletion.create(
99 | model="gpt-4o",
100 | messages=[
101 | {"role": "system", "content": "You are a helpful Assistant."},
102 | {"role": "user", "content": prompt}
103 | ]
104 | )
105 | return response.choices[0].message["content"].replace('.','').strip()
106 |
107 |
108 |
109 | def load_jsonlines(file):
110 | with jsonlines.open(file, 'r') as jsonl_f:
111 | lst = [obj for obj in jsonl_f]
112 | return lst
113 |
114 | def load_file(input_fp):
115 | if input_fp.endswith(".json"):
116 | input_data = json.load(open(input_fp))
117 | else:
118 | input_data = load_jsonlines(input_fp)
119 | return input_data
120 |
121 | # Example usage
122 | # reference = "Jane Austen"
123 | # candidate = "The author of 'Pride and Prejudice' is Chris Paul."
124 |
125 | # reference = "Eiffel Tower in Paris."
126 | # candidate = "The Eiffel Tower is one of the most famous landmarks in Paris, and it is located in the Champ de Mars."
127 | # candidate = "Eiffel Tower is in the Paris."
128 |
129 | # # Compute BERTScore
130 | # bert_score = compute_bert_score(reference, candidate)
131 | # print(f"BERTScore: {bert_score}")
132 |
133 | # # Compute ROUGE Score
134 | # rouge_score = compute_rouge_score(reference, candidate)
135 | # print(f"ROUGE Score: {rouge_score}")
136 |
--------------------------------------------------------------------------------
/src/normalize_text.py:
--------------------------------------------------------------------------------
1 | """
2 | adapted from chemdataextractor.text.normalize
3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 | Tools for normalizing text.
5 | https://github.com/mcs07/ChemDataExtractor
6 | :copyright: Copyright 2016 by Matt Swain.
7 | :license: MIT
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining
10 | a copy of this software and associated documentation files (the
11 | 'Software'), to deal in the Software without restriction, including
12 | without limitation the rights to use, copy, modify, merge, publish,
13 | distribute, sublicense, and/or sell copies of the Software, and to
14 | permit persons to whom the Software is furnished to do so, subject to
15 | the following conditions:
16 |
17 | The above copyright notice and this permission notice shall be
18 | included in all copies or substantial portions of the Software.
19 |
20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 | """
28 |
29 | #: Control characters.
30 | CONTROLS = {
31 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
32 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
33 | }
34 | # There are further control characters, but they are instead replaced with a space by unicode normalization
35 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
36 |
37 |
38 | #: Hyphen and dash characters.
39 | HYPHENS = {
40 | '-', # \u002d Hyphen-minus
41 | '‐', # \u2010 Hyphen
42 | '‑', # \u2011 Non-breaking hyphen
43 | '⁃', # \u2043 Hyphen bullet
44 | '‒', # \u2012 figure dash
45 | '–', # \u2013 en dash
46 | '—', # \u2014 em dash
47 | '―', # \u2015 horizontal bar
48 | }
49 |
50 | #: Minus characters.
51 | MINUSES = {
52 | '-', # \u002d Hyphen-minus
53 | '−', # \u2212 Minus
54 | '-', # \uff0d Full-width Hyphen-minus
55 | '⁻', # \u207b Superscript minus
56 | }
57 |
58 | #: Plus characters.
59 | PLUSES = {
60 | '+', # \u002b Plus
61 | '+', # \uff0b Full-width Plus
62 | '⁺', # \u207a Superscript plus
63 | }
64 |
65 | #: Slash characters.
66 | SLASHES = {
67 | '/', # \u002f Solidus
68 | '⁄', # \u2044 Fraction slash
69 | '∕', # \u2215 Division slash
70 | }
71 |
72 | #: Tilde characters.
73 | TILDES = {
74 | '~', # \u007e Tilde
75 | '˜', # \u02dc Small tilde
76 | '⁓', # \u2053 Swung dash
77 | '∼', # \u223c Tilde operator #in mbert vocab
78 | '∽', # \u223d Reversed tilde
79 | '∿', # \u223f Sine wave
80 | '〜', # \u301c Wave dash #in mbert vocab
81 | '~', # \uff5e Full-width tilde #in mbert vocab
82 | }
83 |
84 | #: Apostrophe characters.
85 | APOSTROPHES = {
86 | "'", # \u0027
87 | '’', # \u2019
88 | '՚', # \u055a
89 | 'Ꞌ', # \ua78b
90 | 'ꞌ', # \ua78c
91 | ''', # \uff07
92 | }
93 |
94 | #: Single quote characters.
95 | SINGLE_QUOTES = {
96 | "'", # \u0027
97 | '‘', # \u2018
98 | '’', # \u2019
99 | '‚', # \u201a
100 | '‛', # \u201b
101 |
102 | }
103 |
104 | #: Double quote characters.
105 | DOUBLE_QUOTES = {
106 | '"', # \u0022
107 | '“', # \u201c
108 | '”', # \u201d
109 | '„', # \u201e
110 | '‟', # \u201f
111 | }
112 |
113 | #: Accent characters.
114 | ACCENTS = {
115 | '`', # \u0060
116 | '´', # \u00b4
117 | }
118 |
119 | #: Prime characters.
120 | PRIMES = {
121 | '′', # \u2032
122 | '″', # \u2033
123 | '‴', # \u2034
124 | '‵', # \u2035
125 | '‶', # \u2036
126 | '‷', # \u2037
127 | '⁗', # \u2057
128 | }
129 |
130 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
131 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
132 |
133 | def normalize(text):
134 | for control in CONTROLS:
135 | text = text.replace(control, '')
136 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
137 |
138 | for hyphen in HYPHENS | MINUSES:
139 | text = text.replace(hyphen, '-')
140 | text = text.replace('\u00ad', '')
141 |
142 | for double_quote in DOUBLE_QUOTES:
143 | text = text.replace(double_quote, '"') # \u0022
144 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
145 | text = text.replace(single_quote, "'") # \u0027
146 | text = text.replace('′', "'") # \u2032 prime
147 | text = text.replace('‵', "'") # \u2035 reversed prime
148 | text = text.replace('″', "''") # \u2033 double prime
149 | text = text.replace('‶', "''") # \u2036 reversed double prime
150 | text = text.replace('‴', "'''") # \u2034 triple prime
151 | text = text.replace('‷', "'''") # \u2037 reversed triple prime
152 | text = text.replace('⁗', "''''") # \u2057 quadruple prime
153 |
154 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
155 |
156 | for slash in SLASHES:
157 | text = text.replace(slash, '/')
158 |
159 | #for tilde in TILDES:
160 | # text = text.replace(tilde, '~')
161 |
162 | return text
163 |
--------------------------------------------------------------------------------
/src/moco.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import logging
6 | import copy
7 | import transformers
8 |
9 | from src import contriever, dist_utils, utils
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class MoCo(nn.Module):
15 | def __init__(self, opt):
16 | super(MoCo, self).__init__()
17 |
18 | self.queue_size = opt.queue_size
19 | self.momentum = opt.momentum
20 | self.temperature = opt.temperature
21 | self.label_smoothing = opt.label_smoothing
22 | self.norm_doc = opt.norm_doc
23 | self.norm_query = opt.norm_query
24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode
25 |
26 | retriever, tokenizer = self._load_retriever(
27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
28 | )
29 |
30 | self.tokenizer = tokenizer
31 | self.encoder_q = retriever
32 | self.encoder_k = copy.deepcopy(retriever)
33 |
34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
35 | param_k.data.copy_(param_q.data)
36 | param_k.requires_grad = False
37 |
38 | # create the queue
39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size))
40 | self.queue = nn.functional.normalize(self.queue, dim=0)
41 |
42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
43 |
44 | def _load_retriever(self, model_id, pooling, random_init):
45 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
47 |
48 | if "xlm" in model_id:
49 | model_class = contriever.XLMRetriever
50 | else:
51 | model_class = contriever.Contriever
52 |
53 | if random_init:
54 | retriever = model_class(cfg)
55 | else:
56 | retriever = utils.load_hf(model_class, model_id)
57 |
58 | if "bert-" in model_id:
59 | if tokenizer.bos_token_id is None:
60 | tokenizer.bos_token = "[CLS]"
61 | if tokenizer.eos_token_id is None:
62 | tokenizer.eos_token = "[SEP]"
63 |
64 | retriever.config.pooling = pooling
65 |
66 | return retriever, tokenizer
67 |
68 | def get_encoder(self, return_encoder_k=False):
69 | if return_encoder_k:
70 | return self.encoder_k
71 | else:
72 | return self.encoder_q
73 |
74 | def _momentum_update_key_encoder(self):
75 | """
76 | Update of the key encoder
77 | """
78 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
79 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
80 |
81 | @torch.no_grad()
82 | def _dequeue_and_enqueue(self, keys):
83 | # gather keys before updating queue
84 | keys = dist_utils.gather_nograd(keys.contiguous())
85 |
86 | batch_size = keys.shape[0]
87 |
88 | ptr = int(self.queue_ptr)
89 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity
90 |
91 | # replace the keys at ptr (dequeue and enqueue)
92 | self.queue[:, ptr : ptr + batch_size] = keys.T
93 | ptr = (ptr + batch_size) % self.queue_size # move pointer
94 |
95 | self.queue_ptr[0] = ptr
96 |
97 | def _compute_logits(self, q, k):
98 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
99 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
100 |
101 | logits = torch.cat([l_pos, l_neg], dim=1)
102 | return logits
103 |
104 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
105 | bsz = q_tokens.size(0)
106 |
107 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
108 |
109 | # compute key features
110 | with torch.no_grad(): # no gradient to keys
111 | self._momentum_update_key_encoder() # update the key encoder
112 |
113 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k:
114 | self.encoder_k.eval()
115 |
116 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
117 |
118 | logits = self._compute_logits(q, k) / self.temperature
119 |
120 | # labels: positive key indicators
121 | labels = torch.zeros(bsz, dtype=torch.long).cuda()
122 |
123 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
124 |
125 | self._dequeue_and_enqueue(k)
126 |
127 | # log stats
128 | if len(stats_prefix) > 0:
129 | stats_prefix = stats_prefix + "/"
130 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
131 |
132 | predicted_idx = torch.argmax(logits, dim=-1)
133 | accuracy = 100 * (predicted_idx == labels).float().mean()
134 | stdq = torch.std(q, dim=0).mean().item()
135 | stdk = torch.std(k, dim=0).mean().item()
136 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
137 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
138 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
139 |
140 | return loss, iter_stats
141 |
--------------------------------------------------------------------------------
/src/contriever.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import torch
5 | import transformers
6 | from transformers import BertModel, XLMRobertaModel
7 |
8 | from src import utils
9 |
10 |
11 | class Contriever(BertModel):
12 | def __init__(self, config, pooling="average", **kwargs):
13 | super().__init__(config, add_pooling_layer=False)
14 | if not hasattr(config, "pooling"):
15 | self.config.pooling = pooling
16 |
17 | def forward(
18 | self,
19 | input_ids=None,
20 | attention_mask=None,
21 | token_type_ids=None,
22 | position_ids=None,
23 | head_mask=None,
24 | inputs_embeds=None,
25 | encoder_hidden_states=None,
26 | encoder_attention_mask=None,
27 | output_attentions=None,
28 | output_hidden_states=None,
29 | normalize=False,
30 | ):
31 |
32 | model_output = super().forward(
33 | input_ids=input_ids,
34 | attention_mask=attention_mask,
35 | token_type_ids=token_type_ids,
36 | position_ids=position_ids,
37 | head_mask=head_mask,
38 | inputs_embeds=inputs_embeds,
39 | encoder_hidden_states=encoder_hidden_states,
40 | encoder_attention_mask=encoder_attention_mask,
41 | output_attentions=output_attentions,
42 | output_hidden_states=output_hidden_states,
43 | )
44 |
45 | last_hidden = model_output["last_hidden_state"]
46 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
47 |
48 | if self.config.pooling == "average":
49 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
50 | elif self.config.pooling == "cls":
51 | emb = last_hidden[:, 0]
52 |
53 | if normalize:
54 | emb = torch.nn.functional.normalize(emb, dim=-1)
55 | return emb
56 |
57 |
58 | class XLMRetriever(XLMRobertaModel):
59 | def __init__(self, config, pooling="average", **kwargs):
60 | super().__init__(config, add_pooling_layer=False)
61 | if not hasattr(config, "pooling"):
62 | self.config.pooling = pooling
63 |
64 | def forward(
65 | self,
66 | input_ids=None,
67 | attention_mask=None,
68 | token_type_ids=None,
69 | position_ids=None,
70 | head_mask=None,
71 | inputs_embeds=None,
72 | encoder_hidden_states=None,
73 | encoder_attention_mask=None,
74 | output_attentions=None,
75 | output_hidden_states=None,
76 | normalize=False,
77 | ):
78 |
79 | model_output = super().forward(
80 | input_ids=input_ids,
81 | attention_mask=attention_mask,
82 | token_type_ids=token_type_ids,
83 | position_ids=position_ids,
84 | head_mask=head_mask,
85 | inputs_embeds=inputs_embeds,
86 | encoder_hidden_states=encoder_hidden_states,
87 | encoder_attention_mask=encoder_attention_mask,
88 | output_attentions=output_attentions,
89 | output_hidden_states=output_hidden_states,
90 | )
91 |
92 | last_hidden = model_output["last_hidden_state"]
93 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
94 | if self.config.pooling == "average":
95 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
96 | elif self.config.pooling == "cls":
97 | emb = last_hidden[:, 0]
98 | if normalize:
99 | emb = torch.nn.functional.normalize(emb, dim=-1)
100 | return emb
101 |
102 |
103 | def load_retriever(model_path, pooling="average", random_init=False):
104 | # try: check if model exists locally
105 | path = os.path.join(model_path, "checkpoint.pth")
106 | if os.path.exists(path):
107 | pretrained_dict = torch.load(path, map_location="cpu")
108 | opt = pretrained_dict["opt"]
109 | if hasattr(opt, "retriever_model_id"):
110 | retriever_model_id = opt.retriever_model_id
111 | else:
112 | # retriever_model_id = "bert-base-uncased"
113 | retriever_model_id = "bert-base-multilingual-cased"
114 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id)
115 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id)
116 | if "xlm" in retriever_model_id:
117 | model_class = XLMRetriever
118 | else:
119 | model_class = Contriever
120 | retriever = model_class(cfg)
121 | pretrained_dict = pretrained_dict["model"]
122 |
123 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class
124 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k}
125 | elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class
126 | pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k}
127 | retriever.load_state_dict(pretrained_dict, strict=False)
128 | else:
129 | retriever_model_id = model_path
130 | if "xlm" in retriever_model_id:
131 | model_class = XLMRetriever
132 | else:
133 | model_class = Contriever
134 | cfg = utils.load_hf(transformers.AutoConfig, model_path)
135 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path)
136 | retriever = utils.load_hf(model_class, model_path)
137 |
138 | return retriever, tokenizer, retriever_model_id
139 |
--------------------------------------------------------------------------------
/src/finetuning_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import random
5 | import json
6 | import sys
7 | import numpy as np
8 | from src import normalize_text
9 |
10 |
11 | class Dataset(torch.utils.data.Dataset):
12 | def __init__(
13 | self,
14 | datapaths,
15 | negative_ctxs=1,
16 | negative_hard_ratio=0.0,
17 | negative_hard_min_idx=0,
18 | training=False,
19 | global_rank=-1,
20 | world_size=-1,
21 | maxload=None,
22 | normalize=False,
23 | ):
24 | self.negative_ctxs = negative_ctxs
25 | self.negative_hard_ratio = negative_hard_ratio
26 | self.negative_hard_min_idx = negative_hard_min_idx
27 | self.training = training
28 | self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x
29 | self._load_data(datapaths, global_rank, world_size, maxload)
30 |
31 | def __len__(self):
32 | return len(self.data)
33 |
34 | def __getitem__(self, index):
35 | example = self.data[index]
36 | question = example["question"]
37 | if self.training:
38 | gold = random.choice(example["positive_ctxs"])
39 |
40 | n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example)
41 | negatives = []
42 | if n_random_negatives > 0:
43 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives)
44 | negatives += random_negatives
45 | if n_hard_negatives > 0:
46 | hard_negatives = random.sample(
47 | example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives
48 | )
49 | negatives += hard_negatives
50 | else:
51 | gold = example["positive_ctxs"][0]
52 | nidx = 0
53 | if "negative_ctxs" in example:
54 | negatives = [example["negative_ctxs"][nidx]]
55 | else:
56 | negatives = []
57 |
58 | gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"]
59 |
60 | negatives = [
61 | n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives
62 | ]
63 |
64 | example = {
65 | "query": self.normalize_fn(question),
66 | "gold": self.normalize_fn(gold),
67 | "negatives": [self.normalize_fn(n) for n in negatives],
68 | }
69 | return example
70 |
71 | def _load_data(self, datapaths, global_rank, world_size, maxload):
72 | counter = 0
73 | self.data = []
74 | for path in datapaths:
75 | path = str(path)
76 | if path.endswith(".jsonl"):
77 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
78 | elif path.endswith(".json"):
79 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
80 | self.data.extend(file_data)
81 | if maxload is not None and maxload > 0 and counter >= maxload:
82 | break
83 |
84 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
85 | examples = []
86 | with open(path, "r") as fin:
87 | data = json.load(fin)
88 | for example in data:
89 | counter += 1
90 | if global_rank > -1 and not counter % world_size == global_rank:
91 | continue
92 | examples.append(example)
93 | if maxload is not None and maxload > 0 and counter == maxload:
94 | break
95 |
96 | return examples, counter
97 |
98 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
99 | examples = []
100 | with open(path, "r") as fin:
101 | for line in fin:
102 | counter += 1
103 | if global_rank > -1 and not counter % world_size == global_rank:
104 | continue
105 | example = json.loads(line)
106 | examples.append(example)
107 | if maxload is not None and maxload > 0 and counter == maxload:
108 | break
109 |
110 | return examples, counter
111 |
112 | def sample_n_hard_negatives(self, ex):
113 |
114 | if "hard_negative_ctxs" in ex:
115 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)])
116 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :]))
117 | else:
118 | n_hard_negatives = 0
119 | n_random_negatives = self.negative_ctxs - n_hard_negatives
120 | if "negative_ctxs" in ex:
121 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"]))
122 | else:
123 | n_random_negatives = 0
124 | return n_hard_negatives, n_random_negatives
125 |
126 |
127 | class Collator(object):
128 | def __init__(self, tokenizer, passage_maxlength=200):
129 | self.tokenizer = tokenizer
130 | self.passage_maxlength = passage_maxlength
131 |
132 | def __call__(self, batch):
133 | queries = [ex["query"] for ex in batch]
134 | golds = [ex["gold"] for ex in batch]
135 | negs = [item for ex in batch for item in ex["negatives"]]
136 | allpassages = golds + negs
137 |
138 | qout = self.tokenizer.batch_encode_plus(
139 | queries,
140 | max_length=self.passage_maxlength,
141 | truncation=True,
142 | padding=True,
143 | add_special_tokens=True,
144 | return_tensors="pt",
145 | )
146 | kout = self.tokenizer.batch_encode_plus(
147 | allpassages,
148 | max_length=self.passage_maxlength,
149 | truncation=True,
150 | padding=True,
151 | add_special_tokens=True,
152 | return_tensors="pt",
153 | )
154 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool()
155 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool()
156 |
157 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)]
158 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :]
159 |
160 | batch = {
161 | "q_tokens": q_tokens,
162 | "q_mask": q_mask,
163 | "k_tokens": k_tokens,
164 | "k_mask": k_mask,
165 | "g_tokens": g_tokens,
166 | "g_mask": g_mask,
167 | "n_tokens": n_tokens,
168 | "n_mask": n_mask,
169 | }
170 |
171 | return batch
172 |
--------------------------------------------------------------------------------
/sampling_dpo_trajectories.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import argparse
4 | from vllm import LLM, SamplingParams
5 | from tqdm import tqdm
6 |
7 | parser = argparse.ArgumentParser(description="Process JSON data using an LLM model.")
8 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.")
9 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.")
10 | parser.add_argument("--input-json", type=str, required=True, help="Path to the input JSON file.")
11 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.")
12 |
13 | args = parser.parse_args()
14 |
15 | with open(args.template, "r") as f:
16 | chat_template = f.read()
17 |
18 | llm = LLM(model=args.llm_model, tensor_parallel_size=1)
19 | sampling_params = SamplingParams(temperature=0.4, n=8, max_tokens=100)
20 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200)
21 |
22 | def get_prompt_docs(query, docs):
23 | retrieved_content = "Retrieved Content:\n" + "\n".join(
24 | [
25 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}"
26 | for i, doc in enumerate(docs)
27 | ]
28 | )
29 | return [
30 | {'role': 'system',
31 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n"
32 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n"
33 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n"
34 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n"
35 | "If no documents are required to answer the query, output \"None\".\n"
36 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n"
37 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")},
38 | {'role': 'user',
39 | 'content': f"Query: {query}\n\n{retrieved_content}"}
40 | ]
41 |
42 | def get_prompt_answer(prompt, entry):
43 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules:
44 | 1. **Task**:
45 | - Use the provided `[Retrieved Content]` to generate responses.
46 | - If the Retrieved Content is None, you should generate answer based on your own knowledge.
47 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.”
48 | 2. **Requirements**:
49 | - **Accuracy**: Base your answers on the retrieved content.
50 | - **Conciseness**: Keep answers brief and relevant.
51 | - **Context Awareness**: Ensure your responses align with the user’s query.
52 | 3. **Input Format**:
53 | - Query: `[User Query]`
54 | - Retrieved: `[Retrieved Content]`
55 | 4. **Output Format**:
56 | - A structured, clear response tailored to the query.
57 | Always prioritize clarity and reliability."""
58 |
59 | item_id = entry['id']
60 | if 'fever' in item_id:
61 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.'
62 | elif 'nq' in item_id:
63 | remember_prompt = 'Please answer the question with a short phrase.'
64 | elif 'hotpotqa' in item_id:
65 | remember_prompt = 'Please answer the question with a short phrase.'
66 | elif 'eli5' in item_id:
67 | remember_prompt = 'Please answer the question with a paragraph.'
68 | elif 'tc' in item_id:
69 | remember_prompt = 'Please answer the question with a short phrase.'
70 | elif 'asqa' in item_id:
71 | remember_prompt = 'Please answer the question with a short phrase.'
72 | else:
73 | remember_prompt = 'Please answer the question with a short phrase.'
74 |
75 | return [
76 | {'role': 'system',
77 | 'content': system_prompt},
78 | {'role': 'user',
79 | 'content': (f"{remember_prompt}\n"
80 | f"Query: {entry['question']}\n\n"
81 | f"{prompt}")}
82 | ]
83 |
84 | def generate_docs(entry):
85 | messages = get_prompt_docs(entry['question'], entry['docs'][:40])
86 |
87 | outputs = llm.chat(messages,
88 | sampling_params=sampling_params,
89 | use_tqdm=False,
90 | chat_template=chat_template)
91 |
92 | return [output.text for output in outputs[0].outputs]
93 |
94 | def map_function(data_list):
95 | result = []
96 | for item in data_list:
97 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)]
98 | result.append(ids)
99 | return result
100 |
101 | def convert(id_lists, docs):
102 | prompts = []
103 | for ids in id_lists:
104 | lines = []
105 | if len(ids) == 0:
106 | retrieved_content = "Retrieved Content: None"
107 | prompts.append(retrieved_content)
108 | continue
109 | for doc_id in ids:
110 | doc_title = docs[doc_id - 1]['title']
111 | doc_text = docs[doc_id - 1]['text']
112 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}"
113 | lines.append(line)
114 |
115 | retrieved_content = "\n".join(lines)
116 | prompt = f"Retrieved Content:\n{retrieved_content}"
117 | prompts.append(prompt)
118 | return prompts
119 |
120 | def generate_answers(entry, prompts):
121 | results = []
122 | for prompt in prompts:
123 | messages = get_prompt_answer(prompt, entry)
124 |
125 | outputs = llm.chat(messages,
126 | sampling_params=sampling_params_generate,
127 | use_tqdm=False,
128 | chat_template=chat_template)
129 | results.append(outputs[0].outputs[0].text)
130 | return results
131 |
132 | def process_json(input_file, output_file):
133 | with open(input_file, 'r', encoding='utf-8') as infile:
134 | data = json.load(infile)
135 |
136 | with open(output_file, 'a', encoding='utf-8') as outfile:
137 | for entry in tqdm(data):
138 | generated_list = generate_docs(entry)
139 | id_lists = map_function(generated_list)
140 | try:
141 | prompts = convert(id_lists, entry['docs'])
142 | except IndexError:
143 | continue
144 | answers = generate_answers(entry, prompts)
145 |
146 | entry['responses'] = answers
147 | entry['id_lists'] = id_lists
148 | import ipdb; ipdb.set_trace()
149 | json.dump(entry, outfile, ensure_ascii=False)
150 | outfile.write('\n')
151 |
152 | if __name__ == "__main__":
153 | process_json(args.input_json, args.output_json)
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import argparse
4 | import os
5 |
6 |
7 | class Options:
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | self.initialize()
11 |
12 | def initialize(self):
13 | # basic parameters
14 | self.parser.add_argument(
15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here"
16 | )
17 | self.parser.add_argument(
18 | "--train_data",
19 | nargs="+",
20 | default=[],
21 | help="Data used for training, passed as a list of directories splitted into tensor files.",
22 | )
23 | self.parser.add_argument(
24 | "--eval_data",
25 | nargs="+",
26 | default=[],
27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.",
28 | )
29 | self.parser.add_argument(
30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format"
31 | )
32 | self.parser.add_argument(
33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored"
34 | )
35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining")
36 | self.parser.add_argument("--continue_training", action="store_true")
37 | self.parser.add_argument("--num_workers", type=int, default=5)
38 |
39 | self.parser.add_argument("--chunk_length", type=int, default=256)
40 | self.parser.add_argument("--loading_mode", type=str, default="split")
41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing")
42 | self.parser.add_argument(
43 | "--sampling_coefficient",
44 | type=float,
45 | default=0.0,
46 | help="coefficient used for sampling between different datasets during training, \
47 | by default sampling is uniform over datasets",
48 | )
49 | self.parser.add_argument("--augmentation", type=str, default="none")
50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0)
51 |
52 | self.parser.add_argument("--dropout", type=float, default=0.1)
53 | self.parser.add_argument("--rho", type=float, default=0.05)
54 |
55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco")
56 | self.parser.add_argument("--queue_size", type=int, default=65536)
57 | self.parser.add_argument("--temperature", type=float, default=1.0)
58 | self.parser.add_argument("--momentum", type=float, default=0.999)
59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true")
60 | self.parser.add_argument("--eval_normalize_text", action="store_true")
61 | self.parser.add_argument("--norm_query", action="store_true")
62 | self.parser.add_argument("--norm_doc", action="store_true")
63 | self.parser.add_argument("--projection_size", type=int, default=768)
64 |
65 | self.parser.add_argument("--ratio_min", type=float, default=0.1)
66 | self.parser.add_argument("--ratio_max", type=float, default=0.5)
67 | self.parser.add_argument("--score_function", type=str, default="dot")
68 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased")
69 | self.parser.add_argument("--pooling", type=str, default="average")
70 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights")
71 |
72 | # dataset parameters
73 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.")
74 | self.parser.add_argument(
75 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation."
76 | )
77 | self.parser.add_argument("--total_steps", type=int, default=1000)
78 | self.parser.add_argument("--warmup_steps", type=int, default=-1)
79 |
80 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
81 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)")
82 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization")
83 | # training parameters
84 | self.parser.add_argument("--optim", type=str, default="adamw")
85 | self.parser.add_argument("--scheduler", type=str, default="linear")
86 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
87 | self.parser.add_argument(
88 | "--lr_min_ratio",
89 | type=float,
90 | default=0.0,
91 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate",
92 | )
93 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate")
94 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1")
95 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2")
96 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps")
97 | self.parser.add_argument(
98 | "--log_freq", type=int, default=100, help="log train stats every steps during training"
99 | )
100 | self.parser.add_argument(
101 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training"
102 | )
103 | self.parser.add_argument("--save_freq", type=int, default=50000)
104 | self.parser.add_argument("--maxload", type=int, default=None)
105 | self.parser.add_argument("--label_smoothing", type=float, default=0.0)
106 |
107 | # finetuning options
108 | self.parser.add_argument("--negative_ctxs", type=int, default=1)
109 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0)
110 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0)
111 |
112 | def print_options(self, opt):
113 | message = ""
114 | for k, v in sorted(vars(opt).items()):
115 | comment = ""
116 | default = self.parser.get_default(k)
117 | if v != default:
118 | comment = f"\t[default: %s]" % str(default)
119 | message += f"{str(k):>40}: {str(v):<40}{comment}\n"
120 | print(message, flush=True)
121 | model_dir = os.path.join(opt.output_dir, "models")
122 | if not os.path.exists(model_dir):
123 | os.makedirs(os.path.join(opt.output_dir, "models"))
124 | file_name = os.path.join(opt.output_dir, "opt.txt")
125 | with open(file_name, "wt") as opt_file:
126 | opt_file.write(message)
127 | opt_file.write("\n")
128 |
129 | def parse(self):
130 | opt, _ = self.parser.parse_known_args()
131 | # opt = self.parser.parse_args()
132 | return opt
133 |
--------------------------------------------------------------------------------
/src/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import collections
9 | import logging
10 | import regex
11 | import string
12 | import unicodedata
13 | from functools import partial
14 | from multiprocessing import Pool as ProcessPool
15 | from typing import Tuple, List, Dict
16 | import numpy as np
17 |
18 | """
19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR
20 | """
21 |
22 | class SimpleTokenizer(object):
23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
24 | NON_WS = r'[^\p{Z}\p{C}]'
25 |
26 | def __init__(self):
27 | """
28 | Args:
29 | annotators: None or empty set (only tokenizes).
30 | """
31 | self._regexp = regex.compile(
32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
34 | )
35 |
36 | def tokenize(self, text, uncased=False):
37 | matches = [m for m in self._regexp.finditer(text)]
38 | if uncased:
39 | tokens = [m.group().lower() for m in matches]
40 | else:
41 | tokens = [m.group() for m in matches]
42 | return tokens
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
47 |
48 | def calculate_matches(data: List, workers_num: int):
49 | """
50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results
52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
53 | :param answers: list of answers's list. One list per question
54 | :param closest_docs: document ids of the top results along with their scores
55 | :param workers_num: amount of parallel threads to process data
56 | :param match_type: type of answer matching. Refer to has_answer code for available options
57 | :return: matching information tuple.
58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
59 | valid matches across an entire dataset.
60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
61 | """
62 |
63 | logger.info('Matching answers in top docs...')
64 |
65 | tokenizer = SimpleTokenizer()
66 | get_score_partial = partial(check_answer, tokenizer=tokenizer)
67 |
68 | processes = ProcessPool(processes=workers_num)
69 | scores = processes.map(get_score_partial, data)
70 |
71 | logger.info('Per question validation results len=%d', len(scores))
72 |
73 | n_docs = len(data[0]['ctxs'])
74 | top_k_hits = [0] * n_docs
75 | for question_hits in scores:
76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None)
77 | if best_hit is not None:
78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
79 |
80 | return QAMatchStats(top_k_hits, scores)
81 |
82 | def check_answer(example, tokenizer) -> List[bool]:
83 | """Search through all the top docs to see if they have any of the answers."""
84 | answers = example['answers']
85 | ctxs = example['ctxs']
86 |
87 | hits = []
88 |
89 | for i, doc in enumerate(ctxs):
90 | text = doc['text']
91 |
92 | if text is None: # cannot find the document for some reason
93 | logger.warning("no doc in db")
94 | hits.append(False)
95 | continue
96 |
97 | hits.append(has_answer(answers, text, tokenizer))
98 |
99 | return hits
100 |
101 | def has_answer(answers, text, tokenizer) -> bool:
102 | """Check if a document contains an answer string."""
103 | text = _normalize(text)
104 | text = tokenizer.tokenize(text, uncased=True)
105 |
106 | for answer in answers:
107 | answer = _normalize(answer)
108 | answer = tokenizer.tokenize(answer, uncased=True)
109 | for i in range(0, len(text) - len(answer) + 1):
110 | if answer == text[i: i + len(answer)]:
111 | return True
112 | return False
113 |
114 | #################################################
115 | ######## READER EVALUATION ########
116 | #################################################
117 |
118 | def _normalize(text):
119 | return unicodedata.normalize('NFD', text)
120 |
121 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
122 | def normalize_answer(s):
123 | def remove_articles(text):
124 | return regex.sub(r'\b(a|an|the)\b', ' ', text)
125 |
126 | def white_space_fix(text):
127 | return ' '.join(text.split())
128 |
129 | def remove_punc(text):
130 | exclude = set(string.punctuation)
131 | return ''.join(ch for ch in text if ch not in exclude)
132 |
133 | def lower(text):
134 | return text.lower()
135 |
136 | return white_space_fix(remove_articles(remove_punc(lower(s))))
137 |
138 | def em(prediction, ground_truth):
139 | return normalize_answer(prediction) == normalize_answer(ground_truth)
140 |
141 | def f1(prediction, ground_truth):
142 | prediction_tokens = normalize_answer(prediction).split()
143 | ground_truth_tokens = normalize_answer(ground_truth).split()
144 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
145 | num_same = sum(common.values())
146 | if num_same == 0:
147 | return 0
148 | precision = 1.0 * num_same / len(prediction_tokens)
149 | recall = 1.0 * num_same / len(ground_truth_tokens)
150 | f1 = (2 * precision * recall) / (precision + recall)
151 | return f1
152 |
153 | def f1_score(prediction, ground_truths):
154 | return max([f1(prediction, gt) for gt in ground_truths])
155 |
156 | def exact_match_score(prediction, ground_truths):
157 | return max([em(prediction, gt) for gt in ground_truths])
158 |
159 | ####################################################
160 | ######## RETRIEVER EVALUATION ########
161 | ####################################################
162 |
163 | def eval_batch(scores, inversions, avg_topk, idx_topk):
164 | for k, s in enumerate(scores):
165 | s = s.cpu().numpy()
166 | sorted_idx = np.argsort(-s)
167 | score(sorted_idx, inversions, avg_topk, idx_topk)
168 |
169 | def count_inversions(arr):
170 | inv_count = 0
171 | lenarr = len(arr)
172 | for i in range(lenarr):
173 | for j in range(i + 1, lenarr):
174 | if (arr[i] > arr[j]):
175 | inv_count += 1
176 | return inv_count
177 |
178 | def score(x, inversions, avg_topk, idx_topk):
179 | x = np.array(x)
180 | inversions.append(count_inversions(x))
181 | for k in avg_topk:
182 | # ratio of passages in the predicted top-k that are
183 | # also in the topk given by gold score
184 | avg_pred_topk = (x[:k] None:
161 | for key, (value, weight) in vals.items():
162 | self.raw_stats[key] += value * weight
163 | self.total_weights[key] += weight
164 |
165 | @property
166 | def stats(self) -> Dict[str, float]:
167 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()}
168 |
169 | @property
170 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]:
171 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()}
172 |
173 | def reset(self) -> None:
174 | self.raw_stats = defaultdict(float)
175 | self.total_weights = defaultdict(float)
176 |
177 | @property
178 | def average_stats(self) -> Dict[str, float]:
179 | keys = sorted(self.raw_stats.keys())
180 | if torch.distributed.is_initialized():
181 | torch.distributed.broadcast_object_list(keys, src=0)
182 | global_dict = {}
183 | for k in keys:
184 | if not k in self.total_weights:
185 | v = 0.0
186 | else:
187 | v = self.raw_stats[k] / self.total_weights[k]
188 | v, _ = dist_utils.weighted_average(v, self.total_weights[k])
189 | global_dict[k] = v
190 | return global_dict
191 |
192 |
193 | def load_hf(object_class, model_name):
194 | try:
195 | obj = object_class.from_pretrained(model_name, local_files_only=True)
196 | except:
197 | obj = object_class.from_pretrained(model_name, local_files_only=False)
198 | return obj
199 |
200 |
201 | def init_tb_logger(output_dir):
202 | try:
203 | from torch.utils import tensorboard
204 |
205 | if dist_utils.is_main():
206 | tb_logger = tensorboard.SummaryWriter(output_dir)
207 | else:
208 | tb_logger = None
209 | except:
210 | logger.warning("Tensorboard is not available.")
211 | tb_logger = None
212 |
213 | return tb_logger
214 |
--------------------------------------------------------------------------------
/src/beir_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | from collections import defaultdict
5 | from typing import List, Dict
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 |
10 | import beir.util
11 | from beir.datasets.data_loader import GenericDataLoader
12 | from beir.retrieval.evaluation import EvaluateRetrieval
13 | from beir.retrieval.search.dense import DenseRetrievalExactSearch
14 |
15 | from beir.reranking.models import CrossEncoder
16 | from beir.reranking import Rerank
17 |
18 | import src.dist_utils as dist_utils
19 | from src import normalize_text
20 |
21 |
22 | class DenseEncoderModel:
23 | def __init__(
24 | self,
25 | query_encoder,
26 | doc_encoder=None,
27 | tokenizer=None,
28 | max_length=512,
29 | add_special_tokens=True,
30 | norm_query=False,
31 | norm_doc=False,
32 | lower_case=False,
33 | normalize_text=False,
34 | **kwargs,
35 | ):
36 | self.query_encoder = query_encoder
37 | self.doc_encoder = doc_encoder
38 | self.tokenizer = tokenizer
39 | self.max_length = max_length
40 | self.add_special_tokens = add_special_tokens
41 | self.norm_query = norm_query
42 | self.norm_doc = norm_doc
43 | self.lower_case = lower_case
44 | self.normalize_text = normalize_text
45 |
46 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
47 |
48 | if dist.is_initialized():
49 | idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()]
50 | else:
51 | idx = range(len(queries))
52 |
53 | queries = [queries[i] for i in idx]
54 | if self.normalize_text:
55 | queries = [normalize_text.normalize(q) for q in queries]
56 | if self.lower_case:
57 | queries = [q.lower() for q in queries]
58 |
59 | allemb = []
60 | nbatch = (len(queries) - 1) // batch_size + 1
61 | with torch.no_grad():
62 | for k in range(nbatch):
63 | start_idx = k * batch_size
64 | end_idx = min((k + 1) * batch_size, len(queries))
65 |
66 | qencode = self.tokenizer.batch_encode_plus(
67 | queries[start_idx:end_idx],
68 | max_length=self.max_length,
69 | padding=True,
70 | truncation=True,
71 | add_special_tokens=self.add_special_tokens,
72 | return_tensors="pt",
73 | )
74 | qencode = {key: value.cuda() for key, value in qencode.items()}
75 | emb = self.query_encoder(**qencode, normalize=self.norm_query)
76 | allemb.append(emb.cpu())
77 |
78 | allemb = torch.cat(allemb, dim=0)
79 | allemb = allemb.cuda()
80 | if dist.is_initialized():
81 | allemb = dist_utils.varsize_gather_nograd(allemb)
82 | allemb = allemb.cpu().numpy()
83 | return allemb
84 |
85 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
86 |
87 | if dist.is_initialized():
88 | idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()]
89 | else:
90 | idx = range(len(corpus))
91 | corpus = [corpus[i] for i in idx]
92 | corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus]
93 | if self.normalize_text:
94 | corpus = [normalize_text.normalize(c) for c in corpus]
95 | if self.lower_case:
96 | corpus = [c.lower() for c in corpus]
97 |
98 | allemb = []
99 | nbatch = (len(corpus) - 1) // batch_size + 1
100 | with torch.no_grad():
101 | for k in range(nbatch):
102 | start_idx = k * batch_size
103 | end_idx = min((k + 1) * batch_size, len(corpus))
104 |
105 | cencode = self.tokenizer.batch_encode_plus(
106 | corpus[start_idx:end_idx],
107 | max_length=self.max_length,
108 | padding=True,
109 | truncation=True,
110 | add_special_tokens=self.add_special_tokens,
111 | return_tensors="pt",
112 | )
113 | cencode = {key: value.cuda() for key, value in cencode.items()}
114 | emb = self.doc_encoder(**cencode, normalize=self.norm_doc)
115 | allemb.append(emb.cpu())
116 |
117 | allemb = torch.cat(allemb, dim=0)
118 | allemb = allemb.cuda()
119 | if dist.is_initialized():
120 | allemb = dist_utils.varsize_gather_nograd(allemb)
121 | allemb = allemb.cpu().numpy()
122 | return allemb
123 |
124 |
125 | def evaluate_model(
126 | query_encoder,
127 | doc_encoder,
128 | tokenizer,
129 | dataset,
130 | batch_size=128,
131 | add_special_tokens=True,
132 | norm_query=False,
133 | norm_doc=False,
134 | is_main=True,
135 | split="test",
136 | score_function="dot",
137 | beir_dir="BEIR/datasets",
138 | save_results_path=None,
139 | lower_case=False,
140 | normalize_text=False,
141 | ):
142 |
143 | metrics = defaultdict(list) # store final results
144 |
145 | if hasattr(query_encoder, "module"):
146 | query_encoder = query_encoder.module
147 | query_encoder.eval()
148 |
149 | if doc_encoder is not None:
150 | if hasattr(doc_encoder, "module"):
151 | doc_encoder = doc_encoder.module
152 | doc_encoder.eval()
153 | else:
154 | doc_encoder = query_encoder
155 |
156 | dmodel = DenseRetrievalExactSearch(
157 | DenseEncoderModel(
158 | query_encoder=query_encoder,
159 | doc_encoder=doc_encoder,
160 | tokenizer=tokenizer,
161 | add_special_tokens=add_special_tokens,
162 | norm_query=norm_query,
163 | norm_doc=norm_doc,
164 | lower_case=lower_case,
165 | normalize_text=normalize_text,
166 | ),
167 | batch_size=batch_size,
168 | )
169 | retriever = EvaluateRetrieval(dmodel, score_function=score_function)
170 | data_path = os.path.join(beir_dir, dataset)
171 |
172 | if not os.path.isdir(data_path) and is_main:
173 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
174 | data_path = beir.util.download_and_unzip(url, beir_dir)
175 | dist_utils.barrier()
176 |
177 | if not dataset == "cqadupstack":
178 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
179 | results = retriever.retrieve(corpus, queries)
180 | if is_main:
181 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
182 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
183 | if isinstance(metric, str):
184 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
185 | for key, value in metric.items():
186 | metrics[key].append(value)
187 | if save_results_path is not None:
188 | torch.save(results, f"{save_results_path}")
189 | elif dataset == "cqadupstack": # compute macroaverage over datasets
190 | paths = glob.glob(data_path)
191 | for path in paths:
192 | corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
193 | results = retriever.retrieve(corpus, queries)
194 | if is_main:
195 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
196 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
197 | if isinstance(metric, str):
198 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric)
199 | for key, value in metric.items():
200 | metrics[key].append(value)
201 | for key, value in metrics.items():
202 | assert (
203 | len(value) == 12
204 | ), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"
205 |
206 | metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}
207 |
208 | return metrics
209 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import glob
5 | import torch
6 | import random
7 | import json
8 | import csv
9 | import numpy as np
10 | import numpy.random
11 | import logging
12 | from collections import defaultdict
13 | import torch.distributed as dist
14 |
15 | from src import dist_utils
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def load_data(opt, tokenizer):
21 | datasets = {}
22 | for path in opt.train_data:
23 | data = load_dataset(path, opt.loading_mode)
24 | if data is not None:
25 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt)
26 | dataset = MultiDataset(datasets)
27 | dataset.set_prob(coeff=opt.sampling_coefficient)
28 | return dataset
29 |
30 |
31 | def load_dataset(data_path, loading_mode):
32 | files = glob.glob(os.path.join(data_path, "*.p*"))
33 | files.sort()
34 | tensors = []
35 | if loading_mode == "split":
36 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()]
37 | for filepath in files_split:
38 | try:
39 | tensors.append(torch.load(filepath, map_location="cpu"))
40 | except:
41 | logger.warning(f"Unable to load file {filepath}")
42 | elif loading_mode == "full":
43 | for fin in files:
44 | tensors.append(torch.load(fin, map_location="cpu"))
45 | elif loading_mode == "single":
46 | tensors.append(torch.load(files[0], map_location="cpu"))
47 | if len(tensors) == 0:
48 | return None
49 | tensor = torch.cat(tensors)
50 | return tensor
51 |
52 |
53 | class MultiDataset(torch.utils.data.Dataset):
54 | def __init__(self, datasets):
55 |
56 | self.datasets = datasets
57 | self.prob = [1 / len(self.datasets) for _ in self.datasets]
58 | self.dataset_ids = list(self.datasets.keys())
59 |
60 | def __len__(self):
61 | return sum([len(dataset) for dataset in self.datasets.values()])
62 |
63 | def __getitem__(self, index):
64 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0]
65 | did = self.dataset_ids[dataset_idx]
66 | index = random.randint(0, len(self.datasets[did]) - 1)
67 | sample = self.datasets[did][index]
68 | sample["dataset_id"] = did
69 | return sample
70 |
71 | def generate_offset(self):
72 | for dataset in self.datasets.values():
73 | dataset.generate_offset()
74 |
75 | def set_prob(self, coeff=0.0):
76 |
77 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()])
78 | prob /= prob.sum()
79 | prob = np.array([p**coeff for p in prob])
80 | prob /= prob.sum()
81 | self.prob = prob
82 |
83 |
84 | class Dataset(torch.utils.data.Dataset):
85 | """Monolingual dataset based on a list of paths"""
86 |
87 | def __init__(self, data, chunk_length, tokenizer, opt):
88 |
89 | self.data = data
90 | self.chunk_length = chunk_length
91 | self.tokenizer = tokenizer
92 | self.opt = opt
93 | self.generate_offset()
94 |
95 | def __len__(self):
96 | return (self.data.size(0) - self.offset) // self.chunk_length
97 |
98 | def __getitem__(self, index):
99 | start_idx = self.offset + index * self.chunk_length
100 | end_idx = start_idx + self.chunk_length
101 | tokens = self.data[start_idx:end_idx]
102 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
103 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
104 | q_tokens = apply_augmentation(q_tokens, self.opt)
105 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
106 | k_tokens = apply_augmentation(k_tokens, self.opt)
107 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
108 |
109 | return {"q_tokens": q_tokens, "k_tokens": k_tokens}
110 |
111 | def generate_offset(self):
112 | self.offset = random.randint(0, self.chunk_length - 1)
113 |
114 |
115 | class Collator(object):
116 | def __init__(self, opt):
117 | self.opt = opt
118 |
119 | def __call__(self, batch_examples):
120 |
121 | batch = defaultdict(list)
122 | for example in batch_examples:
123 | for k, v in example.items():
124 | batch[k].append(v)
125 |
126 | q_tokens, q_mask = build_mask(batch["q_tokens"])
127 | k_tokens, k_mask = build_mask(batch["k_tokens"])
128 |
129 | batch["q_tokens"] = q_tokens
130 | batch["q_mask"] = q_mask
131 | batch["k_tokens"] = k_tokens
132 | batch["k_mask"] = k_mask
133 |
134 | return batch
135 |
136 |
137 | def randomcrop(x, ratio_min, ratio_max):
138 |
139 | ratio = random.uniform(ratio_min, ratio_max)
140 | length = int(len(x) * ratio)
141 | start = random.randint(0, len(x) - length)
142 | end = start + length
143 | crop = x[start:end].clone()
144 | return crop
145 |
146 |
147 | def build_mask(tensors):
148 | shapes = [x.shape for x in tensors]
149 | maxlength = max([len(x) for x in tensors])
150 | returnmasks = []
151 | ids = []
152 | for k, x in enumerate(tensors):
153 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x))))
154 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x))))))
155 | ids = torch.stack(ids, dim=0).long()
156 | returnmasks = torch.stack(returnmasks, dim=0).bool()
157 | return ids, returnmasks
158 |
159 |
160 | def add_token(x, token):
161 | x = torch.cat((torch.tensor([token]), x))
162 | return x
163 |
164 |
165 | def deleteword(x, p=0.1):
166 | mask = np.random.rand(len(x))
167 | x = [e for e, m in zip(x, mask) if m > p]
168 | return x
169 |
170 |
171 | def replaceword(x, min_random, max_random, p=0.1):
172 | mask = np.random.rand(len(x))
173 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)]
174 | return x
175 |
176 |
177 | def maskword(x, mask_id, p=0.1):
178 | mask = np.random.rand(len(x))
179 | x = [e if m > p else mask_id for e, m in zip(x, mask)]
180 | return x
181 |
182 |
183 | def shuffleword(x, p=0.1):
184 | count = (np.random.rand(len(x)) < p).sum()
185 | """Shuffles any n number of values in a list"""
186 | indices_to_shuffle = random.sample(range(len(x)), k=count)
187 | to_shuffle = [x[i] for i in indices_to_shuffle]
188 | random.shuffle(to_shuffle)
189 | for index, value in enumerate(to_shuffle):
190 | old_index = indices_to_shuffle[index]
191 | x[old_index] = value
192 | return x
193 |
194 |
195 | def apply_augmentation(x, opt):
196 | if opt.augmentation == "mask":
197 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation))
198 | elif opt.augmentation == "replace":
199 | return torch.tensor(
200 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation)
201 | )
202 | elif opt.augmentation == "delete":
203 | return torch.tensor(deleteword(x, p=opt.prob_augmentation))
204 | elif opt.augmentation == "shuffle":
205 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation))
206 | else:
207 | if not isinstance(x, torch.Tensor):
208 | x = torch.Tensor(x)
209 | return x
210 |
211 |
212 | def add_bos_eos(x, bos_token_id, eos_token_id):
213 | if not isinstance(x, torch.Tensor):
214 | x = torch.Tensor(x)
215 | if bos_token_id is None and eos_token_id is not None:
216 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])])
217 | elif bos_token_id is not None and eos_token_id is None:
218 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()])
219 | elif bos_token_id is None and eos_token_id is None:
220 | pass
221 | else:
222 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])])
223 | return x
224 |
225 |
226 | # Used for passage retrieval
227 | def load_passages(path):
228 | if not os.path.exists(path):
229 | logger.info(f"{path} does not exist")
230 | return
231 | logger.info(f"Loading passages from: {path}")
232 | passages = []
233 | with open(path) as fin:
234 | if path.endswith(".jsonl"):
235 | for k, line in enumerate(fin):
236 | ex = json.loads(line)
237 | passages.append(ex)
238 | else:
239 | reader = csv.reader(fin, delimiter="\t")
240 | for k, row in enumerate(reader):
241 | if not row[0] == "id":
242 | ex = {"id": row[0], "title": row[2], "text": row[1]}
243 | passages.append(ex)
244 | return passages
245 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from tqdm import tqdm
4 | from vllm import LLM, SamplingParams
5 | import argparse
6 | import os
7 | import gc
8 | import torch
9 | from vllm.distributed.parallel_state import destroy_model_parallel
10 |
11 | parser = argparse.ArgumentParser(description="Process JSONL data using an LLM model.")
12 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.")
13 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.")
14 | parser.add_argument("--input-jsonl", type=str, required=True, help="Path to the input JSONL file.")
15 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.")
16 | parser.add_argument("--remain-output-json", type=str, required=True, help="Path to the remain output JSONL file.")
17 | args = parser.parse_args()
18 |
19 |
20 |
21 | def get_prompt_docs(query, docs):
22 | retrieved_content = "Retrieved Content:\n" + "\n".join(
23 | [
24 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}"
25 | for i, doc in enumerate(docs)
26 | ]
27 | )
28 | return [
29 | {'role': 'system',
30 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n"
31 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n"
32 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n"
33 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n"
34 | "If no documents are required to answer the query, output \"None\".\n"
35 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n"
36 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")},
37 | {'role': 'user',
38 | 'content': f"Query: {query}\n{retrieved_content}"}
39 | ]
40 |
41 |
42 | def get_prompt_answer(input_file, prompt, entry):
43 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules:
44 | 1. **Task**:
45 | - Use the provided `[Retrieved Content]` to generate responses.
46 | - If the Retrieved Content is None, you should generate answer based on your own knowledge.
47 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.”
48 | 2. **Requirements**:
49 | - **Accuracy**: Base your answers on the retrieved content.
50 | - **Conciseness**: Keep answers brief and relevant.
51 | - **Context Awareness**: Ensure your responses align with the user’s query.
52 | 3. **Input Format**:
53 | - Query: `[User Query]`
54 | - Retrieved: `[Retrieved Content]`
55 | 4. **Output Format**:
56 | - A structured, clear response tailored to the query.
57 | Always prioritize clarity and reliability."""
58 |
59 | if 'fever' in input_file:
60 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.'
61 | elif 'nq' in input_file:
62 | remember_prompt = 'Please answer the question with a short phrase.'
63 | elif 'hotpotqa' in input_file:
64 | remember_prompt = 'Please answer the question with a short phrase.'
65 | elif 'eli5' in input_file:
66 | remember_prompt = 'Please answer the question with a paragraph.'
67 | elif 'triviaqa' in input_file:
68 | remember_prompt = 'Please answer the question with a short phrase.'
69 | elif '2wikimqa' in input_file:
70 | remember_prompt = 'Please answer the question with a short phrase.'
71 | elif 'arc' in input_file:
72 | remember_prompt = 'Please answer the following questions and directly output the answer options.'
73 | elif 'asqa' in input_file:
74 | remember_prompt = 'Please answer the question with a short phrase.'
75 | elif 'popqa' in input_file:
76 | remember_prompt = 'Please answer the question with a short phrase.'
77 | else:
78 | remember_prompt = 'Please answer the question with a short phrase.'
79 |
80 | return [
81 | {'role': 'system',
82 | 'content': system_prompt},
83 | {'role': 'user',
84 | 'content': (f"{remember_prompt}\n\n"
85 | f"Query: {entry['question']}\n\n"
86 | f"{prompt}")}
87 | ]
88 |
89 |
90 | def generate_docs(llm, entry):
91 | messages = get_prompt_docs(entry['question'], entry['ctxs'])
92 |
93 | outputs = llm.chat(messages,
94 | sampling_params=sampling_params,
95 | use_tqdm=False,
96 | chat_template=chat_template)
97 |
98 | return [output.text for output in outputs[0].outputs]
99 |
100 |
101 | def map_function(data_list):
102 | result = []
103 | for item in data_list:
104 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)]
105 | result.append(ids)
106 | return result
107 |
108 |
109 | def convert(id_lists, docs):
110 | prompts = []
111 | for ids in id_lists:
112 | lines = []
113 | if len(ids) == 0:
114 | retrieved_content = "Retrieved Content: None"
115 | prompts.append(retrieved_content)
116 | continue
117 | for doc_id in ids:
118 | doc_title = docs[doc_id - 1]['title']
119 | doc_text = docs[doc_id - 1]['text']
120 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}"
121 | lines.append(line)
122 |
123 | retrieved_content = "\n".join(lines)
124 | prompt = f"Retrieved Content:\n{retrieved_content}"
125 | prompts.append(prompt)
126 | return prompts
127 |
128 |
129 | def generate_answers(input_file, llm, entry, prompts):
130 | results = []
131 | for prompt in prompts:
132 | messages = get_prompt_answer(input_file, prompt, entry)
133 |
134 | outputs = llm.chat(messages,
135 | sampling_params=sampling_params_generate,
136 | use_tqdm=False,
137 | chat_template=chat_template)
138 | results.append(outputs[0].outputs[0].text)
139 | return results
140 |
141 | def process_jsonl(llm, input_file, output_file, remain_output_file):
142 | """
143 | Main processing pipeline to read input JSONL, process data, and write output JSON and remain JSONL.
144 | """
145 | result = []
146 | result_remain = []
147 |
148 | with open(input_file, 'r', encoding='utf-8') as infile:
149 | datas = [json.loads(line.strip()) for line in infile]
150 |
151 | for entry in tqdm(datas, total=len(datas), desc="Processing"):
152 | if 'asqa' in input_file:
153 | entry['ctxs'] = entry['docs']
154 | # for 4k window, entry['ctxs'][:20]
155 | # for 8k window, entry['ctxs'][:40]
156 | entry['ctxs'] = entry['ctxs'][:40]
157 |
158 |
159 | ranking_string = generate_docs(llm, entry)
160 | id_list = map_function(ranking_string)
161 |
162 | try:
163 | prompt = convert(id_list, entry['ctxs'])
164 | except IndexError:
165 | result_remain.append(entry)
166 | continue
167 |
168 | answer = generate_answers(input_file, llm, entry, prompt)
169 | entry['response'] = answer
170 | entry['id_list'] = id_list
171 | result.append(entry)
172 |
173 | output_dir = os.path.dirname(output_file)
174 | remain_dir = os.path.dirname(remain_output_file)
175 | if output_dir:
176 | os.makedirs(output_dir, exist_ok=True)
177 | if remain_dir:
178 | os.makedirs(remain_dir, exist_ok=True)
179 |
180 | with open(output_file, 'w', encoding='utf-8') as outfile:
181 | json.dump(result, outfile, indent=4)
182 |
183 |
184 | with open(remain_output_file, 'w', encoding='utf-8') as outfile:
185 | json.dump(result_remain, outfile, indent=4)
186 |
187 |
188 |
189 | if __name__ == "__main__":
190 |
191 | with open(args.template, "r") as f:
192 | chat_template = f.read()
193 |
194 | llm = LLM(model=args.llm_model, tensor_parallel_size=8)
195 | sampling_params = SamplingParams(temperature=0.4, max_tokens=100)
196 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200)
197 | process_jsonl(llm, args.input_jsonl, args.output_json, args.remain_output_json)
198 | destroy_model_parallel()
199 | del llm
200 | gc.collect()
201 | torch.cuda.empty_cache()
202 | torch.distributed.destroy_process_group()
203 | print("Successfully delete the llm pipeline and free the GPU memory!")
204 |
205 |
206 |
--------------------------------------------------------------------------------
/top_inference.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from tqdm import tqdm
4 | from vllm import LLM, SamplingParams
5 | import argparse
6 | import os
7 | import gc
8 | import torch
9 | from vllm.distributed.parallel_state import destroy_model_parallel
10 |
11 | def get_prompt_docs(query, docs):
12 | retrieved_content = "Retrieved Content:\n" + "\n".join(
13 | [
14 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}"
15 | for i, doc in enumerate(docs)
16 | ]
17 | )
18 | return [
19 | {'role': 'system',
20 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n"
21 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n"
22 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n"
23 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n"
24 | "If no documents are required to answer the query, output \"None\".\n"
25 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n"
26 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")},
27 | {'role': 'user',
28 | 'content': f"Query: {query}\n{retrieved_content}"}
29 | ]
30 |
31 | def map_function(data_list):
32 | result = []
33 | for item in data_list:
34 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)]
35 | result.append(ids)
36 | return result
37 |
38 | def convert(id_lists, docs):
39 | prompts = []
40 | for ids in id_lists:
41 | lines = []
42 | if len(ids) == 0:
43 | retrieved_content = "Retrieved Content: None"
44 | prompts.append(retrieved_content)
45 | continue
46 | for doc_id in ids:
47 | doc_title = docs[doc_id - 1]['title']
48 | doc_text = docs[doc_id - 1]['text']
49 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}"
50 | lines.append(line)
51 |
52 | retrieved_content = "\n".join(lines)
53 | prompt = f"Retrieved Content:\n{retrieved_content}"
54 | prompts.append(prompt)
55 | return prompts
56 |
57 | def generate_docs(llm, query, docs, sampling_params, chat_template):
58 | messages = get_prompt_docs(query, docs)
59 | outputs = llm.chat(messages,
60 | sampling_params=sampling_params,
61 | use_tqdm=False,
62 | chat_template=chat_template)
63 | return [output.text for output in outputs[0].outputs]
64 |
65 | def generate_final_docs(llm, query, all_docs, max_context_window, sampling_params, chat_template):
66 | """
67 | Process all documents in batches, reduce to a manageable size (<= max_context_window), and return the final set of docs.
68 | """
69 | current_docs = []
70 | for i in range(0, len(all_docs), max_context_window):
71 | batch_docs = all_docs[i:i + max_context_window]
72 | ranking_strings = generate_docs(llm, query, batch_docs, sampling_params, chat_template)
73 | id_list = map_function(ranking_strings)
74 | for ids in id_list:
75 | for doc_id in ids:
76 | if all_docs[doc_id - 1] not in current_docs:
77 | current_docs.append(all_docs[doc_id - 1])
78 |
79 | while len(current_docs) > max_context_window:
80 | reduced_docs = []
81 | for i in range(0, len(current_docs), max_context_window):
82 | batch_docs = current_docs[i:i + max_context_window]
83 | ranking_strings = generate_docs(llm, query, batch_docs, sampling_params, chat_template)
84 | id_list = map_function(ranking_strings)
85 |
86 | for ids in id_list:
87 | for doc_id in ids:
88 | if doc_id - 1 < len(batch_docs):
89 | reduced_docs.append(batch_docs[doc_id - 1])
90 |
91 | current_docs = list(set(reduced_docs))
92 |
93 | if len(current_docs) > max_context_window and len(reduced_docs) == 0:
94 | raise ValueError("Failed to reduce document count below the context window. Check the LLM output.")
95 |
96 | return current_docs
97 |
98 | def generate_answers(llm, query, final_docs, sampling_params_generate, chat_template):
99 | prompt = convert([[i + 1 for i in range(len(final_docs))]], final_docs)[0]
100 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules:
101 | 1. **Task**:
102 | - Use the provided `[Retrieved Content]` to generate responses.
103 | - If the Retrieved Content is None, you should generate answer based on your own knowledge.
104 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.”
105 | 2. **Requirements**:
106 | - **Accuracy**: Base your answers on the retrieved content.
107 | - **Conciseness**: Keep answers brief and relevant.
108 | - **Context Awareness**: Ensure your responses align with the user’s query.
109 | 3. **Input Format**:
110 | - Query: `[User Query]`
111 | - Retrieved: `[Retrieved Content]`
112 | 4. **Output Format**:
113 | - A structured, clear response tailored to the query.
114 | Always prioritize clarity and reliability."""
115 | remember_prompt = 'Please answer the question with a short phrase.'
116 | messages = [
117 | {'role': 'system',
118 | 'content': system_prompt},
119 | {'role': 'user',
120 | 'content': (f"{remember_prompt}\n\n"
121 | f"Query: {query}\n\n"
122 | f"{prompt}")}
123 | ]
124 | outputs = llm.chat(messages,
125 | sampling_params=sampling_params_generate,
126 | use_tqdm=False,
127 | chat_template=chat_template)
128 | return outputs[0].outputs[0].text
129 |
130 | def process_jsonl(llm, input_file, output_file, remain_output_file, max_context_window):
131 | result = []
132 | result_remain = []
133 | with open(input_file, 'r', encoding='utf-8') as infile:
134 | datas = [json.loads(line.strip()) for line in infile]
135 |
136 | for entry in tqdm(datas, total=len(datas), desc="Processing"):
137 | try:
138 | entry['ctxs'] = entry['ctxs'][:args.topn]
139 | final_docs = generate_final_docs(llm, entry['question'], entry['ctxs'], max_context_window, sampling_params, chat_template)
140 | answer = generate_answers(llm, entry['question'], final_docs, sampling_params_generate, chat_template)
141 | entry['response'] = answer
142 | entry['final_docs'] = final_docs
143 | result.append(entry)
144 | except Exception as e:
145 | entry['error'] = str(e)
146 | result_remain.append(entry)
147 |
148 | with open(output_file, 'w', encoding='utf-8') as outfile:
149 | json.dump(result, outfile, indent=4)
150 |
151 | with open(remain_output_file, 'w', encoding='utf-8') as outfile:
152 | json.dump(result_remain, outfile, indent=4)
153 |
154 | if __name__ == "__main__":
155 | parser = argparse.ArgumentParser(description="Process JSONL data using an LLM model.")
156 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.")
157 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.")
158 | parser.add_argument("--input-jsonl", type=str, required=True, help="Path to the input JSONL file.")
159 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.")
160 | parser.add_argument("--remain-output-json", type=str, required=True, help="Path to the remain output JSONL file.")
161 | parser.add_argument("--max-context-window", type=int, default=40, help="Maximum number of documents per context window.")
162 | parser.add_argument("--topn", type=int, default=300, help="Maximum number of documents per context window.")
163 | args = parser.parse_args()
164 |
165 | with open(args.template, "r") as f:
166 | chat_template = f.read()
167 |
168 | llm = LLM(model=args.llm_model, tensor_parallel_size=8)
169 | sampling_params = SamplingParams(temperature=0.4, max_tokens=100)
170 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200)
171 |
172 | process_jsonl(llm, args.input_jsonl, args.output_json, args.remain_output_json, args.max_context_window)
173 |
174 | destroy_model_parallel()
175 | del llm
176 | gc.collect()
177 | torch.cuda.empty_cache()
178 | torch.distributed.destroy_process_group()
179 | print("Successfully deleted the LLM pipeline and freed GPU memory!")
180 |
--------------------------------------------------------------------------------
/retriever.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import argparse
9 | import json
10 | import pickle
11 | import time
12 | import glob
13 | from pathlib import Path
14 | from tqdm import tqdm
15 | import numpy as np
16 | import torch
17 | import transformers
18 |
19 | import src.index
20 | import src.contriever
21 | import src.utils
22 | import src.slurm
23 | import src.data
24 | from src.evaluation import calculate_matches
25 | import src.normalize_text
26 |
27 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
28 |
29 |
30 | class Retriever:
31 | def __init__(self, args, model=None, tokenizer=None) :
32 | self.args = args
33 | self.model = model
34 | self.tokenizer = tokenizer
35 |
36 | def embed_queries(self, args, queries):
37 | embeddings, batch_question = [], []
38 | with torch.no_grad():
39 | for k, q in enumerate(queries):
40 | if args.lowercase:
41 | q = q.lower()
42 | if args.normalize_text:
43 | q = src.normalize_text.normalize(q)
44 | batch_question.append(q)
45 |
46 | if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
47 |
48 | encoded_batch = self.tokenizer.batch_encode_plus(
49 | batch_question,
50 | return_tensors="pt",
51 | max_length=args.question_maxlength,
52 | padding=True,
53 | truncation=True,
54 | )
55 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
56 | output = self.model(**encoded_batch)
57 | embeddings.append(output.cpu())
58 |
59 | batch_question = []
60 |
61 | embeddings = torch.cat(embeddings, dim=0)
62 | print(f"Questions embeddings shape: {embeddings.size()}")
63 |
64 | return embeddings.numpy()
65 |
66 |
67 | def embed_queries_demo(self, queries):
68 | embeddings, batch_question = [], []
69 | with torch.no_grad():
70 | for k, q in enumerate(queries):
71 | batch_question.append(q)
72 |
73 | if len(batch_question) == 16 or k == len(queries) - 1:
74 |
75 | encoded_batch = self.tokenizer.batch_encode_plus(
76 | batch_question,
77 | return_tensors="pt",
78 | max_length=200,
79 | padding=True,
80 | truncation=True,
81 | )
82 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
83 | output = self.model(**encoded_batch)
84 | embeddings.append(output.cpu())
85 |
86 | batch_question = []
87 |
88 | embeddings = torch.cat(embeddings, dim=0)
89 | print(f"Questions embeddings shape: {embeddings.size()}")
90 |
91 | return embeddings.numpy()
92 |
93 | def index_encoded_data(self, index, embedding_files, indexing_batch_size):
94 | allids = []
95 | allembeddings = np.array([])
96 | for i, file_path in enumerate(embedding_files):
97 | print(f"Loading file {file_path}")
98 | with open(file_path, "rb") as fin:
99 | ids, embeddings = pickle.load(fin)
100 |
101 | allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
102 | allids.extend(ids)
103 | while allembeddings.shape[0] > indexing_batch_size:
104 | allembeddings, allids = self.add_embeddings(index, allembeddings, allids, indexing_batch_size)
105 |
106 | while allembeddings.shape[0] > 0:
107 | allembeddings, allids = self.add_embeddings(index, allembeddings, allids, indexing_batch_size)
108 |
109 | print("Data indexing completed.")
110 |
111 |
112 | def add_embeddings(self, index, embeddings, ids, indexing_batch_size):
113 | end_idx = min(indexing_batch_size, embeddings.shape[0])
114 | ids_toadd = ids[:end_idx]
115 | embeddings_toadd = embeddings[:end_idx]
116 | ids = ids[end_idx:]
117 | embeddings = embeddings[end_idx:]
118 | index.index_data(ids_toadd, embeddings_toadd)
119 | return embeddings, ids
120 |
121 |
122 | def add_passages(self, passages, top_passages_and_scores):
123 | # add passages to original data
124 | docs = [passages[doc_id] for doc_id in top_passages_and_scores[0][0]]
125 | return docs
126 |
127 | def setup_retriever(self):
128 | print(f"Loading model from: {self.args.model_name_or_path}")
129 | self.model, self.tokenizer, _ = src.contriever.load_retriever(self.args.model_name_or_path)
130 | self.model.eval()
131 | self.model = self.model.cuda()
132 | if not self.args.no_fp16:
133 | self.model = self.model.half()
134 |
135 | self.index = src.index.Indexer(self.args.projection_size, self.args.n_subquantizers, self.args.n_bits)
136 |
137 | # index all passages
138 | input_paths = glob.glob(self.args.passages_embeddings)
139 | input_paths = sorted(input_paths)
140 | embeddings_dir = os.path.dirname(input_paths[0])
141 | index_path = os.path.join(embeddings_dir, "index.faiss")
142 | if self.args.save_or_load_index and os.path.exists(index_path):
143 | self.index.deserialize_from(embeddings_dir)
144 | else:
145 | print(f"Indexing passages from files {input_paths}")
146 | start_time_indexing = time.time()
147 | self.index_encoded_data(self.index, input_paths, self.args.indexing_batch_size)
148 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
149 | if self.args.save_or_load_index:
150 | self.index.serialize(embeddings_dir)
151 |
152 | # load passages
153 | print("loading passages")
154 | self.passages = src.data.load_passages(self.args.passages)
155 | self.passage_id_map = {x["id"]: x for x in self.passages}
156 | print("passages have been loaded")
157 |
158 | def search_document(self, query, top_n=10):
159 | questions_embedding = self.embed_queries(self.args, [query])
160 |
161 | # get top k results
162 | start_time_retrieval = time.time()
163 | top_ids_and_scores = self.index.search_knn(questions_embedding, self.args.n_docs)
164 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
165 |
166 | return self.add_passages(self.passage_id_map, top_ids_and_scores)[:top_n]
167 |
168 | def search_document_demo(self, query, n_docs=10):
169 | questions_embedding = self.embed_queries_demo([query])
170 |
171 | # get top k results
172 | start_time_retrieval = time.time()
173 | top_ids_and_scores = self.index.search_knn(questions_embedding, n_docs)
174 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
175 |
176 | return self.add_passages(self.passage_id_map, top_ids_and_scores)[:n_docs]
177 |
178 | def setup_retriever_demo(self, model_name_or_path, passages, passages_embeddings, n_docs=5, save_or_load_index=False):
179 | print(f"Loading model from: {model_name_or_path}")
180 | self.model, self.tokenizer, _ = src.contriever.load_retriever(model_name_or_path)
181 | self.model.eval()
182 | self.model = self.model.cuda()
183 |
184 | self.index = src.index.Indexer(768, 0, 8)
185 |
186 | # index all passages
187 | input_paths = glob.glob(passages_embeddings)
188 | input_paths = sorted(input_paths)
189 | embeddings_dir = os.path.dirname(input_paths[0])
190 | index_path = os.path.join(embeddings_dir, "index.faiss")
191 | if save_or_load_index and os.path.exists(index_path):
192 | self.index.deserialize_from(embeddings_dir)
193 | else:
194 | print(f"Indexing passages from files {input_paths}")
195 | start_time_indexing = time.time()
196 | self.index_encoded_data(self.index, input_paths, 1000000)
197 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
198 |
199 | # load passages
200 | print("loading passages")
201 | self.passages = src.data.load_passages(passages)
202 | self.passage_id_map = {x["id"]: x for x in self.passages}
203 | print("passages have been loaded")
204 |
205 | def add_hasanswer(data, hasanswer):
206 | # add hasanswer to data
207 | for i, ex in enumerate(data):
208 | for k, d in enumerate(ex["ctxs"]):
209 | d["hasanswer"] = hasanswer[i][k]
210 |
211 |
212 | def load_data(data_path):
213 | if data_path.endswith(".json"):
214 | with open(data_path, "r") as fin:
215 | data = json.load(fin)
216 | elif data_path.endswith(".jsonl"):
217 | data = []
218 | with open(data_path, "r") as fin:
219 | for k, example in enumerate(fin):
220 | example = json.loads(example)
221 | data.append(example)
222 | return data
223 |
224 |
225 | def main(args):
226 | data = load_data(args.query)
227 |
228 | retriever = Retriever(args)
229 | retriever.setup_retriever()
230 |
231 | # 3. 对每条问答执行检索
232 | with open(args.output_dir, "a", encoding="utf-8") as fout:
233 | for item in tqdm(data):
234 | question = item["retrieved_question"]
235 | docs = retriever.search_document(question, args.n_docs)
236 | item["ctxs"] = docs
237 |
238 | fout.write(json.dumps(item, ensure_ascii=False) + "\n")
239 |
240 |
241 | if __name__ == "__main__":
242 | parser = argparse.ArgumentParser()
243 |
244 | parser.add_argument(
245 | "--query",
246 | type=str,
247 | default=None,
248 | help=".json file containing question and answers, similar format to reader data",
249 | )
250 | parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
251 | parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
252 | parser.add_argument(
253 | "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
254 | )
255 | parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
256 | parser.add_argument(
257 | "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
258 | )
259 | parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
260 | parser.add_argument(
261 | "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
262 | )
263 | parser.add_argument(
264 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
265 | )
266 | parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
267 | parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
268 | parser.add_argument(
269 | "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
270 | )
271 | parser.add_argument("--projection_size", type=int, default=768)
272 | parser.add_argument(
273 | "--n_subquantizers",
274 | type=int,
275 | default=0,
276 | help="Number of subquantizer used for vector quantization, if 0 flat index is used",
277 | )
278 | parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
279 | parser.add_argument("--lang", nargs="+")
280 | parser.add_argument("--dataset", type=str, default="none")
281 | parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
282 | parser.add_argument("--normalize_text", action="store_true", help="normalize text")
283 |
284 | args = parser.parse_args()
285 | src.slurm.init_distributed_mode(args)
286 | main(args)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # DynamicRAG: Leveraging Outputs of Large Language Model as Feedback for Dynamic Reranking in Retrieval-Augmented Generation
2 |
3 |
26 |
27 | **DynamicRAG** is an innovative framework for Retrieval-Augmented Generation (RAG) that dynamically adjusts both the **order** and **number** of retrieved documents per query. A reinforcement learning (RL) agent serves as the reranker, optimizing document retrieval based on feedback from a **Large Language Model (LLM)**. The training process is divided into two main stages:
28 |
29 | 1. **Supervised Fine-Tuning (SFT) via Behavior Cloning**:
30 | - Trains the reranker with expert trajectories.
31 | - Simplifies the action space and establishes a baseline.
32 | 2. **Reinforcement Learning (RL) with LLM Feedback**:
33 | - Uses interactive feedback from the generator.
34 | - Explores improved trajectories and further optimizes the reranker.
35 |
36 | ## How to cite
37 | If you extend or use this work, please cite the [paper](https://arxiv.org/abs/2212.07249) where it was introduced:
38 |
39 | ```
40 | @misc{sun2025dynamicragleveragingoutputslarge,
41 | title={DynamicRAG: Leveraging Outputs of Large Language Model as Feedback for Dynamic Reranking in Retrieval-Augmented Generation},
42 | author={Jiashuo Sun and Xianrui Zhong and Sizhe Zhou and Jiawei Han},
43 | year={2025},
44 | eprint={2505.07233},
45 | archivePrefix={arXiv},
46 | primaryClass={cs.CL},
47 | url={https://arxiv.org/abs/2505.07233},
48 | }
49 | ```
50 |
51 | ## 🔥 Update
52 | * [2025-09-18]: 🚀 Our paper is accepted by NeurIPS 2025! See you in San Diego! 🥳🥳🥳
53 | * [2025-07-13]: 🚀 We release the training data of DynamicRAG: [DynamicRAG_Training_Data_150k](https://huggingface.co/datasets/gasolsun/DynamicRAG_Training_Data_150k)
54 | * [2025-05-13]: 🚀 We release the paper: [https://arxiv.org/abs/2505.07233](https://arxiv.org/abs/2505.07233)
55 | * [2025-05-07]: 🚀 We release the [DynamicRAG-7B](https://huggingface.co/gasolsun/DynamicRAG-7B) and [DynamicRAG-8B](https://huggingface.co/gasolsun/DynamicRAG-8B) and [eval-datas](https://huggingface.co/datasets/gasolsun/DynamicRAG-Eval-Data).
56 | * [2025-05-05]: 🚀 We release the code for training and evaluation.
57 |
58 |
59 |
60 | ## Table of Contents
61 |
62 | - [DynamicRAG Overview](#dynamicrag-overview)
63 | - [Project Visualizations](#project-visualizations)
64 | - [📌 Data Processing Pipeline](#-data-processing-pipeline)
65 | - [🎯 Supervised Fine-Tuning (SFT) Training](#-supervised-fine-tuning-sft-training)
66 | - [🤖 Interactive Data Collection](#-interactive-data-collection)
67 | - [📈 Direct Preference Optimization (DPO) Training](#-direct-preference-optimization-dpo-training)
68 | - [🔍 Inference and Evaluation](#-inference-and-evaluation)
69 | - [📄 Licensing and Claims](#-licensing-and-claims)
70 |
71 | ---
72 |
73 | ## DynamicRAG Overview
74 |
75 | DynamicRAG adjusts the retrieval process on-the-fly by:
76 | - Dynamically reordering and selecting the number of documents per query.
77 | - Leveraging a reranker trained with RL and LLM feedback to improve retrieval quality.
78 |
79 | ---
80 |
81 | ## 💡 Preliminaries
82 | You should install the enviroment by `pip install -r requirements.txt`, and running:
83 | ```python
84 | apt-get update
85 | apt-get install libtiff5
86 | ```
87 | Moreover, you need to config the retriever corpus, e.g. official 2018 English Wikipedia embeddings. We use the exact same config with [Self-RAG](https://github.com/AkariAsai/self-rag). You can read their Retriever Setup.
88 |
89 |
90 | ## 📌 Data Processing Pipeline
91 | Example: Training LLaMA3-8B with Top-40 Documents
92 |
93 | ### **1. Prepare BC Data Pipeline**
94 | #### **Step 1: Retrieve Top-40 Documents**
95 | Run the retrieval script:
96 | ```bash
97 | #!/bin/bash
98 |
99 | NUM_GPUS=8
100 | INPUT_FILE="data/rag_training_data.json"
101 | SPLIT_DIR="data/splits"
102 |
103 | python split_data.py --input_file $INPUT_FILE --output_dir $SPLIT_DIR --num_splits $NUM_GPUS
104 |
105 | for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
106 | SPLIT_FILE="${SPLIT_DIR}/split_${GPU_ID}.json"
107 | OUTPUT_FILE="output/retrieval_split_${GPU_ID}.json"
108 | log_file="logs/retriever_split_${GPU_ID}.log"
109 | CUDA_VISIBLE_DEVICES=$GPU_ID python retriever.py \
110 | --model_name_or_path models/retriever \
111 | --passages data/psgs_w100.tsv \
112 | --passages_embeddings "data/wikipedia_embeddings/*" \
113 | --query $SPLIT_FILE \
114 | --output_dir $OUTPUT_FILE \
115 | --n_docs 50 \
116 | 1>"$log_file" 2>&1 &
117 |
118 | echo "Started process on GPU $GPU_ID with input $SPLIT_FILE"
119 | done
120 |
121 | wait
122 | echo "All processes completed."
123 |
124 | ```
125 |
126 | #### **Step 2: Aggregate Retrieved Data**
127 | ```bash
128 | python aggregate.py
129 | ```
130 |
131 | #### **Step 3: Rerank Documents**
132 | ```bash
133 | python reranker.py --model_name_or_path models/reranker/monot5 \
134 | --input_file output/retrieval_data.jsonl \
135 | --output_file output/retrieval_data_rerank.jsonl \
136 | --device cuda
137 | ```
138 | Outputs: `retrieval_data_rerank.jsonl`
139 |
140 | > 💡 If you running above command slowly, consider running it with multi-gpus like retriever and then combine the results.
141 |
142 |
143 | #### **Step 4: Compute True/False in Reranking**
144 | ```bash
145 | python process_training_data.py
146 | ```
147 | Outputs:
148 | - `retrieval_data_rerank_sequence.json` (for Reranker BC training)
149 | - `retrieval_data_rerank_normal.json` (for SFT & DPO training)
150 |
151 | #### **Step 5: Convert Reranker Data for Training**
152 | ```bash
153 | python reranker_sequence.py
154 | ```
155 | Output: `reranker_bc_data.json` (formatted for **LLaMA-Factory**)
156 |
157 | #### **Step 6: Split SFT & DPO Data**
158 | ```bash
159 | python split_for_sft_dpo.py
160 | ```
161 |
162 | #### **Step 7: Construct Generator SFT Data**
163 | ```bash
164 | python construct_generator_sft.py
165 | ```
166 |
167 | ---
168 |
169 | ## 🎯 Supervised Fine-Tuning (SFT) Training
170 | We use **LLaMA-Factory** as the training framework. Install it from [here](https://github.com/hiyouga/LLaMA-Factory).
171 |
172 | ### **1. Configure `dataset_info.json`**
173 | Modify `LLaMA-Factory/data/dataset_info.json`:
174 | ```json
175 | {
176 | "generator_sft": {
177 | "file_name": "generator_sft_training.json",
178 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"}
179 | },
180 |
181 | "reranker_bc": {
182 | "file_name": "reranker_bc_training.json",
183 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"}
184 | },
185 |
186 | "alpaca_data": {
187 | "file_name": "alpaca_data_cleaned_system.json",
188 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"}
189 | }
190 | }
191 | ```
192 |
193 | ### **2. Train the Model**
194 | Modify `llama8b.yaml` and run:
195 | ```bash
196 | llamafactory-cli train examples/train_full/llama8b.yaml
197 | ```
198 | > 🛠️ Requires at least **8 A100-80G GPUs**.
199 |
200 | ---
201 |
202 | ## 🤖 Interactive Data Collection
203 | We use **vLLM** for faster sampling.
204 |
205 | ### **1. Sample Interaction Trajectories**
206 | ```bash
207 | python sampling_dpo_trajectories.py \
208 | --template template/llama3.jinja \
209 | --llm-model DynamicRAG_llama3_8b \
210 | --input-jsonl training_data/training_data_dpo.jsonl \
211 | --output-json results/training_data_dpo_sampling.json
212 |
213 | ```
214 |
215 | ### **2. Collect Rewards for Trajectories**
216 | ```bash
217 | ython reward_trajectories.py \
218 | --input_file results/training_data_dpo_sampling.json \
219 | --output_file training_data/llama3_8b_output_dpo.jsonl \
220 | ```
221 |
222 | ### **3. Construct DPO Training Data**
223 | ```bash
224 | python construct_dpo.py
225 | ```
226 |
227 | ---
228 |
229 | ## 📈 Direct Preference Optimization (DPO) Training
230 |
231 | ### **1. Configure `dataset_info.json`**
232 | ```json
233 | {
234 | "llama3_generator_dpo": {
235 | "file_name": "llama3_8b_generator_dpo.json",
236 | "ranking": true,
237 | "columns": {"prompt": "instruction", "query": "input", "chosen": "chosen", "rejected": "rejected"}
238 | },
239 |
240 | "llama3_reranker_dpo": {
241 | "file_name": "llama3_8b_reranker_dpo.json",
242 | "ranking": true,
243 | "columns": {"prompt": "instruction", "query": "input", "chosen": "chosen", "rejected": "rejected"}
244 | }
245 | }
246 | ```
247 |
248 | ### **2. Train the Model**
249 | ```bash
250 | llamafactory-cli train examples/train_full/llama8b_dpo.yaml
251 | ```
252 | > 🛠️ Requires at least **8 A100-80G GPUs**.
253 |
254 | ---
255 |
256 | ## 🔍 Inference and Evaluation
257 | We use **vLLM** for efficient inference.
258 |
259 | ### **1. Run Inference**
260 | ```bash
261 | #!/bin/bash
262 |
263 |
264 | LOG_DIR="eval_logs"
265 | mkdir -p $LOG_DIR
266 |
267 | run_inference() {
268 | local input_file=$1
269 | local output_file=$2
270 | local remain_output_file=$3
271 |
272 | echo "Running inference for $input_file..."
273 | python inference.py \
274 | --template template/llama3.jinja \
275 | --llm-model DynamicRAG_llama3_8b \
276 | --input-json $input_file \
277 | --output-json $output_file \
278 | --remain-output-json $remain_output_file \
279 | >> $LOG_DIR/$(basename $output_file .json)_log.txt 2>&1
280 |
281 | sleep 5
282 | }
283 |
284 |
285 | run_inference "eval_data/triviaqa.jsonl" \
286 | "results/llama3_8b_triviaqa.json" \
287 | "results/llama3_8b_triviaqa_remain.json"
288 |
289 | run_inference "eval_data/nq.jsonl" \
290 | "results/llama3_8b_nq.json" \
291 | "results/llama3_8b_nq_remain.json"
292 |
293 | run_inference "eval_data/hotpotqa.jsonl" \
294 | "results/llama3_8b_hotpotqa.json" \
295 | "results/llama3_8b_hotpotqa_remain.json"
296 |
297 | run_inference "eval_data/2wikimqa.jsonl" \
298 | "results/llama3_8b_2wikimqa.json" \
299 | "results/llama3_8b_2wikimqa_remain.json"
300 |
301 | run_inference "eval_data/fever.jsonl" \
302 | "results/llama3_8b_fever.json" \
303 | "results/llama3_8b_fever_remain.json"
304 |
305 | run_inference "eval_data/eli5.jsonl" \
306 | "results/llama3_8b_eli5.json" \
307 | "results/llama3_8b_eli5_remain.json"
308 |
309 | run_inference "eval_data/asqa_eval_gtr_top100.jsonl" \
310 | "results/llama3_8b_asqa.json" \
311 | "results/llama3_8b_asqa_remain.json"
312 |
313 | echo "All tasks completed. Logs are available in $LOG_DIR."
314 |
315 | ```
316 | Evaluates **7 different benchmarks**.
317 |
318 | ### **2. Evaluate Performance**
319 | ```bash
320 | # install nltk, rouge_score, spacy
321 | # python -m spacy download en_core_web_sm
322 |
323 | # for example, when we evaluate nq
324 | python evaluate.py \
325 | --results_file results/llama3_8b_nq.json \
326 | --metric match
327 | ```
328 |
329 | ### **3. Run DynamicRAG on 500+ Documents**
330 | ```bash
331 | #!/bin/bash
332 |
333 | TEMPLATE="template/llama3.jinja"
334 | LLM_MODEL="DynamicRAG_llama3_8b"
335 | INPUT_JSONL="eval_data/nq_top500.jsonl"
336 | MAX_CONTEXT_WINDOW=40
337 |
338 | TOPN_VALUES=(50 100 150 200 300 500)
339 |
340 | for TOPN in "${TOPN_VALUES[@]}"; do
341 | LOG_FILE="top_logs/llama3_8b_nq_top_${TOPN}.log"
342 |
343 | python top_inference.py \
344 | --template "$TEMPLATE" \
345 | --llm-model "$LLM_MODEL" \
346 | --input-jsonl "$INPUT_JSONL" \
347 | --output-json "results/llama3_8b_top_${TOPN}_nq.json" \
348 | --remain-output-json "results/llama3_8b_top_${TOPN}_nq_remain.json" \
349 | --max-context-window "$MAX_CONTEXT_WINDOW" \
350 | --topn "$TOPN" >> "$LOG_FILE" 2>&1
351 |
352 | sleep 3
353 | done
354 |
355 | ```
356 |
357 | ---
358 |
359 |
360 |
361 | ## Project Visualizations
362 |
363 | Explore the key components and performance of DynamicRAG through the following images:
364 |
365 | - **Introduction of DynamicRAG:**
366 |
367 |
372 |
373 |
374 | - **Pipeline of DynamicRAG:**
375 |
376 |
381 |
382 | - **Generator Experiment:**
383 |
384 |
389 |
390 |
391 | - **Reranker Experiment:**
392 |
393 |
398 |
399 |
400 |
401 | - **Efficiency of DynamicRAG:**
402 |
403 |
408 |
409 | - **Case Study:**
410 |
411 |
416 |
421 |
422 |
423 | ---
424 |
425 |
426 | ## 📄 Licensing and Claims
427 | This project is licensed under the Apache 2.0 protocol. The project assumes no legal responsibility for any output generated by the models and will not be held liable for any damages resulting from the use of the provided resources and outputs.
428 |
429 |
--------------------------------------------------------------------------------