├── src ├── __init__.py ├── index.py ├── inbatch.py ├── dist_utils.py ├── slurm.py ├── normalize_text.py ├── moco.py ├── contriever.py ├── finetuning_data.py ├── options.py ├── evaluation.py ├── utils.py ├── beir_utils.py └── data.py ├── requirements.txt ├── evaluate.sh ├── run_reranker.sh ├── template ├── mistral.jinja ├── llama2.jinja └── llama3.jinja ├── aggregate.py ├── run_retriever.sh ├── llama8b.yaml ├── llama8b_dpo.yaml ├── split_for_sft_dpo.py ├── split_data.py ├── inference.sh ├── reward_trajectories.py ├── reranker_sequence.py ├── evaluate.py ├── metrics.py ├── process_training_data.py ├── construct_generator_sft.py ├── reranker.py ├── utils.py ├── sampling_dpo_trajectories.py ├── constrcut_dpo.py ├── inference.py ├── top_inference.py ├── retriever.py └── readme.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rouge 2 | bert_score 3 | faiss-gpu 4 | tqdm 5 | openai 6 | vllm 7 | deepspeed 8 | flash-attn 9 | llamafactory-cli -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | # install nltk, rouge_score, spacy 2 | # python -m spacy download en_core_web_sm 3 | 4 | 5 | python evaluate.py \ 6 | --results_file results/llama3_8b_nq.json \ 7 | --metric match -------------------------------------------------------------------------------- /run_reranker.sh: -------------------------------------------------------------------------------- 1 | python reranker.py --model_name_or_path models/reranker/monot5 \ 2 | --input_file output/retrieval_data.jsonl \ 3 | --output_file output/retrieval_data_rerank.jsonl \ 4 | --device cuda 5 | -------------------------------------------------------------------------------- /template/mistral.jinja: -------------------------------------------------------------------------------- 1 | {% if messages[0]['role'] == 'system' %} 2 | {% set loop_messages = messages[1:] %} 3 | {% set system_message = messages[0]['content'] %} 4 | {% else %} 5 | {% set loop_messages = messages %} 6 | {% endif %} 7 | 8 | {% if system_message is defined %} 9 | {{ '' }}{{ system_message }} 10 | {% endif %} 11 | 12 | {% for message in loop_messages %} 13 | {% set content = message['content'] %} 14 | 15 | {% if message['role'] == 'user' %} 16 | {{ '[INST] ' + content + ' [/INST]' }} 17 | {% elif message['role'] == 'assistant' %} 18 | {{ content + '' }} 19 | {% endif %} 20 | {% endfor %} 21 | -------------------------------------------------------------------------------- /template/llama2.jinja: -------------------------------------------------------------------------------- 1 | {% if messages[0]['role'] == 'system' %} 2 | {% set loop_messages = messages[1:] %} 3 | {% set system_message = messages[0]['content'] %} 4 | {% else %} 5 | {% set loop_messages = messages %} 6 | {% endif %} 7 | 8 | {% for message in loop_messages %} 9 | {% set content = message['content'] %} 10 | 11 | {% if loop.index0 == 0 and system_message is defined %} 12 | {% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %} 13 | {% endif %} 14 | 15 | {% if message['role'] == 'user' %} 16 | {{ '' + '[INST] ' + content + ' [/INST]' }} 17 | {% elif message['role'] == 'assistant' %} 18 | {{ content + '' }} 19 | {% endif %} 20 | {% endfor %} 21 | -------------------------------------------------------------------------------- /aggregate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | 5 | input_files = [f"retrieval_split_{i}.json" for i in range(32)] 6 | output_file = "retrieval_data.jsonl" 7 | 8 | 9 | data = [] 10 | id_set = set() 11 | 12 | 13 | for file in tqdm(input_files, desc="Processing files"): 14 | with open(file, 'r', encoding='utf-8') as f: 15 | for line in f: 16 | record = json.loads(line) 17 | record_id = record.get("id") 18 | if record_id not in id_set: 19 | id_set.add(record_id) 20 | data.append(record) 21 | 22 | 23 | with open(output_file, 'w', encoding='utf-8') as f: 24 | for record in tqdm(data, desc="Writing to file"): 25 | f.write(json.dumps(record, ensure_ascii=False) + "\n") 26 | 27 | -------------------------------------------------------------------------------- /template/llama3.jinja: -------------------------------------------------------------------------------- 1 | {{ '<|begin_of_text|>' }} 2 | 3 | {% if messages[0]['role'] == 'system' %} 4 | {% set loop_messages = messages[1:] %} 5 | {% set system_message = messages[0]['content'] %} 6 | {% else %} 7 | {% set loop_messages = messages %} 8 | {% endif %} 9 | 10 | {% if system_message is defined %} 11 | {{ '<|start_header_id|>system<|end_header_id|>' }} 12 | {{ system_message }} 13 | {{ '<|eot_id|>' }} 14 | {% endif %} 15 | 16 | {% for message in loop_messages %} 17 | {% set content = message['content'] %} 18 | 19 | {% if message['role'] == 'user' %} 20 | {{ '<|start_header_id|>user<|end_header_id|>' }} 21 | {{ content }} 22 | {{ '<|eot_id|>' }} 23 | {{ '<|start_header_id|>assistant<|end_header_id|>' }} 24 | {% elif message['role'] == 'assistant' %} 25 | {{ content }} 26 | {{ '<|eot_id|>' }} 27 | {% endif %} 28 | {% endfor %} 29 | -------------------------------------------------------------------------------- /run_retriever.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_GPUS=8 4 | INPUT_FILE="data/rag_training_data.json" 5 | SPLIT_DIR="data/splits" 6 | 7 | python split_data.py --input_file $INPUT_FILE --output_dir $SPLIT_DIR --num_splits $NUM_GPUS 8 | 9 | for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do 10 | SPLIT_FILE="${SPLIT_DIR}/split_${GPU_ID}.json" 11 | OUTPUT_FILE="output/retrieval_split_${GPU_ID}.json" 12 | log_file="logs/retriever_split_${GPU_ID}.log" 13 | CUDA_VISIBLE_DEVICES=$GPU_ID python retriever.py \ 14 | --model_name_or_path models/retriever \ 15 | --passages data/psgs_w100.tsv \ 16 | --passages_embeddings "data/wikipedia_embeddings/*" \ 17 | --query $SPLIT_FILE \ 18 | --output_dir $OUTPUT_FILE \ 19 | --n_docs 50 \ 20 | 1>"$log_file" 2>&1 & 21 | 22 | echo "Started process on GPU $GPU_ID with input $SPLIT_FILE" 23 | done 24 | 25 | wait 26 | echo "All processes completed." 27 | -------------------------------------------------------------------------------- /llama8b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: dynamicRAG/models/generator/llama3_8b 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: examples/deepspeed/ds_z3_config.json 9 | 10 | ### dataset 11 | dataset: alpaca_data, reranker_bc, generator_sft 12 | template: llama3 13 | cutoff_len: 8192 14 | # max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/dynamicrag_llama3_8b/ 20 | logging_steps: 1 21 | save_steps: 2000 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 2 27 | gradient_accumulation_steps: 2 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | eval_dataset: alpaca_en_demo 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 1000000000 40 | -------------------------------------------------------------------------------- /llama8b_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: LLaMA-Factory/saves/dynamicrag_llama3_8b 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: examples/deepspeed/ds_z3_config.json 9 | 10 | ### dataset 11 | dataset: llama3_generator_dpo, llama3_reranker_dpo 12 | template: llama3 13 | cutoff_len: 8192 14 | # max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | # 20 | output_dir: saves/dynamicrag_llama3_8b_dpo/ 21 | logging_steps: 1 22 | save_steps: 200 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 5.0e-6 30 | num_train_epochs: 1.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | eval_dataset: llama3_reranker_dpo 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 1000000000 -------------------------------------------------------------------------------- /split_for_sft_dpo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | def shuffle_and_split_jsonl_to_json(input_file, n, output_file_1, output_file_2): 5 | data = [] 6 | with open(input_file, 'r', encoding='utf-8') as f: 7 | for line in f: 8 | data.append(json.loads(line.strip())) 9 | print(len(data)) 10 | if not isinstance(data, list): 11 | raise ValueError("The input JSONL must contain a list of items.") 12 | 13 | random.shuffle(data) 14 | 15 | data_part_1 = data[:n] 16 | data_part_2 = data[n:] 17 | 18 | with open(output_file_1, 'w', encoding='utf-8') as f: 19 | json.dump(data_part_1, f, ensure_ascii=False, indent=4) 20 | 21 | with open(output_file_2, 'w', encoding='utf-8') as f: 22 | json.dump(data_part_2, f, ensure_ascii=False, indent=4) 23 | 24 | print(f"Shuffled and split data saved to {output_file_1} and {output_file_2}") 25 | 26 | input_file = 'training_data/retrieval_data_rerank_normal.jsonl' 27 | n = 70000 28 | output_file_1 = 'training_data/training_data_sft.json' 29 | output_file_2 = 'training_data/training_data_dpo.json' 30 | 31 | shuffle_and_split_jsonl_to_json(input_file, n, output_file_1, output_file_2) 32 | -------------------------------------------------------------------------------- /split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import math 5 | 6 | def split_data(input_file, output_dir, num_splits): 7 | with open(input_file, 'r', encoding='utf-8') as f: 8 | data = [json.loads(line.strip()) for line in f] 9 | 10 | total_items = len(data) 11 | items_per_split = math.ceil(total_items / num_splits) 12 | 13 | os.makedirs(output_dir, exist_ok=True) 14 | 15 | for i in range(num_splits): 16 | split_data = data[i * items_per_split:(i + 1) * items_per_split] 17 | split_file = os.path.join(output_dir, f"split_{i}.json") 18 | with open(split_file, 'w', encoding='utf-8') as out_f: 19 | json.dump(split_data, out_f, indent=4, ensure_ascii=False) 20 | 21 | print(f"Data split into {num_splits} parts and saved in {output_dir}") 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--input_file", type=str, required=True, help="Path to input JSONL file") 26 | parser.add_argument("--output_dir", type=str, required=True, help="Directory to save splits") 27 | parser.add_argument("--num_splits", type=int, required=True, help="Number of splits") 28 | 29 | args = parser.parse_args() 30 | split_data(args.input_file, args.output_dir, args.num_splits) 31 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | LOG_DIR="eval_logs" 5 | mkdir -p $LOG_DIR 6 | 7 | run_inference() { 8 | local input_file=$1 9 | local output_file=$2 10 | local remain_output_file=$3 11 | 12 | echo "Running inference for $input_file..." 13 | python inference.py \ 14 | --template template/llama3.jinja \ 15 | --llm-model DynamicRAG_llama3_8b \ 16 | --input-json $input_file \ 17 | --output-json $output_file \ 18 | --remain-output-json $remain_output_file \ 19 | >> $LOG_DIR/$(basename $output_file .json)_log.txt 2>&1 20 | 21 | sleep 5 22 | } 23 | 24 | 25 | run_inference "eval_data/triviaqa.jsonl" \ 26 | "results/llama3_8b_triviaqa.json" \ 27 | "results/llama3_8b_triviaqa_remain.json" 28 | 29 | run_inference "eval_data/nq.jsonl" \ 30 | "results/llama3_8b_nq.json" \ 31 | "results/llama3_8b_nq_remain.json" 32 | 33 | run_inference "eval_data/hotpotqa.jsonl" \ 34 | "results/llama3_8b_hotpotqa.json" \ 35 | "results/llama3_8b_hotpotqa_remain.json" 36 | 37 | run_inference "eval_data/2wikimqa.jsonl" \ 38 | "results/llama3_8b_2wikimqa.json" \ 39 | "results/llama3_8b_2wikimqa_remain.json" 40 | 41 | run_inference "eval_data/fever.jsonl" \ 42 | "results/llama3_8b_fever.json" \ 43 | "results/llama3_8b_fever_remain.json" 44 | 45 | run_inference "eval_data/eli5.jsonl" \ 46 | "results/llama3_8b_eli5.json" \ 47 | "results/llama3_8b_eli5_remain.json" 48 | 49 | run_inference "eval_data/asqa_eval_gtr_top100.jsonl" \ 50 | "results/llama3_8b_asqa.json" \ 51 | "results/llama3_8b_asqa_remain.json" 52 | 53 | echo "All tasks completed. Logs are available in $LOG_DIR." 54 | -------------------------------------------------------------------------------- /reward_trajectories.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | from rouge import Rouge 5 | from utils import compute_bert_score, compute_rouge_score, compute_exact_match, inverse_reward, llm_eval 6 | 7 | def process_data(input_file, output_file): 8 | with open(input_file, 'r', encoding='utf-8') as f: 9 | data = json.load(f) 10 | 11 | rouge_score = Rouge() 12 | 13 | weights = [0.2, 0.2, 0.2, 0.2, 0.2] 14 | 15 | for item in tqdm(data): 16 | instruction = item['question'] 17 | answers = [ans for ans in item.get("answer", []) if ans.strip()] 18 | responses = item["responses"] 19 | id_list = item.get("id_lists", []) 20 | 21 | scores = [] 22 | if len(set(responses)) == 1: 23 | scores = [0] * len(responses) 24 | else: 25 | for response in responses: 26 | response_scores = [ 27 | weights[0] * compute_bert_score(answer, response) + 28 | weights[1] * compute_rouge_score(rouge_score, answer, response) + 29 | weights[2] * compute_exact_match(answer, response) + 30 | weights[3] * inverse_reward(id_list) + 31 | weights[4] * llm_eval(instruction, answer, response) 32 | for answer in answers 33 | ] 34 | scores.append(sum(response_scores) / len(response_scores) if response_scores else 0) 35 | 36 | item["response_scores"] = scores 37 | 38 | with open(output_file, 'w', encoding='utf-8') as f: 39 | json.dump(data, f, indent=4, ensure_ascii=False) 40 | 41 | print(f"Processed data saved to {output_file}") 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--input_file", type=str, required=True, help="Path to input JSON file") 46 | parser.add_argument("--output_file", type=str, required=True, help="Path to save processed JSON file") 47 | 48 | args = parser.parse_args() 49 | process_data(args.input_file, args.output_file) 50 | -------------------------------------------------------------------------------- /reranker_sequence.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | 5 | input_file = 'training_data/retrieval_data_rerank_sequence.jsonl' 6 | output_file = 'training_data/reranker_bc_training.json' 7 | 8 | 9 | data = [] 10 | with open(input_file, "r", encoding="utf-8") as f: 11 | for line in f: 12 | data.append(json.loads(line.strip())) 13 | 14 | 15 | new_data = [] 16 | 17 | for item in tqdm(data, desc="Processing items"): 18 | sequence = item.get("sequence", []) 19 | docs = item.get("docs", []) 20 | question = item.get("question", "") 21 | 22 | query = question 23 | 24 | system_prompt = ( 25 | f"You are an expert at dynamically generating document identifiers to answer a given query.\n" 26 | f"I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n" 27 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n" 28 | f"Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n" 29 | f"If no documents are required to answer the query, output \"None\".\n" 30 | f"Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n" 31 | f"Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text." 32 | ) 33 | 34 | if not sequence: 35 | output = "None" 36 | else: 37 | output = ", ".join(f"[{seq}]" for seq in sequence) 38 | 39 | retrieved_content = "Retrieved Content:\n" + "\n".join( 40 | [ 41 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}" 42 | for i, doc in enumerate(docs) 43 | ] 44 | ) 45 | instruction_data = { 46 | "instruction": f"Query: {question}\n\n{retrieved_content}", 47 | "input": "", 48 | "output": output, 49 | "system": system_prompt 50 | } 51 | 52 | new_data.append(instruction_data) 53 | 54 | 55 | with open(output_file, "w", encoding="utf-8") as f: 56 | json.dump(new_data, f, ensure_ascii=False, indent=4) 57 | -------------------------------------------------------------------------------- /src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pickle 9 | from typing import List, Tuple 10 | 11 | import faiss 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | class Indexer(object): 16 | 17 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8): 18 | if n_subquantizers > 0: 19 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 20 | else: 21 | self.index = faiss.IndexFlatIP(vector_sz) 22 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64) 23 | self.index_id_to_db_id = [] 24 | 25 | def index_data(self, ids, embeddings): 26 | self._update_id_mapping(ids) 27 | embeddings = embeddings.astype('float32') 28 | if not self.index.is_trained: 29 | self.index.train(embeddings) 30 | self.index.add(embeddings) 31 | 32 | print(f'Total data indexed {len(self.index_id_to_db_id)}') 33 | 34 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]: 35 | query_vectors = query_vectors.astype('float32') 36 | result = [] 37 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 38 | for k in tqdm(range(nbatch)): 39 | start_idx = k*index_batch_size 40 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 41 | q = query_vectors[start_idx: end_idx] 42 | scores, indexes = self.index.search(q, top_docs) 43 | # convert to external ids 44 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 45 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 46 | return result 47 | 48 | def serialize(self, dir_path): 49 | index_file = os.path.join(dir_path, 'index.faiss') 50 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 51 | print(f'Serializing index to {index_file}, meta data to {meta_file}') 52 | 53 | faiss.write_index(self.index, index_file) 54 | with open(meta_file, mode='wb') as f: 55 | pickle.dump(self.index_id_to_db_id, f) 56 | 57 | def deserialize_from(self, dir_path): 58 | index_file = os.path.join(dir_path, 'index.faiss') 59 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 60 | print(f'Loading index from {index_file}, meta data from {meta_file}') 61 | 62 | self.index = faiss.read_index(index_file) 63 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 64 | 65 | with open(meta_file, "rb") as reader: 66 | self.index_id_to_db_id = pickle.load(reader) 67 | assert len( 68 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 69 | 70 | def _update_id_mapping(self, db_ids: List): 71 | #new_ids = np.array(db_ids, dtype=np.int64) 72 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) 73 | self.index_id_to_db_id.extend(db_ids) -------------------------------------------------------------------------------- /src/inbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | import random 8 | import transformers 9 | import logging 10 | import torch.distributed as dist 11 | 12 | from src import contriever, dist_utils, utils 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class InBatch(nn.Module): 18 | def __init__(self, opt, retriever=None, tokenizer=None): 19 | super(InBatch, self).__init__() 20 | 21 | self.opt = opt 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.label_smoothing = opt.label_smoothing 25 | if retriever is None or tokenizer is None: 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | self.tokenizer = tokenizer 30 | self.encoder = retriever 31 | 32 | def _load_retriever(self, model_id, pooling, random_init): 33 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 34 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 35 | 36 | if "xlm" in model_id: 37 | model_class = contriever.XLMRetriever 38 | else: 39 | model_class = contriever.Contriever 40 | 41 | if random_init: 42 | retriever = model_class(cfg) 43 | else: 44 | retriever = utils.load_hf(model_class, model_id) 45 | 46 | if "bert-" in model_id: 47 | if tokenizer.bos_token_id is None: 48 | tokenizer.bos_token = "[CLS]" 49 | if tokenizer.eos_token_id is None: 50 | tokenizer.eos_token = "[SEP]" 51 | 52 | retriever.config.pooling = pooling 53 | 54 | return retriever, tokenizer 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 60 | 61 | bsz = len(q_tokens) 62 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 63 | 64 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 65 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 66 | 67 | gather_fn = dist_utils.gather 68 | 69 | gather_kemb = gather_fn(kemb) 70 | 71 | labels = labels + dist_utils.get_rank() * len(kemb) 72 | 73 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 74 | 75 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 76 | 77 | # log stats 78 | if len(stats_prefix) > 0: 79 | stats_prefix = stats_prefix + "/" 80 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 81 | 82 | predicted_idx = torch.argmax(scores, dim=-1) 83 | accuracy = 100 * (predicted_idx == labels).float().mean() 84 | stdq = torch.std(qemb, dim=0).mean().item() 85 | stdk = torch.std(kemb, dim=0).mean().item() 86 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 87 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 88 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 89 | 90 | return loss, iter_stats 91 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from metrics import metric_max_over_ground_truths, exact_match_score, match, rouge, f1 3 | import json 4 | from tqdm import tqdm 5 | from rouge import Rouge 6 | 7 | def load_file(file_path): 8 | """ 9 | Load data from a JSON or JSONL file. 10 | 11 | Args: 12 | file_path (str): Path to the file to load. 13 | 14 | Returns: 15 | list: List of dictionaries loaded from the file. 16 | 17 | Raises: 18 | ValueError: If the file format is not supported. 19 | """ 20 | try: 21 | with open(file_path, 'r', encoding='utf-8') as f: 22 | if file_path.endswith('.json'): 23 | data = json.load(f) 24 | if isinstance(data, dict): 25 | return [data] 26 | elif isinstance(data, list): 27 | return data 28 | else: 29 | raise ValueError("Unsupported JSON structure. Expecting list or dict.") 30 | elif file_path.endswith('.jsonl'): 31 | data = [json.loads(line.strip()) for line in f if line.strip()] 32 | return data 33 | else: 34 | raise ValueError(f"Unsupported file format: {file_path}") 35 | except Exception as e: 36 | print(f"Error loading file {file_path}: {e}") 37 | raise 38 | 39 | 40 | def get_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--results_file", type=str, help="File containing results with both generated responses and ground truth answers") 43 | parser.add_argument("--metric", type=str, choices=["em", "accuracy", "match", "rouge", "f1"], help="Metric to use for evaluation") 44 | return parser.parse_args() 45 | 46 | def main(): 47 | args = get_args() 48 | 49 | results = load_file(args.results_file) 50 | 51 | scores = [] 52 | cnt=0 53 | rouge_score = Rouge() 54 | for item in tqdm(results): 55 | response = item['response'][0] 56 | 57 | question = item['question'] 58 | 59 | if 'asqa' in args.results_file: 60 | answers= [] 61 | for ans in item['qa_pairs']: 62 | answers.extend(ans['short_answers']) 63 | else: 64 | answers = item['answers'] 65 | if not answers: 66 | print(f"Warning: No answers provided for ID {item.get('id', 'unknown')}") 67 | continue 68 | 69 | if args.metric == "em": 70 | metric_result = metric_max_over_ground_truths( 71 | exact_match_score, response, answers 72 | ) 73 | elif args.metric == "accuracy": 74 | response = response.replace('\n','').strip()[0] 75 | answer = answers[0][0] 76 | if response==answer: 77 | metric_result = 1.0 78 | else: 79 | metric_result = 0.0 80 | elif args.metric == "match": 81 | metric_result = match(question, response, answers) 82 | elif args.metric == "rouge": 83 | metric_result = rouge(rouge_score, response, answers) 84 | elif args.metric == "f1": 85 | metric_result = f1(response, answers) 86 | else: 87 | raise NotImplementedError(f"Metric {args.metric} is not implemented.") 88 | 89 | scores.append(metric_result) 90 | 91 | if scores: 92 | print(f'Overall result: {sum(scores) / len(scores)}') 93 | else: 94 | print("No scores were calculated. Please check your input file.") 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | import re 4 | from collections import Counter 5 | import re 6 | from typing import List 7 | 8 | # nltk.download('punkt_tab') 9 | 10 | def convert_to_capitalized(s): 11 | if s.isupper(): 12 | return s.capitalize() 13 | return s 14 | 15 | def exact_match_score(prediction, ground_truth): 16 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 17 | 18 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 19 | scores_for_ground_truths = [] 20 | for ground_truth in ground_truths: 21 | score = metric_fn(prediction, ground_truth) 22 | scores_for_ground_truths.append(score) 23 | return max(scores_for_ground_truths) 24 | 25 | def accuracy(preds, labels): 26 | match_count = 0 27 | 28 | for pred, label in zip(preds, labels): 29 | target = label[0] 30 | if pred == target or pred[0]==target: 31 | match_count += 1 32 | if match_count == 0: 33 | print(repr(pred)) 34 | print(target) 35 | return 100 * (match_count / len(preds)) 36 | 37 | 38 | def f1(decoded_preds, decoded_labels): 39 | f1_all = [] 40 | for prediction, answers in zip(decoded_preds, decoded_labels): 41 | if type(answers) == list: 42 | if len(answers) == 0: 43 | return 0 44 | f1_all.append(np.max([qa_f1_score(prediction, gt) 45 | for gt in answers])) 46 | else: 47 | f1_all.append(qa_f1_score(prediction, answers)) 48 | return 100 * np.mean(f1_all) 49 | 50 | 51 | def qa_f1_score(prediction, ground_truth): 52 | prediction_tokens = normalize_answer(prediction).split() 53 | ground_truth_tokens = normalize_answer(ground_truth).split() 54 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 55 | num_same = sum(common.values()) 56 | if num_same == 0: 57 | return 0 58 | precision = 1.0 * num_same / len(prediction_tokens) 59 | recall = 1.0 * num_same / len(ground_truth_tokens) 60 | f1 = (2 * precision * recall) / (precision + recall) 61 | return f1 62 | 63 | 64 | def normalize_answer(s): 65 | def remove_articles(text): 66 | return re.sub(r'\b(a|an|the)\b', ' ', text) 67 | 68 | def white_space_fix(text): 69 | return ' '.join(text.split()) 70 | 71 | def remove_punc(text): 72 | exclude = set(string.punctuation) 73 | return ''.join(ch for ch in text if ch not in exclude) 74 | 75 | def lower(text): 76 | return text.lower() 77 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 78 | 79 | def find_entity_tags(sentence): 80 | entity_regex = r'(.+?)(?=\s<|$)' 81 | tag_regex = r'<(.+?)>' 82 | entity_names = re.findall(entity_regex, sentence) 83 | tags = re.findall(tag_regex, sentence) 84 | 85 | results = {} 86 | for entity, tag in zip(entity_names, tags): 87 | if "<" in entity: 88 | results[entity.split("> ")[1]] = tag 89 | else: 90 | results[entity] = tag 91 | return results 92 | 93 | def match(question, prediction, ground_truth): 94 | prediction = prediction.replace('\n','').strip().lower() 95 | for gt in ground_truth: 96 | if prediction in gt.lower(): 97 | return 1 98 | return 0 99 | 100 | 101 | def rouge(rouge, prediction: str, ground_truths: List[str]) -> float: 102 | if prediction=="": 103 | return 0.0 104 | 105 | highest = 0.0 106 | 107 | for ref in ground_truths: 108 | if ref=="": 109 | return 0.0 110 | score = rouge.get_scores(ref, prediction) 111 | highest = max(highest, score[0]['rouge-l']['f']) 112 | 113 | return highest -------------------------------------------------------------------------------- /process_training_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections import Counter 4 | from tqdm import tqdm 5 | 6 | 7 | def load_jsonl(file_path): 8 | data = [] 9 | with open(file_path, 'r', encoding='utf-8') as file: 10 | for line in file: 11 | data.append(json.loads(line)) 12 | return data 13 | 14 | 15 | def save_jsonl(data, file_path): 16 | with open(file_path, 'w', encoding='utf-8') as file: 17 | for item in data: 18 | json.dump(item, file, ensure_ascii=False) 19 | file.write("\n") 20 | 21 | 22 | def process_jsonl(input_file, output_sequence_file, output_normal_file): 23 | result = [] 24 | remaining_data = [] 25 | length_counters = {length: 0 for length in range(15)} 26 | true_counts = [] 27 | 28 | data = load_jsonl(input_file) 29 | unused_data = data.copy() 30 | total_items = len(data) 31 | 32 | with tqdm(total=total_items, desc="Processing", unit="item") as pbar: 33 | while sum(length_counters.values()) < 30000 and unused_data: 34 | item = unused_data.pop(0) 35 | used = False 36 | 37 | true_ids = [] 38 | false_ids = [] 39 | 40 | for idx, doc in enumerate(item['docs'], start=1): 41 | if doc.get('rerank_label') == "true": 42 | true_ids.append(idx) 43 | else: 44 | false_ids.append(idx) 45 | 46 | item['true_id'] = true_ids 47 | item['false_id'] = false_ids 48 | item['true_count'] = len(true_ids) 49 | true_counts.append(len(true_ids)) 50 | 51 | for length in range(15): 52 | if length_counters[length] >= 2000: 53 | continue 54 | 55 | if len(true_ids) > length: 56 | continue 57 | 58 | shuffled_true_id = true_ids.copy() 59 | shuffled_false_id = false_ids.copy() 60 | random.shuffle(shuffled_true_id) 61 | random.shuffle(shuffled_false_id) 62 | 63 | sequence = shuffled_true_id 64 | remaining_length = length - len(shuffled_true_id) 65 | if remaining_length > 0: 66 | sequence += shuffled_false_id[:remaining_length] 67 | 68 | new_item = item.copy() 69 | new_item["sequence_length"] = length 70 | new_item["sequence"] = sequence 71 | result.append(new_item) 72 | 73 | length_counters[length] += 1 74 | used = True 75 | 76 | if sum(length_counters.values()) >= 30000: 77 | break 78 | 79 | if sum(length_counters.values()) >= 30000: 80 | break 81 | 82 | if not used: 83 | remaining_data.append(item) 84 | 85 | pbar.update(1) 86 | 87 | remaining_data.extend(unused_data) 88 | 89 | save_jsonl(result, output_sequence_file) 90 | save_jsonl(remaining_data, output_normal_file) 91 | print(f"Data generation complete. Total sequence items: {len(result)}") 92 | print(f"Remaining data items: {len(remaining_data)}") 93 | 94 | true_count_distribution = Counter(true_counts) 95 | print("True Count Distribution:") 96 | for count, freq in sorted(true_count_distribution.items()): 97 | print(f"True Count = {count}: Frequency = {freq}") 98 | 99 | # 主流程 100 | if __name__ == "__main__": 101 | input_jsonl_file = 'training_data/retrieval_data_rerank.jsonl' 102 | output_sequence_file = 'training_data/retrieval_data_rerank_sequence.jsonl' 103 | output_normal_file = 'training_data/retrieval_data_rerank_normal.jsonl' 104 | process_jsonl(input_jsonl_file, output_sequence_file, output_normal_file) 105 | -------------------------------------------------------------------------------- /src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Gather(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x: torch.tensor): 10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 11 | dist.all_gather(output, x) 12 | return tuple(output) 13 | 14 | @staticmethod 15 | def backward(ctx, *grads): 16 | all_gradients = torch.stack(grads) 17 | dist.all_reduce(all_gradients) 18 | return all_gradients[dist.get_rank()] 19 | 20 | 21 | def gather(x: torch.tensor): 22 | if not dist.is_initialized(): 23 | return x 24 | x_gather = Gather.apply(x) 25 | x_gather = torch.cat(x_gather, dim=0) 26 | return x_gather 27 | 28 | 29 | @torch.no_grad() 30 | def gather_nograd(x: torch.tensor): 31 | if not dist.is_initialized(): 32 | return x 33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())] 34 | dist.all_gather(x_gather, x, async_op=False) 35 | 36 | x_gather = torch.cat(x_gather, dim=0) 37 | return x_gather 38 | 39 | 40 | @torch.no_grad() 41 | def varsize_gather_nograd(x: torch.Tensor): 42 | """gather tensors of different sizes along the first dimension""" 43 | if not dist.is_initialized(): 44 | return x 45 | 46 | # determine max size 47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 49 | dist.all_gather(allsizes, size) 50 | max_size = max([size.cpu().max() for size in allsizes]) 51 | 52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) 53 | padded[: x.shape[0]] = x 54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] 55 | dist.all_gather(output, padded) 56 | 57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] 58 | output = torch.cat(output, dim=0) 59 | 60 | return output 61 | 62 | 63 | @torch.no_grad() 64 | def get_varsize(x: torch.Tensor): 65 | """gather tensors of different sizes along the first dimension""" 66 | if not dist.is_initialized(): 67 | return [x.shape[0]] 68 | 69 | # determine max size 70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 72 | dist.all_gather(allsizes, size) 73 | allsizes = torch.cat(allsizes) 74 | return allsizes 75 | 76 | 77 | def get_rank(): 78 | if not dist.is_available(): 79 | return 0 80 | if not dist.is_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main(): 86 | return get_rank() == 0 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_initialized(): 91 | return 1 92 | else: 93 | return dist.get_world_size() 94 | 95 | 96 | def barrier(): 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | 101 | def average_main(x): 102 | if not dist.is_initialized(): 103 | return x 104 | if dist.is_initialized() and dist.get_world_size() > 1: 105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 106 | if is_main(): 107 | x = x / dist.get_world_size() 108 | return x 109 | 110 | 111 | def sum_main(x): 112 | if not dist.is_initialized(): 113 | return x 114 | if dist.is_initialized() and dist.get_world_size() > 1: 115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 116 | return x 117 | 118 | 119 | def weighted_average(x, count): 120 | if not dist.is_initialized(): 121 | if isinstance(x, torch.Tensor): 122 | x = x.item() 123 | return x, count 124 | t_loss = torch.tensor([x * count]).cuda() 125 | t_total = torch.tensor([count]).cuda() 126 | t_loss = sum_main(t_loss) 127 | t_total = sum_main(t_total) 128 | return (t_loss / t_total).item(), t_total.item() 129 | -------------------------------------------------------------------------------- /src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | import os 9 | import sys 10 | import torch 11 | import socket 12 | import signal 13 | import subprocess 14 | 15 | 16 | logger = getLogger() 17 | 18 | def sig_handler(signum, frame): 19 | logger.warning("Signal handler called with signal " + str(signum)) 20 | prod_id = int(os.environ['SLURM_PROCID']) 21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 22 | if prod_id == 0: 23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 25 | else: 26 | logger.warning("Not the main process, no need to requeue.") 27 | sys.exit(-1) 28 | 29 | 30 | def term_handler(signum, frame): 31 | logger.warning("Signal handler called with signal " + str(signum)) 32 | logger.warning("Bypassing SIGTERM.") 33 | 34 | 35 | def init_signal_handler(): 36 | """ 37 | Handle signals sent by SLURM for time limit / pre-emption. 38 | """ 39 | signal.signal(signal.SIGUSR1, sig_handler) 40 | signal.signal(signal.SIGTERM, term_handler) 41 | 42 | 43 | def init_distributed_mode(params): 44 | """ 45 | Handle single and multi-GPU / multi-node / SLURM jobs. 46 | Initialize the following variables: 47 | - local_rank 48 | - global_rank 49 | - world_size 50 | """ 51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ 52 | has_local_rank = hasattr(params, 'local_rank') 53 | 54 | # SLURM job without torch.distributed.launch 55 | if is_slurm_job and has_local_rank: 56 | 57 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 58 | 59 | # local rank on the current node / global rank 60 | params.local_rank = int(os.environ['SLURM_LOCALID']) 61 | params.global_rank = int(os.environ['SLURM_PROCID']) 62 | params.world_size = int(os.environ['SLURM_NTASKS']) 63 | 64 | # define master address and master port 65 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 66 | params.main_addr = hostnames.split()[0].decode('utf-8') 67 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1 68 | 69 | # set environment variables for 'env://' 70 | os.environ['MASTER_ADDR'] = params.main_addr 71 | os.environ['MASTER_PORT'] = str(params.main_port) 72 | os.environ['WORLD_SIZE'] = str(params.world_size) 73 | os.environ['RANK'] = str(params.global_rank) 74 | is_distributed = True 75 | 76 | 77 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 78 | elif has_local_rank and params.local_rank != -1: 79 | 80 | assert params.main_port == -1 81 | 82 | # read environment variables 83 | params.global_rank = int(os.environ['RANK']) 84 | params.world_size = int(os.environ['WORLD_SIZE']) 85 | 86 | is_distributed = True 87 | 88 | # local job (single GPU) 89 | else: 90 | params.local_rank = 0 91 | params.global_rank = 0 92 | params.world_size = 1 93 | is_distributed = False 94 | 95 | # set GPU device 96 | torch.cuda.set_device(params.local_rank) 97 | 98 | # initialize multi-GPU 99 | if is_distributed: 100 | 101 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 102 | # 'env://' will read these environment variables: 103 | # MASTER_PORT - required; has to be a free port on machine with rank 0 104 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 105 | # WORLD_SIZE - required; can be set either here, or in a call to init function 106 | # RANK - required; can be set either here, or in a call to init function 107 | 108 | #print("Initializing PyTorch distributed ...") 109 | torch.distributed.init_process_group( 110 | init_method='env://', 111 | backend='nccl', 112 | #world_size=params.world_size, 113 | #rank=params.global_rank, 114 | ) -------------------------------------------------------------------------------- /construct_generator_sft.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | 5 | def read_json(file_path): 6 | with open(file_path, 'r', encoding='utf-8') as file: 7 | return json.load(file) 8 | 9 | def format_data(data): 10 | error_data = 0 11 | formatted_data = [] 12 | 13 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules: 14 | 15 | 1. **Task**: 16 | - Use the provided `[Retrieved Content]` to generate responses. 17 | - If the Retrieved Content is None, you should generate answer based on your own knowledge. 18 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.” 19 | 20 | 2. **Requirements**: 21 | - **Accuracy**: Base your answers on the retrieved content. 22 | - **Conciseness**: Keep answers brief and relevant. 23 | - **Context Awareness**: Ensure your responses align with the user’s query. 24 | 25 | 3. **Input Format**: 26 | - Query: `[User Query]` 27 | - Retrieved: `[Retrieved Content]` 28 | 29 | 4. **Output Format**: 30 | - A structured, clear response tailored to the query. 31 | 32 | Always prioritize clarity and reliability.""" 33 | for item in data: 34 | 35 | item_id = item['id'] 36 | 37 | if 'fever' in item_id: 38 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.' 39 | elif 'nq' in item_id: 40 | remember_prompt = 'Please answer the question with a short phrase.' 41 | elif 'hotpotqa' in item_id: 42 | remember_prompt = 'Please answer the question with a short phrase.' 43 | elif 'eli5' in item_id: 44 | remember_prompt = 'Please answer the question with a paragraph.' 45 | elif 'tc' in item_id: 46 | remember_prompt = 'Please answer the question with a short phrase.' 47 | elif 'asqa' in item_id: 48 | remember_prompt = 'Please answer the question with a short phrase.' 49 | else: 50 | remember_prompt = 'Please answer the question with a short phrase.' 51 | 52 | true_ids = item.get('true_id', [])[:15] 53 | true_docs = [item['docs'][i - 1] for i in true_ids if i <= len(item['docs'])] 54 | 55 | if len(true_docs) == 0: 56 | retrieved_content = "Retrieved Content: None" 57 | else: 58 | retrieved_content = "Retrieved Content:\n" + "\n".join( 59 | [ 60 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}" 61 | for i, doc in enumerate(true_docs) 62 | ] 63 | ) 64 | 65 | if 'tc' in item_id: 66 | filtered_answers = [item['answer'].strip()] 67 | if type(item['answer']) == list: 68 | filtered_answers = [ans.strip() for ans in item.get('answer', []) if ans.strip()] 69 | elif type(item['answer']) == str: 70 | filtered_answers = [item['answer'].strip()] 71 | else: 72 | error_data += 1 73 | continue 74 | 75 | if filtered_answers: 76 | first_answer = filtered_answers[0] 77 | else: 78 | error_data += 1 79 | continue 80 | 81 | instruction_data = { 82 | "instruction": f"{remember_prompt}\nQuery: {item['question']}\n\n{retrieved_content}", 83 | "input": "", 84 | "output": first_answer, 85 | "system": system_prompt 86 | } 87 | 88 | formatted_data.append(instruction_data) 89 | 90 | print(f"Error Data Count: {error_data}") 91 | return formatted_data 92 | 93 | 94 | def write_json(file_path, data): 95 | with open(file_path, 'w', encoding='utf-8') as file: 96 | json.dump(data, file, ensure_ascii=False, indent=4) 97 | 98 | 99 | def main(input_file, output_file): 100 | data = read_json(input_file) 101 | random.shuffle(data) 102 | formatted_data = format_data(data) 103 | print(len(formatted_data)) 104 | write_json(output_file, formatted_data) 105 | 106 | 107 | if __name__ == "__main__": 108 | input_file = 'training_data/training_data_sft.json' 109 | output_file = 'training_data/generator_sft_training.json' 110 | 111 | main(input_file, output_file) 112 | -------------------------------------------------------------------------------- /reranker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import T5Tokenizer, T5ForConditionalGeneration 4 | from typing import List 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | 9 | class Reranker: 10 | def __init__(self, model_name_or_path: str, device: str = "cuda"): 11 | """ 12 | Initializes the Reranker with a specified model. 13 | 14 | Args: 15 | model_name_or_path (str): Path to the pretrained reranker model. 16 | device (str): Device to load the model on ("cuda" or "cpu"). 17 | """ 18 | self.device = device 19 | print(f"Loading reranker model from: {model_name_or_path}") 20 | self.tokenizer = T5Tokenizer.from_pretrained(model_name_or_path) 21 | self.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path) 22 | self.model = self.model.to(self.device) 23 | self.model.eval() 24 | print("Reranker model loaded successfully.") 25 | 26 | def rerank(self, query: str, docs: List[dict]) -> List[str]: 27 | """ 28 | Rerank a list of documents based on their relevance to the query. 29 | 30 | Args: 31 | query (str): The input query. 32 | docs (List[dict]): A list of documents, each represented as a dictionary with keys "id" and "text". 33 | 34 | Returns: 35 | List[str]: A list of relevance labels ("true" or "false") for the documents. 36 | """ 37 | labels = [] 38 | inputs = [] 39 | 40 | # Prepare inputs for the model 41 | for doc in docs: 42 | combined_input = f"Query: {query} Document: {doc['text']} Relevant:" 43 | inputs.append(combined_input) 44 | 45 | # Tokenize inputs in batches 46 | batch_size = 128 # Adjust based on available memory 47 | for start in range(0, len(inputs), batch_size): 48 | batch = inputs[start:start + batch_size] 49 | encoded_batch = self.tokenizer.batch_encode_plus( 50 | batch, 51 | return_tensors="pt", 52 | padding=True, 53 | truncation=True, 54 | max_length=512 55 | ) 56 | encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()} 57 | 58 | # Generate labels 59 | with torch.no_grad(): 60 | outputs = self.model.generate( 61 | input_ids=encoded_batch["input_ids"], 62 | attention_mask=encoded_batch["attention_mask"], 63 | max_length=2 # Only need a short output like "true" or "false" 64 | ) 65 | 66 | # Decode generated sequences to labels 67 | decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 68 | labels.extend([output.strip().lower() for output in decoded_outputs]) 69 | 70 | return labels 71 | 72 | 73 | if __name__ == "__main__": 74 | # Define command-line arguments 75 | parser = argparse.ArgumentParser(description="Reranker Demo") 76 | parser.add_argument("--model_name_or_path", type=str, default="", 77 | help="Path to the pretrained reranker model.") 78 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", 79 | help="Device to run the reranker on ('cuda' or 'cpu').") 80 | parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSONL file.") 81 | parser.add_argument("--output_file", type=str, required=True, help="Path to save the output JSONL file.") 82 | args = parser.parse_args() 83 | 84 | # Load the reranker model 85 | reranker = Reranker(model_name_or_path=args.model_name_or_path, device=args.device) 86 | 87 | # Process JSONL file 88 | with open(args.input_file, "r", encoding='utf-8') as input_file, open(args.output_file, "a", encoding='utf-8') as output_file: 89 | for line in tqdm(input_file): 90 | if not line.strip(): 91 | continue 92 | # Parse the JSON object 93 | example = json.loads(line.strip()) 94 | query = example["question"] 95 | docs = example["docs"] 96 | 97 | # Get rerank labels 98 | rerank_labels = reranker.rerank(query=query, docs=docs) 99 | 100 | # Add rerank labels to each document 101 | for doc, label in zip(docs, rerank_labels): 102 | doc["rerank_label"] = label 103 | 104 | # Save the updated example to the output file as a JSONL entry 105 | output_file.write(json.dumps(example, ensure_ascii=False) + "\n") 106 | 107 | print(f"Reranking completed and saved to {args.output_file}.") 108 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from bert_score import score 2 | from rouge import Rouge 3 | import json 4 | import jsonlines 5 | import openai 6 | 7 | 8 | openai.api_key = "Your_API_Key" 9 | 10 | def compute_bert_score(reference: str, candidate: str): 11 | """ 12 | Compute the BERTScore between two strings. 13 | 14 | Parameters: 15 | reference (str): The reference string. 16 | candidate (str): The candidate string to evaluate against the reference. 17 | 18 | Returns: 19 | float: The BERTScore (F1 score) between the reference and candidate. 20 | """ 21 | if candidate == "" or reference == "": 22 | return 0 23 | try: 24 | precision, recall, f1 = score([candidate], [reference], lang="en", rescale_with_baseline=True) # 25 | except: 26 | return 0 27 | return f1[0].item() 28 | 29 | def compute_rouge_score(rouge_score, reference: str, candidate: str): 30 | """ 31 | Compute the ROUGE score between two strings. 32 | 33 | Parameters: 34 | reference (str): The reference string. 35 | candidate (str): The candidate string to evaluate against the reference. 36 | 37 | Returns: 38 | dict: A dictionary containing ROUGE-1, ROUGE-2, and ROUGE-L scores (F1 scores). 39 | """ 40 | if candidate == "" or reference == "": 41 | return 0 42 | try: 43 | scores = rouge_score.get_scores(candidate, reference, avg=True) 44 | except: 45 | return 0 46 | return scores["rouge-l"]["f"] 47 | 48 | def compute_exact_match(reference: str, candidate: str): 49 | """ 50 | Compute whether two strings are an exact match. 51 | 52 | Parameters: 53 | reference (str): The reference string. 54 | candidate (str): The candidate string to evaluate against the reference. 55 | 56 | Returns: 57 | bool: True if the strings are an exact match, False otherwise. 58 | """ 59 | if candidate == "" or reference == "": 60 | return 0 61 | return reference.strip().lower() in candidate.strip().lower() or candidate.strip().lower() in reference.strip().lower() or reference.strip().lower() in candidate.strip().lower() 62 | 63 | def inverse_reward(sentence): 64 | length = len(sentence) 65 | return 1 / (1 + length) 66 | 67 | def llm_eval(instruction: str, reference: str, candidate: str): 68 | prompt = f"""Use the following criteria to evaluate the quality of the model's response in a knowledge-intensive task, considering the provided ground-truth answer. Assign a score between 0-100 based on the overall quality, relevance, and correctness of the response: 69 | 70 | Relevance to the Prompt (20 points): 71 | 72 | Award up to 20 points if the response aligns well with the user's query, even if minor errors are present. 73 | Deduct points if the response lacks focus or deviates significantly from the query. 74 | Accuracy of Factual Information (20 points): 75 | 76 | Grant up to 20 points for correct factual details aligning with the ground-truth answer. 77 | Penalize for inaccuracies, missing essential elements, or presenting incorrect knowledge. 78 | Handling of Temporal and Logical Reasoning (20 points): 79 | 80 | Award up to 20 points for demonstrating correct temporal and logical reasoning. 81 | Deduct points if temporal reasoning is flawed or logical consistency is missing. 82 | Clarity and Coherence of Response (20 points): 83 | 84 | Assign up to 15 points for clear, coherent, and well-structured responses. 85 | Reduce points for ambiguity, confusion, or poor organization. 86 | Potential Misleading Nature or Misconceptions (20 points): 87 | 88 | Award up to 10 points if the response avoids being misleading. 89 | Penalize responses that could confuse or mislead the user, even if partially relevant. 90 | After evaluating the response based on these criteria, provide a total score in the format: 91 | “Score: points”. 92 | 93 | User: {instruction} 94 | Ground-Truth Answer: {reference} 95 | Model Response: {candidate} 96 | 97 | Score: """ 98 | response = openai.ChatCompletion.create( 99 | model="gpt-4o", 100 | messages=[ 101 | {"role": "system", "content": "You are a helpful Assistant."}, 102 | {"role": "user", "content": prompt} 103 | ] 104 | ) 105 | return response.choices[0].message["content"].replace('.','').strip() 106 | 107 | 108 | 109 | def load_jsonlines(file): 110 | with jsonlines.open(file, 'r') as jsonl_f: 111 | lst = [obj for obj in jsonl_f] 112 | return lst 113 | 114 | def load_file(input_fp): 115 | if input_fp.endswith(".json"): 116 | input_data = json.load(open(input_fp)) 117 | else: 118 | input_data = load_jsonlines(input_fp) 119 | return input_data 120 | 121 | # Example usage 122 | # reference = "Jane Austen" 123 | # candidate = "The author of 'Pride and Prejudice' is Chris Paul." 124 | 125 | # reference = "Eiffel Tower in Paris." 126 | # candidate = "The Eiffel Tower is one of the most famous landmarks in Paris, and it is located in the Champ de Mars." 127 | # candidate = "Eiffel Tower is in the Paris." 128 | 129 | # # Compute BERTScore 130 | # bert_score = compute_bert_score(reference, candidate) 131 | # print(f"BERTScore: {bert_score}") 132 | 133 | # # Compute ROUGE Score 134 | # rouge_score = compute_rouge_score(reference, candidate) 135 | # print(f"ROUGE Score: {rouge_score}") 136 | -------------------------------------------------------------------------------- /src/normalize_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from chemdataextractor.text.normalize 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | Tools for normalizing text. 5 | https://github.com/mcs07/ChemDataExtractor 6 | :copyright: Copyright 2016 by Matt Swain. 7 | :license: MIT 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining 10 | a copy of this software and associated documentation files (the 11 | 'Software'), to deal in the Software without restriction, including 12 | without limitation the rights to use, copy, modify, merge, publish, 13 | distribute, sublicense, and/or sell copies of the Software, and to 14 | permit persons to whom the Software is furnished to do so, subject to 15 | the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be 18 | included in all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | """ 28 | 29 | #: Control characters. 30 | CONTROLS = { 31 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 32 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 33 | } 34 | # There are further control characters, but they are instead replaced with a space by unicode normalization 35 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 36 | 37 | 38 | #: Hyphen and dash characters. 39 | HYPHENS = { 40 | '-', # \u002d Hyphen-minus 41 | '‐', # \u2010 Hyphen 42 | '‑', # \u2011 Non-breaking hyphen 43 | '⁃', # \u2043 Hyphen bullet 44 | '‒', # \u2012 figure dash 45 | '–', # \u2013 en dash 46 | '—', # \u2014 em dash 47 | '―', # \u2015 horizontal bar 48 | } 49 | 50 | #: Minus characters. 51 | MINUSES = { 52 | '-', # \u002d Hyphen-minus 53 | '−', # \u2212 Minus 54 | '-', # \uff0d Full-width Hyphen-minus 55 | '⁻', # \u207b Superscript minus 56 | } 57 | 58 | #: Plus characters. 59 | PLUSES = { 60 | '+', # \u002b Plus 61 | '+', # \uff0b Full-width Plus 62 | '⁺', # \u207a Superscript plus 63 | } 64 | 65 | #: Slash characters. 66 | SLASHES = { 67 | '/', # \u002f Solidus 68 | '⁄', # \u2044 Fraction slash 69 | '∕', # \u2215 Division slash 70 | } 71 | 72 | #: Tilde characters. 73 | TILDES = { 74 | '~', # \u007e Tilde 75 | '˜', # \u02dc Small tilde 76 | '⁓', # \u2053 Swung dash 77 | '∼', # \u223c Tilde operator #in mbert vocab 78 | '∽', # \u223d Reversed tilde 79 | '∿', # \u223f Sine wave 80 | '〜', # \u301c Wave dash #in mbert vocab 81 | '~', # \uff5e Full-width tilde #in mbert vocab 82 | } 83 | 84 | #: Apostrophe characters. 85 | APOSTROPHES = { 86 | "'", # \u0027 87 | '’', # \u2019 88 | '՚', # \u055a 89 | 'Ꞌ', # \ua78b 90 | 'ꞌ', # \ua78c 91 | ''', # \uff07 92 | } 93 | 94 | #: Single quote characters. 95 | SINGLE_QUOTES = { 96 | "'", # \u0027 97 | '‘', # \u2018 98 | '’', # \u2019 99 | '‚', # \u201a 100 | '‛', # \u201b 101 | 102 | } 103 | 104 | #: Double quote characters. 105 | DOUBLE_QUOTES = { 106 | '"', # \u0022 107 | '“', # \u201c 108 | '”', # \u201d 109 | '„', # \u201e 110 | '‟', # \u201f 111 | } 112 | 113 | #: Accent characters. 114 | ACCENTS = { 115 | '`', # \u0060 116 | '´', # \u00b4 117 | } 118 | 119 | #: Prime characters. 120 | PRIMES = { 121 | '′', # \u2032 122 | '″', # \u2033 123 | '‴', # \u2034 124 | '‵', # \u2035 125 | '‶', # \u2036 126 | '‷', # \u2037 127 | '⁗', # \u2057 128 | } 129 | 130 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 131 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 132 | 133 | def normalize(text): 134 | for control in CONTROLS: 135 | text = text.replace(control, '') 136 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 137 | 138 | for hyphen in HYPHENS | MINUSES: 139 | text = text.replace(hyphen, '-') 140 | text = text.replace('\u00ad', '') 141 | 142 | for double_quote in DOUBLE_QUOTES: 143 | text = text.replace(double_quote, '"') # \u0022 144 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 145 | text = text.replace(single_quote, "'") # \u0027 146 | text = text.replace('′', "'") # \u2032 prime 147 | text = text.replace('‵', "'") # \u2035 reversed prime 148 | text = text.replace('″', "''") # \u2033 double prime 149 | text = text.replace('‶', "''") # \u2036 reversed double prime 150 | text = text.replace('‴', "'''") # \u2034 triple prime 151 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 152 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 153 | 154 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 155 | 156 | for slash in SLASHES: 157 | text = text.replace(slash, '/') 158 | 159 | #for tilde in TILDES: 160 | # text = text.replace(tilde, '~') 161 | 162 | return text 163 | -------------------------------------------------------------------------------- /src/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import copy 7 | import transformers 8 | 9 | from src import contriever, dist_utils, utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MoCo(nn.Module): 15 | def __init__(self, opt): 16 | super(MoCo, self).__init__() 17 | 18 | self.queue_size = opt.queue_size 19 | self.momentum = opt.momentum 20 | self.temperature = opt.temperature 21 | self.label_smoothing = opt.label_smoothing 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode 25 | 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | 30 | self.tokenizer = tokenizer 31 | self.encoder_q = retriever 32 | self.encoder_k = copy.deepcopy(retriever) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) 36 | param_k.requires_grad = False 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | def _load_retriever(self, model_id, pooling, random_init): 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | 48 | if "xlm" in model_id: 49 | model_class = contriever.XLMRetriever 50 | else: 51 | model_class = contriever.Contriever 52 | 53 | if random_init: 54 | retriever = model_class(cfg) 55 | else: 56 | retriever = utils.load_hf(model_class, model_id) 57 | 58 | if "bert-" in model_id: 59 | if tokenizer.bos_token_id is None: 60 | tokenizer.bos_token = "[CLS]" 61 | if tokenizer.eos_token_id is None: 62 | tokenizer.eos_token = "[SEP]" 63 | 64 | retriever.config.pooling = pooling 65 | 66 | return retriever, tokenizer 67 | 68 | def get_encoder(self, return_encoder_k=False): 69 | if return_encoder_k: 70 | return self.encoder_k 71 | else: 72 | return self.encoder_q 73 | 74 | def _momentum_update_key_encoder(self): 75 | """ 76 | Update of the key encoder 77 | """ 78 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 79 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum) 80 | 81 | @torch.no_grad() 82 | def _dequeue_and_enqueue(self, keys): 83 | # gather keys before updating queue 84 | keys = dist_utils.gather_nograd(keys.contiguous()) 85 | 86 | batch_size = keys.shape[0] 87 | 88 | ptr = int(self.queue_ptr) 89 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity 90 | 91 | # replace the keys at ptr (dequeue and enqueue) 92 | self.queue[:, ptr : ptr + batch_size] = keys.T 93 | ptr = (ptr + batch_size) % self.queue_size # move pointer 94 | 95 | self.queue_ptr[0] = ptr 96 | 97 | def _compute_logits(self, q, k): 98 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 99 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 100 | 101 | logits = torch.cat([l_pos, l_neg], dim=1) 102 | return logits 103 | 104 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 105 | bsz = q_tokens.size(0) 106 | 107 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 108 | 109 | # compute key features 110 | with torch.no_grad(): # no gradient to keys 111 | self._momentum_update_key_encoder() # update the key encoder 112 | 113 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k: 114 | self.encoder_k.eval() 115 | 116 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 117 | 118 | logits = self._compute_logits(q, k) / self.temperature 119 | 120 | # labels: positive key indicators 121 | labels = torch.zeros(bsz, dtype=torch.long).cuda() 122 | 123 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) 124 | 125 | self._dequeue_and_enqueue(k) 126 | 127 | # log stats 128 | if len(stats_prefix) > 0: 129 | stats_prefix = stats_prefix + "/" 130 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 131 | 132 | predicted_idx = torch.argmax(logits, dim=-1) 133 | accuracy = 100 * (predicted_idx == labels).float().mean() 134 | stdq = torch.std(q, dim=0).mean().item() 135 | stdk = torch.std(k, dim=0).mean().item() 136 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 137 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 138 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 139 | 140 | return loss, iter_stats 141 | -------------------------------------------------------------------------------- /src/contriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | import transformers 6 | from transformers import BertModel, XLMRobertaModel 7 | 8 | from src import utils 9 | 10 | 11 | class Contriever(BertModel): 12 | def __init__(self, config, pooling="average", **kwargs): 13 | super().__init__(config, add_pooling_layer=False) 14 | if not hasattr(config, "pooling"): 15 | self.config.pooling = pooling 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | position_ids=None, 23 | head_mask=None, 24 | inputs_embeds=None, 25 | encoder_hidden_states=None, 26 | encoder_attention_mask=None, 27 | output_attentions=None, 28 | output_hidden_states=None, 29 | normalize=False, 30 | ): 31 | 32 | model_output = super().forward( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | token_type_ids=token_type_ids, 36 | position_ids=position_ids, 37 | head_mask=head_mask, 38 | inputs_embeds=inputs_embeds, 39 | encoder_hidden_states=encoder_hidden_states, 40 | encoder_attention_mask=encoder_attention_mask, 41 | output_attentions=output_attentions, 42 | output_hidden_states=output_hidden_states, 43 | ) 44 | 45 | last_hidden = model_output["last_hidden_state"] 46 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 47 | 48 | if self.config.pooling == "average": 49 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 50 | elif self.config.pooling == "cls": 51 | emb = last_hidden[:, 0] 52 | 53 | if normalize: 54 | emb = torch.nn.functional.normalize(emb, dim=-1) 55 | return emb 56 | 57 | 58 | class XLMRetriever(XLMRobertaModel): 59 | def __init__(self, config, pooling="average", **kwargs): 60 | super().__init__(config, add_pooling_layer=False) 61 | if not hasattr(config, "pooling"): 62 | self.config.pooling = pooling 63 | 64 | def forward( 65 | self, 66 | input_ids=None, 67 | attention_mask=None, 68 | token_type_ids=None, 69 | position_ids=None, 70 | head_mask=None, 71 | inputs_embeds=None, 72 | encoder_hidden_states=None, 73 | encoder_attention_mask=None, 74 | output_attentions=None, 75 | output_hidden_states=None, 76 | normalize=False, 77 | ): 78 | 79 | model_output = super().forward( 80 | input_ids=input_ids, 81 | attention_mask=attention_mask, 82 | token_type_ids=token_type_ids, 83 | position_ids=position_ids, 84 | head_mask=head_mask, 85 | inputs_embeds=inputs_embeds, 86 | encoder_hidden_states=encoder_hidden_states, 87 | encoder_attention_mask=encoder_attention_mask, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | ) 91 | 92 | last_hidden = model_output["last_hidden_state"] 93 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 94 | if self.config.pooling == "average": 95 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 96 | elif self.config.pooling == "cls": 97 | emb = last_hidden[:, 0] 98 | if normalize: 99 | emb = torch.nn.functional.normalize(emb, dim=-1) 100 | return emb 101 | 102 | 103 | def load_retriever(model_path, pooling="average", random_init=False): 104 | # try: check if model exists locally 105 | path = os.path.join(model_path, "checkpoint.pth") 106 | if os.path.exists(path): 107 | pretrained_dict = torch.load(path, map_location="cpu") 108 | opt = pretrained_dict["opt"] 109 | if hasattr(opt, "retriever_model_id"): 110 | retriever_model_id = opt.retriever_model_id 111 | else: 112 | # retriever_model_id = "bert-base-uncased" 113 | retriever_model_id = "bert-base-multilingual-cased" 114 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) 115 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) 116 | if "xlm" in retriever_model_id: 117 | model_class = XLMRetriever 118 | else: 119 | model_class = Contriever 120 | retriever = model_class(cfg) 121 | pretrained_dict = pretrained_dict["model"] 122 | 123 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class 124 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} 125 | elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class 126 | pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} 127 | retriever.load_state_dict(pretrained_dict, strict=False) 128 | else: 129 | retriever_model_id = model_path 130 | if "xlm" in retriever_model_id: 131 | model_class = XLMRetriever 132 | else: 133 | model_class = Contriever 134 | cfg = utils.load_hf(transformers.AutoConfig, model_path) 135 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) 136 | retriever = utils.load_hf(model_class, model_path) 137 | 138 | return retriever, tokenizer, retriever_model_id 139 | -------------------------------------------------------------------------------- /src/finetuning_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import random 5 | import json 6 | import sys 7 | import numpy as np 8 | from src import normalize_text 9 | 10 | 11 | class Dataset(torch.utils.data.Dataset): 12 | def __init__( 13 | self, 14 | datapaths, 15 | negative_ctxs=1, 16 | negative_hard_ratio=0.0, 17 | negative_hard_min_idx=0, 18 | training=False, 19 | global_rank=-1, 20 | world_size=-1, 21 | maxload=None, 22 | normalize=False, 23 | ): 24 | self.negative_ctxs = negative_ctxs 25 | self.negative_hard_ratio = negative_hard_ratio 26 | self.negative_hard_min_idx = negative_hard_min_idx 27 | self.training = training 28 | self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x 29 | self._load_data(datapaths, global_rank, world_size, maxload) 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, index): 35 | example = self.data[index] 36 | question = example["question"] 37 | if self.training: 38 | gold = random.choice(example["positive_ctxs"]) 39 | 40 | n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example) 41 | negatives = [] 42 | if n_random_negatives > 0: 43 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives) 44 | negatives += random_negatives 45 | if n_hard_negatives > 0: 46 | hard_negatives = random.sample( 47 | example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives 48 | ) 49 | negatives += hard_negatives 50 | else: 51 | gold = example["positive_ctxs"][0] 52 | nidx = 0 53 | if "negative_ctxs" in example: 54 | negatives = [example["negative_ctxs"][nidx]] 55 | else: 56 | negatives = [] 57 | 58 | gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"] 59 | 60 | negatives = [ 61 | n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives 62 | ] 63 | 64 | example = { 65 | "query": self.normalize_fn(question), 66 | "gold": self.normalize_fn(gold), 67 | "negatives": [self.normalize_fn(n) for n in negatives], 68 | } 69 | return example 70 | 71 | def _load_data(self, datapaths, global_rank, world_size, maxload): 72 | counter = 0 73 | self.data = [] 74 | for path in datapaths: 75 | path = str(path) 76 | if path.endswith(".jsonl"): 77 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 78 | elif path.endswith(".json"): 79 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 80 | self.data.extend(file_data) 81 | if maxload is not None and maxload > 0 and counter >= maxload: 82 | break 83 | 84 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 85 | examples = [] 86 | with open(path, "r") as fin: 87 | data = json.load(fin) 88 | for example in data: 89 | counter += 1 90 | if global_rank > -1 and not counter % world_size == global_rank: 91 | continue 92 | examples.append(example) 93 | if maxload is not None and maxload > 0 and counter == maxload: 94 | break 95 | 96 | return examples, counter 97 | 98 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 99 | examples = [] 100 | with open(path, "r") as fin: 101 | for line in fin: 102 | counter += 1 103 | if global_rank > -1 and not counter % world_size == global_rank: 104 | continue 105 | example = json.loads(line) 106 | examples.append(example) 107 | if maxload is not None and maxload > 0 and counter == maxload: 108 | break 109 | 110 | return examples, counter 111 | 112 | def sample_n_hard_negatives(self, ex): 113 | 114 | if "hard_negative_ctxs" in ex: 115 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)]) 116 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :])) 117 | else: 118 | n_hard_negatives = 0 119 | n_random_negatives = self.negative_ctxs - n_hard_negatives 120 | if "negative_ctxs" in ex: 121 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"])) 122 | else: 123 | n_random_negatives = 0 124 | return n_hard_negatives, n_random_negatives 125 | 126 | 127 | class Collator(object): 128 | def __init__(self, tokenizer, passage_maxlength=200): 129 | self.tokenizer = tokenizer 130 | self.passage_maxlength = passage_maxlength 131 | 132 | def __call__(self, batch): 133 | queries = [ex["query"] for ex in batch] 134 | golds = [ex["gold"] for ex in batch] 135 | negs = [item for ex in batch for item in ex["negatives"]] 136 | allpassages = golds + negs 137 | 138 | qout = self.tokenizer.batch_encode_plus( 139 | queries, 140 | max_length=self.passage_maxlength, 141 | truncation=True, 142 | padding=True, 143 | add_special_tokens=True, 144 | return_tensors="pt", 145 | ) 146 | kout = self.tokenizer.batch_encode_plus( 147 | allpassages, 148 | max_length=self.passage_maxlength, 149 | truncation=True, 150 | padding=True, 151 | add_special_tokens=True, 152 | return_tensors="pt", 153 | ) 154 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool() 155 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool() 156 | 157 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)] 158 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :] 159 | 160 | batch = { 161 | "q_tokens": q_tokens, 162 | "q_mask": q_mask, 163 | "k_tokens": k_tokens, 164 | "k_mask": k_mask, 165 | "g_tokens": g_tokens, 166 | "g_mask": g_mask, 167 | "n_tokens": n_tokens, 168 | "n_mask": n_mask, 169 | } 170 | 171 | return batch 172 | -------------------------------------------------------------------------------- /sampling_dpo_trajectories.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | from vllm import LLM, SamplingParams 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser(description="Process JSON data using an LLM model.") 8 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.") 9 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.") 10 | parser.add_argument("--input-json", type=str, required=True, help="Path to the input JSON file.") 11 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.") 12 | 13 | args = parser.parse_args() 14 | 15 | with open(args.template, "r") as f: 16 | chat_template = f.read() 17 | 18 | llm = LLM(model=args.llm_model, tensor_parallel_size=1) 19 | sampling_params = SamplingParams(temperature=0.4, n=8, max_tokens=100) 20 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200) 21 | 22 | def get_prompt_docs(query, docs): 23 | retrieved_content = "Retrieved Content:\n" + "\n".join( 24 | [ 25 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}" 26 | for i, doc in enumerate(docs) 27 | ] 28 | ) 29 | return [ 30 | {'role': 'system', 31 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n" 32 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n" 33 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n" 34 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n" 35 | "If no documents are required to answer the query, output \"None\".\n" 36 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n" 37 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")}, 38 | {'role': 'user', 39 | 'content': f"Query: {query}\n\n{retrieved_content}"} 40 | ] 41 | 42 | def get_prompt_answer(prompt, entry): 43 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules: 44 | 1. **Task**: 45 | - Use the provided `[Retrieved Content]` to generate responses. 46 | - If the Retrieved Content is None, you should generate answer based on your own knowledge. 47 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.” 48 | 2. **Requirements**: 49 | - **Accuracy**: Base your answers on the retrieved content. 50 | - **Conciseness**: Keep answers brief and relevant. 51 | - **Context Awareness**: Ensure your responses align with the user’s query. 52 | 3. **Input Format**: 53 | - Query: `[User Query]` 54 | - Retrieved: `[Retrieved Content]` 55 | 4. **Output Format**: 56 | - A structured, clear response tailored to the query. 57 | Always prioritize clarity and reliability.""" 58 | 59 | item_id = entry['id'] 60 | if 'fever' in item_id: 61 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.' 62 | elif 'nq' in item_id: 63 | remember_prompt = 'Please answer the question with a short phrase.' 64 | elif 'hotpotqa' in item_id: 65 | remember_prompt = 'Please answer the question with a short phrase.' 66 | elif 'eli5' in item_id: 67 | remember_prompt = 'Please answer the question with a paragraph.' 68 | elif 'tc' in item_id: 69 | remember_prompt = 'Please answer the question with a short phrase.' 70 | elif 'asqa' in item_id: 71 | remember_prompt = 'Please answer the question with a short phrase.' 72 | else: 73 | remember_prompt = 'Please answer the question with a short phrase.' 74 | 75 | return [ 76 | {'role': 'system', 77 | 'content': system_prompt}, 78 | {'role': 'user', 79 | 'content': (f"{remember_prompt}\n" 80 | f"Query: {entry['question']}\n\n" 81 | f"{prompt}")} 82 | ] 83 | 84 | def generate_docs(entry): 85 | messages = get_prompt_docs(entry['question'], entry['docs'][:40]) 86 | 87 | outputs = llm.chat(messages, 88 | sampling_params=sampling_params, 89 | use_tqdm=False, 90 | chat_template=chat_template) 91 | 92 | return [output.text for output in outputs[0].outputs] 93 | 94 | def map_function(data_list): 95 | result = [] 96 | for item in data_list: 97 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)] 98 | result.append(ids) 99 | return result 100 | 101 | def convert(id_lists, docs): 102 | prompts = [] 103 | for ids in id_lists: 104 | lines = [] 105 | if len(ids) == 0: 106 | retrieved_content = "Retrieved Content: None" 107 | prompts.append(retrieved_content) 108 | continue 109 | for doc_id in ids: 110 | doc_title = docs[doc_id - 1]['title'] 111 | doc_text = docs[doc_id - 1]['text'] 112 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}" 113 | lines.append(line) 114 | 115 | retrieved_content = "\n".join(lines) 116 | prompt = f"Retrieved Content:\n{retrieved_content}" 117 | prompts.append(prompt) 118 | return prompts 119 | 120 | def generate_answers(entry, prompts): 121 | results = [] 122 | for prompt in prompts: 123 | messages = get_prompt_answer(prompt, entry) 124 | 125 | outputs = llm.chat(messages, 126 | sampling_params=sampling_params_generate, 127 | use_tqdm=False, 128 | chat_template=chat_template) 129 | results.append(outputs[0].outputs[0].text) 130 | return results 131 | 132 | def process_json(input_file, output_file): 133 | with open(input_file, 'r', encoding='utf-8') as infile: 134 | data = json.load(infile) 135 | 136 | with open(output_file, 'a', encoding='utf-8') as outfile: 137 | for entry in tqdm(data): 138 | generated_list = generate_docs(entry) 139 | id_lists = map_function(generated_list) 140 | try: 141 | prompts = convert(id_lists, entry['docs']) 142 | except IndexError: 143 | continue 144 | answers = generate_answers(entry, prompts) 145 | 146 | entry['responses'] = answers 147 | entry['id_lists'] = id_lists 148 | import ipdb; ipdb.set_trace() 149 | json.dump(entry, outfile, ensure_ascii=False) 150 | outfile.write('\n') 151 | 152 | if __name__ == "__main__": 153 | process_json(args.input_json, args.output_json) -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | class Options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # basic parameters 14 | self.parser.add_argument( 15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here" 16 | ) 17 | self.parser.add_argument( 18 | "--train_data", 19 | nargs="+", 20 | default=[], 21 | help="Data used for training, passed as a list of directories splitted into tensor files.", 22 | ) 23 | self.parser.add_argument( 24 | "--eval_data", 25 | nargs="+", 26 | default=[], 27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.", 28 | ) 29 | self.parser.add_argument( 30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format" 31 | ) 32 | self.parser.add_argument( 33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored" 34 | ) 35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining") 36 | self.parser.add_argument("--continue_training", action="store_true") 37 | self.parser.add_argument("--num_workers", type=int, default=5) 38 | 39 | self.parser.add_argument("--chunk_length", type=int, default=256) 40 | self.parser.add_argument("--loading_mode", type=str, default="split") 41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing") 42 | self.parser.add_argument( 43 | "--sampling_coefficient", 44 | type=float, 45 | default=0.0, 46 | help="coefficient used for sampling between different datasets during training, \ 47 | by default sampling is uniform over datasets", 48 | ) 49 | self.parser.add_argument("--augmentation", type=str, default="none") 50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0) 51 | 52 | self.parser.add_argument("--dropout", type=float, default=0.1) 53 | self.parser.add_argument("--rho", type=float, default=0.05) 54 | 55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco") 56 | self.parser.add_argument("--queue_size", type=int, default=65536) 57 | self.parser.add_argument("--temperature", type=float, default=1.0) 58 | self.parser.add_argument("--momentum", type=float, default=0.999) 59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true") 60 | self.parser.add_argument("--eval_normalize_text", action="store_true") 61 | self.parser.add_argument("--norm_query", action="store_true") 62 | self.parser.add_argument("--norm_doc", action="store_true") 63 | self.parser.add_argument("--projection_size", type=int, default=768) 64 | 65 | self.parser.add_argument("--ratio_min", type=float, default=0.1) 66 | self.parser.add_argument("--ratio_max", type=float, default=0.5) 67 | self.parser.add_argument("--score_function", type=str, default="dot") 68 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased") 69 | self.parser.add_argument("--pooling", type=str, default="average") 70 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights") 71 | 72 | # dataset parameters 73 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.") 74 | self.parser.add_argument( 75 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation." 76 | ) 77 | self.parser.add_argument("--total_steps", type=int, default=1000) 78 | self.parser.add_argument("--warmup_steps", type=int, default=-1) 79 | 80 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 81 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)") 82 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization") 83 | # training parameters 84 | self.parser.add_argument("--optim", type=str, default="adamw") 85 | self.parser.add_argument("--scheduler", type=str, default="linear") 86 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 87 | self.parser.add_argument( 88 | "--lr_min_ratio", 89 | type=float, 90 | default=0.0, 91 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate", 92 | ) 93 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate") 94 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1") 95 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2") 96 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps") 97 | self.parser.add_argument( 98 | "--log_freq", type=int, default=100, help="log train stats every steps during training" 99 | ) 100 | self.parser.add_argument( 101 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training" 102 | ) 103 | self.parser.add_argument("--save_freq", type=int, default=50000) 104 | self.parser.add_argument("--maxload", type=int, default=None) 105 | self.parser.add_argument("--label_smoothing", type=float, default=0.0) 106 | 107 | # finetuning options 108 | self.parser.add_argument("--negative_ctxs", type=int, default=1) 109 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0) 110 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0) 111 | 112 | def print_options(self, opt): 113 | message = "" 114 | for k, v in sorted(vars(opt).items()): 115 | comment = "" 116 | default = self.parser.get_default(k) 117 | if v != default: 118 | comment = f"\t[default: %s]" % str(default) 119 | message += f"{str(k):>40}: {str(v):<40}{comment}\n" 120 | print(message, flush=True) 121 | model_dir = os.path.join(opt.output_dir, "models") 122 | if not os.path.exists(model_dir): 123 | os.makedirs(os.path.join(opt.output_dir, "models")) 124 | file_name = os.path.join(opt.output_dir, "opt.txt") 125 | with open(file_name, "wt") as opt_file: 126 | opt_file.write(message) 127 | opt_file.write("\n") 128 | 129 | def parse(self): 130 | opt, _ = self.parser.parse_known_args() 131 | # opt = self.parser.parse_args() 132 | return opt 133 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def em(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def f1(prediction, ground_truth): 142 | prediction_tokens = normalize_answer(prediction).split() 143 | ground_truth_tokens = normalize_answer(ground_truth).split() 144 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 145 | num_same = sum(common.values()) 146 | if num_same == 0: 147 | return 0 148 | precision = 1.0 * num_same / len(prediction_tokens) 149 | recall = 1.0 * num_same / len(ground_truth_tokens) 150 | f1 = (2 * precision * recall) / (precision + recall) 151 | return f1 152 | 153 | def f1_score(prediction, ground_truths): 154 | return max([f1(prediction, gt) for gt in ground_truths]) 155 | 156 | def exact_match_score(prediction, ground_truths): 157 | return max([em(prediction, gt) for gt in ground_truths]) 158 | 159 | #################################################### 160 | ######## RETRIEVER EVALUATION ######## 161 | #################################################### 162 | 163 | def eval_batch(scores, inversions, avg_topk, idx_topk): 164 | for k, s in enumerate(scores): 165 | s = s.cpu().numpy() 166 | sorted_idx = np.argsort(-s) 167 | score(sorted_idx, inversions, avg_topk, idx_topk) 168 | 169 | def count_inversions(arr): 170 | inv_count = 0 171 | lenarr = len(arr) 172 | for i in range(lenarr): 173 | for j in range(i + 1, lenarr): 174 | if (arr[i] > arr[j]): 175 | inv_count += 1 176 | return inv_count 177 | 178 | def score(x, inversions, avg_topk, idx_topk): 179 | x = np.array(x) 180 | inversions.append(count_inversions(x)) 181 | for k in avg_topk: 182 | # ratio of passages in the predicted top-k that are 183 | # also in the topk given by gold score 184 | avg_pred_topk = (x[:k] None: 161 | for key, (value, weight) in vals.items(): 162 | self.raw_stats[key] += value * weight 163 | self.total_weights[key] += weight 164 | 165 | @property 166 | def stats(self) -> Dict[str, float]: 167 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()} 168 | 169 | @property 170 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]: 171 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()} 172 | 173 | def reset(self) -> None: 174 | self.raw_stats = defaultdict(float) 175 | self.total_weights = defaultdict(float) 176 | 177 | @property 178 | def average_stats(self) -> Dict[str, float]: 179 | keys = sorted(self.raw_stats.keys()) 180 | if torch.distributed.is_initialized(): 181 | torch.distributed.broadcast_object_list(keys, src=0) 182 | global_dict = {} 183 | for k in keys: 184 | if not k in self.total_weights: 185 | v = 0.0 186 | else: 187 | v = self.raw_stats[k] / self.total_weights[k] 188 | v, _ = dist_utils.weighted_average(v, self.total_weights[k]) 189 | global_dict[k] = v 190 | return global_dict 191 | 192 | 193 | def load_hf(object_class, model_name): 194 | try: 195 | obj = object_class.from_pretrained(model_name, local_files_only=True) 196 | except: 197 | obj = object_class.from_pretrained(model_name, local_files_only=False) 198 | return obj 199 | 200 | 201 | def init_tb_logger(output_dir): 202 | try: 203 | from torch.utils import tensorboard 204 | 205 | if dist_utils.is_main(): 206 | tb_logger = tensorboard.SummaryWriter(output_dir) 207 | else: 208 | tb_logger = None 209 | except: 210 | logger.warning("Tensorboard is not available.") 211 | tb_logger = None 212 | 213 | return tb_logger 214 | -------------------------------------------------------------------------------- /src/beir_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | from collections import defaultdict 5 | from typing import List, Dict 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | 10 | import beir.util 11 | from beir.datasets.data_loader import GenericDataLoader 12 | from beir.retrieval.evaluation import EvaluateRetrieval 13 | from beir.retrieval.search.dense import DenseRetrievalExactSearch 14 | 15 | from beir.reranking.models import CrossEncoder 16 | from beir.reranking import Rerank 17 | 18 | import src.dist_utils as dist_utils 19 | from src import normalize_text 20 | 21 | 22 | class DenseEncoderModel: 23 | def __init__( 24 | self, 25 | query_encoder, 26 | doc_encoder=None, 27 | tokenizer=None, 28 | max_length=512, 29 | add_special_tokens=True, 30 | norm_query=False, 31 | norm_doc=False, 32 | lower_case=False, 33 | normalize_text=False, 34 | **kwargs, 35 | ): 36 | self.query_encoder = query_encoder 37 | self.doc_encoder = doc_encoder 38 | self.tokenizer = tokenizer 39 | self.max_length = max_length 40 | self.add_special_tokens = add_special_tokens 41 | self.norm_query = norm_query 42 | self.norm_doc = norm_doc 43 | self.lower_case = lower_case 44 | self.normalize_text = normalize_text 45 | 46 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray: 47 | 48 | if dist.is_initialized(): 49 | idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()] 50 | else: 51 | idx = range(len(queries)) 52 | 53 | queries = [queries[i] for i in idx] 54 | if self.normalize_text: 55 | queries = [normalize_text.normalize(q) for q in queries] 56 | if self.lower_case: 57 | queries = [q.lower() for q in queries] 58 | 59 | allemb = [] 60 | nbatch = (len(queries) - 1) // batch_size + 1 61 | with torch.no_grad(): 62 | for k in range(nbatch): 63 | start_idx = k * batch_size 64 | end_idx = min((k + 1) * batch_size, len(queries)) 65 | 66 | qencode = self.tokenizer.batch_encode_plus( 67 | queries[start_idx:end_idx], 68 | max_length=self.max_length, 69 | padding=True, 70 | truncation=True, 71 | add_special_tokens=self.add_special_tokens, 72 | return_tensors="pt", 73 | ) 74 | qencode = {key: value.cuda() for key, value in qencode.items()} 75 | emb = self.query_encoder(**qencode, normalize=self.norm_query) 76 | allemb.append(emb.cpu()) 77 | 78 | allemb = torch.cat(allemb, dim=0) 79 | allemb = allemb.cuda() 80 | if dist.is_initialized(): 81 | allemb = dist_utils.varsize_gather_nograd(allemb) 82 | allemb = allemb.cpu().numpy() 83 | return allemb 84 | 85 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs): 86 | 87 | if dist.is_initialized(): 88 | idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()] 89 | else: 90 | idx = range(len(corpus)) 91 | corpus = [corpus[i] for i in idx] 92 | corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus] 93 | if self.normalize_text: 94 | corpus = [normalize_text.normalize(c) for c in corpus] 95 | if self.lower_case: 96 | corpus = [c.lower() for c in corpus] 97 | 98 | allemb = [] 99 | nbatch = (len(corpus) - 1) // batch_size + 1 100 | with torch.no_grad(): 101 | for k in range(nbatch): 102 | start_idx = k * batch_size 103 | end_idx = min((k + 1) * batch_size, len(corpus)) 104 | 105 | cencode = self.tokenizer.batch_encode_plus( 106 | corpus[start_idx:end_idx], 107 | max_length=self.max_length, 108 | padding=True, 109 | truncation=True, 110 | add_special_tokens=self.add_special_tokens, 111 | return_tensors="pt", 112 | ) 113 | cencode = {key: value.cuda() for key, value in cencode.items()} 114 | emb = self.doc_encoder(**cencode, normalize=self.norm_doc) 115 | allemb.append(emb.cpu()) 116 | 117 | allemb = torch.cat(allemb, dim=0) 118 | allemb = allemb.cuda() 119 | if dist.is_initialized(): 120 | allemb = dist_utils.varsize_gather_nograd(allemb) 121 | allemb = allemb.cpu().numpy() 122 | return allemb 123 | 124 | 125 | def evaluate_model( 126 | query_encoder, 127 | doc_encoder, 128 | tokenizer, 129 | dataset, 130 | batch_size=128, 131 | add_special_tokens=True, 132 | norm_query=False, 133 | norm_doc=False, 134 | is_main=True, 135 | split="test", 136 | score_function="dot", 137 | beir_dir="BEIR/datasets", 138 | save_results_path=None, 139 | lower_case=False, 140 | normalize_text=False, 141 | ): 142 | 143 | metrics = defaultdict(list) # store final results 144 | 145 | if hasattr(query_encoder, "module"): 146 | query_encoder = query_encoder.module 147 | query_encoder.eval() 148 | 149 | if doc_encoder is not None: 150 | if hasattr(doc_encoder, "module"): 151 | doc_encoder = doc_encoder.module 152 | doc_encoder.eval() 153 | else: 154 | doc_encoder = query_encoder 155 | 156 | dmodel = DenseRetrievalExactSearch( 157 | DenseEncoderModel( 158 | query_encoder=query_encoder, 159 | doc_encoder=doc_encoder, 160 | tokenizer=tokenizer, 161 | add_special_tokens=add_special_tokens, 162 | norm_query=norm_query, 163 | norm_doc=norm_doc, 164 | lower_case=lower_case, 165 | normalize_text=normalize_text, 166 | ), 167 | batch_size=batch_size, 168 | ) 169 | retriever = EvaluateRetrieval(dmodel, score_function=score_function) 170 | data_path = os.path.join(beir_dir, dataset) 171 | 172 | if not os.path.isdir(data_path) and is_main: 173 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 174 | data_path = beir.util.download_and_unzip(url, beir_dir) 175 | dist_utils.barrier() 176 | 177 | if not dataset == "cqadupstack": 178 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 179 | results = retriever.retrieve(corpus, queries) 180 | if is_main: 181 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 182 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 183 | if isinstance(metric, str): 184 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 185 | for key, value in metric.items(): 186 | metrics[key].append(value) 187 | if save_results_path is not None: 188 | torch.save(results, f"{save_results_path}") 189 | elif dataset == "cqadupstack": # compute macroaverage over datasets 190 | paths = glob.glob(data_path) 191 | for path in paths: 192 | corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split) 193 | results = retriever.retrieve(corpus, queries) 194 | if is_main: 195 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 196 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 197 | if isinstance(metric, str): 198 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 199 | for key, value in metric.items(): 200 | metrics[key].append(value) 201 | for key, value in metrics.items(): 202 | assert ( 203 | len(value) == 12 204 | ), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric" 205 | 206 | metrics = {key: 100 * np.mean(value) for key, value in metrics.items()} 207 | 208 | return metrics 209 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import glob 5 | import torch 6 | import random 7 | import json 8 | import csv 9 | import numpy as np 10 | import numpy.random 11 | import logging 12 | from collections import defaultdict 13 | import torch.distributed as dist 14 | 15 | from src import dist_utils 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def load_data(opt, tokenizer): 21 | datasets = {} 22 | for path in opt.train_data: 23 | data = load_dataset(path, opt.loading_mode) 24 | if data is not None: 25 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt) 26 | dataset = MultiDataset(datasets) 27 | dataset.set_prob(coeff=opt.sampling_coefficient) 28 | return dataset 29 | 30 | 31 | def load_dataset(data_path, loading_mode): 32 | files = glob.glob(os.path.join(data_path, "*.p*")) 33 | files.sort() 34 | tensors = [] 35 | if loading_mode == "split": 36 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()] 37 | for filepath in files_split: 38 | try: 39 | tensors.append(torch.load(filepath, map_location="cpu")) 40 | except: 41 | logger.warning(f"Unable to load file {filepath}") 42 | elif loading_mode == "full": 43 | for fin in files: 44 | tensors.append(torch.load(fin, map_location="cpu")) 45 | elif loading_mode == "single": 46 | tensors.append(torch.load(files[0], map_location="cpu")) 47 | if len(tensors) == 0: 48 | return None 49 | tensor = torch.cat(tensors) 50 | return tensor 51 | 52 | 53 | class MultiDataset(torch.utils.data.Dataset): 54 | def __init__(self, datasets): 55 | 56 | self.datasets = datasets 57 | self.prob = [1 / len(self.datasets) for _ in self.datasets] 58 | self.dataset_ids = list(self.datasets.keys()) 59 | 60 | def __len__(self): 61 | return sum([len(dataset) for dataset in self.datasets.values()]) 62 | 63 | def __getitem__(self, index): 64 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0] 65 | did = self.dataset_ids[dataset_idx] 66 | index = random.randint(0, len(self.datasets[did]) - 1) 67 | sample = self.datasets[did][index] 68 | sample["dataset_id"] = did 69 | return sample 70 | 71 | def generate_offset(self): 72 | for dataset in self.datasets.values(): 73 | dataset.generate_offset() 74 | 75 | def set_prob(self, coeff=0.0): 76 | 77 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()]) 78 | prob /= prob.sum() 79 | prob = np.array([p**coeff for p in prob]) 80 | prob /= prob.sum() 81 | self.prob = prob 82 | 83 | 84 | class Dataset(torch.utils.data.Dataset): 85 | """Monolingual dataset based on a list of paths""" 86 | 87 | def __init__(self, data, chunk_length, tokenizer, opt): 88 | 89 | self.data = data 90 | self.chunk_length = chunk_length 91 | self.tokenizer = tokenizer 92 | self.opt = opt 93 | self.generate_offset() 94 | 95 | def __len__(self): 96 | return (self.data.size(0) - self.offset) // self.chunk_length 97 | 98 | def __getitem__(self, index): 99 | start_idx = self.offset + index * self.chunk_length 100 | end_idx = start_idx + self.chunk_length 101 | tokens = self.data[start_idx:end_idx] 102 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 103 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 104 | q_tokens = apply_augmentation(q_tokens, self.opt) 105 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 106 | k_tokens = apply_augmentation(k_tokens, self.opt) 107 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 108 | 109 | return {"q_tokens": q_tokens, "k_tokens": k_tokens} 110 | 111 | def generate_offset(self): 112 | self.offset = random.randint(0, self.chunk_length - 1) 113 | 114 | 115 | class Collator(object): 116 | def __init__(self, opt): 117 | self.opt = opt 118 | 119 | def __call__(self, batch_examples): 120 | 121 | batch = defaultdict(list) 122 | for example in batch_examples: 123 | for k, v in example.items(): 124 | batch[k].append(v) 125 | 126 | q_tokens, q_mask = build_mask(batch["q_tokens"]) 127 | k_tokens, k_mask = build_mask(batch["k_tokens"]) 128 | 129 | batch["q_tokens"] = q_tokens 130 | batch["q_mask"] = q_mask 131 | batch["k_tokens"] = k_tokens 132 | batch["k_mask"] = k_mask 133 | 134 | return batch 135 | 136 | 137 | def randomcrop(x, ratio_min, ratio_max): 138 | 139 | ratio = random.uniform(ratio_min, ratio_max) 140 | length = int(len(x) * ratio) 141 | start = random.randint(0, len(x) - length) 142 | end = start + length 143 | crop = x[start:end].clone() 144 | return crop 145 | 146 | 147 | def build_mask(tensors): 148 | shapes = [x.shape for x in tensors] 149 | maxlength = max([len(x) for x in tensors]) 150 | returnmasks = [] 151 | ids = [] 152 | for k, x in enumerate(tensors): 153 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x)))) 154 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x)))))) 155 | ids = torch.stack(ids, dim=0).long() 156 | returnmasks = torch.stack(returnmasks, dim=0).bool() 157 | return ids, returnmasks 158 | 159 | 160 | def add_token(x, token): 161 | x = torch.cat((torch.tensor([token]), x)) 162 | return x 163 | 164 | 165 | def deleteword(x, p=0.1): 166 | mask = np.random.rand(len(x)) 167 | x = [e for e, m in zip(x, mask) if m > p] 168 | return x 169 | 170 | 171 | def replaceword(x, min_random, max_random, p=0.1): 172 | mask = np.random.rand(len(x)) 173 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)] 174 | return x 175 | 176 | 177 | def maskword(x, mask_id, p=0.1): 178 | mask = np.random.rand(len(x)) 179 | x = [e if m > p else mask_id for e, m in zip(x, mask)] 180 | return x 181 | 182 | 183 | def shuffleword(x, p=0.1): 184 | count = (np.random.rand(len(x)) < p).sum() 185 | """Shuffles any n number of values in a list""" 186 | indices_to_shuffle = random.sample(range(len(x)), k=count) 187 | to_shuffle = [x[i] for i in indices_to_shuffle] 188 | random.shuffle(to_shuffle) 189 | for index, value in enumerate(to_shuffle): 190 | old_index = indices_to_shuffle[index] 191 | x[old_index] = value 192 | return x 193 | 194 | 195 | def apply_augmentation(x, opt): 196 | if opt.augmentation == "mask": 197 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation)) 198 | elif opt.augmentation == "replace": 199 | return torch.tensor( 200 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation) 201 | ) 202 | elif opt.augmentation == "delete": 203 | return torch.tensor(deleteword(x, p=opt.prob_augmentation)) 204 | elif opt.augmentation == "shuffle": 205 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation)) 206 | else: 207 | if not isinstance(x, torch.Tensor): 208 | x = torch.Tensor(x) 209 | return x 210 | 211 | 212 | def add_bos_eos(x, bos_token_id, eos_token_id): 213 | if not isinstance(x, torch.Tensor): 214 | x = torch.Tensor(x) 215 | if bos_token_id is None and eos_token_id is not None: 216 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])]) 217 | elif bos_token_id is not None and eos_token_id is None: 218 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()]) 219 | elif bos_token_id is None and eos_token_id is None: 220 | pass 221 | else: 222 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])]) 223 | return x 224 | 225 | 226 | # Used for passage retrieval 227 | def load_passages(path): 228 | if not os.path.exists(path): 229 | logger.info(f"{path} does not exist") 230 | return 231 | logger.info(f"Loading passages from: {path}") 232 | passages = [] 233 | with open(path) as fin: 234 | if path.endswith(".jsonl"): 235 | for k, line in enumerate(fin): 236 | ex = json.loads(line) 237 | passages.append(ex) 238 | else: 239 | reader = csv.reader(fin, delimiter="\t") 240 | for k, row in enumerate(reader): 241 | if not row[0] == "id": 242 | ex = {"id": row[0], "title": row[2], "text": row[1]} 243 | passages.append(ex) 244 | return passages 245 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from tqdm import tqdm 4 | from vllm import LLM, SamplingParams 5 | import argparse 6 | import os 7 | import gc 8 | import torch 9 | from vllm.distributed.parallel_state import destroy_model_parallel 10 | 11 | parser = argparse.ArgumentParser(description="Process JSONL data using an LLM model.") 12 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.") 13 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.") 14 | parser.add_argument("--input-jsonl", type=str, required=True, help="Path to the input JSONL file.") 15 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.") 16 | parser.add_argument("--remain-output-json", type=str, required=True, help="Path to the remain output JSONL file.") 17 | args = parser.parse_args() 18 | 19 | 20 | 21 | def get_prompt_docs(query, docs): 22 | retrieved_content = "Retrieved Content:\n" + "\n".join( 23 | [ 24 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}" 25 | for i, doc in enumerate(docs) 26 | ] 27 | ) 28 | return [ 29 | {'role': 'system', 30 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n" 31 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n" 32 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n" 33 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n" 34 | "If no documents are required to answer the query, output \"None\".\n" 35 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n" 36 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")}, 37 | {'role': 'user', 38 | 'content': f"Query: {query}\n{retrieved_content}"} 39 | ] 40 | 41 | 42 | def get_prompt_answer(input_file, prompt, entry): 43 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules: 44 | 1. **Task**: 45 | - Use the provided `[Retrieved Content]` to generate responses. 46 | - If the Retrieved Content is None, you should generate answer based on your own knowledge. 47 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.” 48 | 2. **Requirements**: 49 | - **Accuracy**: Base your answers on the retrieved content. 50 | - **Conciseness**: Keep answers brief and relevant. 51 | - **Context Awareness**: Ensure your responses align with the user’s query. 52 | 3. **Input Format**: 53 | - Query: `[User Query]` 54 | - Retrieved: `[Retrieved Content]` 55 | 4. **Output Format**: 56 | - A structured, clear response tailored to the query. 57 | Always prioritize clarity and reliability.""" 58 | 59 | if 'fever' in input_file: 60 | remember_prompt = 'Please answer the question with “SUPPORTS”, “REFUTES” or “NEI” based on what you know.' 61 | elif 'nq' in input_file: 62 | remember_prompt = 'Please answer the question with a short phrase.' 63 | elif 'hotpotqa' in input_file: 64 | remember_prompt = 'Please answer the question with a short phrase.' 65 | elif 'eli5' in input_file: 66 | remember_prompt = 'Please answer the question with a paragraph.' 67 | elif 'triviaqa' in input_file: 68 | remember_prompt = 'Please answer the question with a short phrase.' 69 | elif '2wikimqa' in input_file: 70 | remember_prompt = 'Please answer the question with a short phrase.' 71 | elif 'arc' in input_file: 72 | remember_prompt = 'Please answer the following questions and directly output the answer options.' 73 | elif 'asqa' in input_file: 74 | remember_prompt = 'Please answer the question with a short phrase.' 75 | elif 'popqa' in input_file: 76 | remember_prompt = 'Please answer the question with a short phrase.' 77 | else: 78 | remember_prompt = 'Please answer the question with a short phrase.' 79 | 80 | return [ 81 | {'role': 'system', 82 | 'content': system_prompt}, 83 | {'role': 'user', 84 | 'content': (f"{remember_prompt}\n\n" 85 | f"Query: {entry['question']}\n\n" 86 | f"{prompt}")} 87 | ] 88 | 89 | 90 | def generate_docs(llm, entry): 91 | messages = get_prompt_docs(entry['question'], entry['ctxs']) 92 | 93 | outputs = llm.chat(messages, 94 | sampling_params=sampling_params, 95 | use_tqdm=False, 96 | chat_template=chat_template) 97 | 98 | return [output.text for output in outputs[0].outputs] 99 | 100 | 101 | def map_function(data_list): 102 | result = [] 103 | for item in data_list: 104 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)] 105 | result.append(ids) 106 | return result 107 | 108 | 109 | def convert(id_lists, docs): 110 | prompts = [] 111 | for ids in id_lists: 112 | lines = [] 113 | if len(ids) == 0: 114 | retrieved_content = "Retrieved Content: None" 115 | prompts.append(retrieved_content) 116 | continue 117 | for doc_id in ids: 118 | doc_title = docs[doc_id - 1]['title'] 119 | doc_text = docs[doc_id - 1]['text'] 120 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}" 121 | lines.append(line) 122 | 123 | retrieved_content = "\n".join(lines) 124 | prompt = f"Retrieved Content:\n{retrieved_content}" 125 | prompts.append(prompt) 126 | return prompts 127 | 128 | 129 | def generate_answers(input_file, llm, entry, prompts): 130 | results = [] 131 | for prompt in prompts: 132 | messages = get_prompt_answer(input_file, prompt, entry) 133 | 134 | outputs = llm.chat(messages, 135 | sampling_params=sampling_params_generate, 136 | use_tqdm=False, 137 | chat_template=chat_template) 138 | results.append(outputs[0].outputs[0].text) 139 | return results 140 | 141 | def process_jsonl(llm, input_file, output_file, remain_output_file): 142 | """ 143 | Main processing pipeline to read input JSONL, process data, and write output JSON and remain JSONL. 144 | """ 145 | result = [] 146 | result_remain = [] 147 | 148 | with open(input_file, 'r', encoding='utf-8') as infile: 149 | datas = [json.loads(line.strip()) for line in infile] 150 | 151 | for entry in tqdm(datas, total=len(datas), desc="Processing"): 152 | if 'asqa' in input_file: 153 | entry['ctxs'] = entry['docs'] 154 | # for 4k window, entry['ctxs'][:20] 155 | # for 8k window, entry['ctxs'][:40] 156 | entry['ctxs'] = entry['ctxs'][:40] 157 | 158 | 159 | ranking_string = generate_docs(llm, entry) 160 | id_list = map_function(ranking_string) 161 | 162 | try: 163 | prompt = convert(id_list, entry['ctxs']) 164 | except IndexError: 165 | result_remain.append(entry) 166 | continue 167 | 168 | answer = generate_answers(input_file, llm, entry, prompt) 169 | entry['response'] = answer 170 | entry['id_list'] = id_list 171 | result.append(entry) 172 | 173 | output_dir = os.path.dirname(output_file) 174 | remain_dir = os.path.dirname(remain_output_file) 175 | if output_dir: 176 | os.makedirs(output_dir, exist_ok=True) 177 | if remain_dir: 178 | os.makedirs(remain_dir, exist_ok=True) 179 | 180 | with open(output_file, 'w', encoding='utf-8') as outfile: 181 | json.dump(result, outfile, indent=4) 182 | 183 | 184 | with open(remain_output_file, 'w', encoding='utf-8') as outfile: 185 | json.dump(result_remain, outfile, indent=4) 186 | 187 | 188 | 189 | if __name__ == "__main__": 190 | 191 | with open(args.template, "r") as f: 192 | chat_template = f.read() 193 | 194 | llm = LLM(model=args.llm_model, tensor_parallel_size=8) 195 | sampling_params = SamplingParams(temperature=0.4, max_tokens=100) 196 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200) 197 | process_jsonl(llm, args.input_jsonl, args.output_json, args.remain_output_json) 198 | destroy_model_parallel() 199 | del llm 200 | gc.collect() 201 | torch.cuda.empty_cache() 202 | torch.distributed.destroy_process_group() 203 | print("Successfully delete the llm pipeline and free the GPU memory!") 204 | 205 | 206 | -------------------------------------------------------------------------------- /top_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from tqdm import tqdm 4 | from vllm import LLM, SamplingParams 5 | import argparse 6 | import os 7 | import gc 8 | import torch 9 | from vllm.distributed.parallel_state import destroy_model_parallel 10 | 11 | def get_prompt_docs(query, docs): 12 | retrieved_content = "Retrieved Content:\n" + "\n".join( 13 | [ 14 | f"{i+1}. Topic: {doc['title']}\nContent: {doc['text']}" 15 | for i, doc in enumerate(docs) 16 | ] 17 | ) 18 | return [ 19 | {'role': 'system', 20 | 'content': ("You are an expert at dynamically generating document identifiers to answer a given query.\n" 21 | "I will provide you with a set of documents, each uniquely identified by a number within square brackets, e.g., [1], [2], etc.\n" 22 | f"Your task is to identify and generate only the identifiers of the documents that contain sufficient information to answer the query.\n" 23 | "Stop generating identifiers as soon as the selected documents collectively provide enough information to answer the query.\n" 24 | "If no documents are required to answer the query, output \"None\".\n" 25 | "Output the identifiers as a comma-separated list, e.g., [1], [2] or \"None\" if no documents are needed.\n" 26 | "Focus solely on providing the identifiers. Do not include any explanations, descriptions, or additional text.")}, 27 | {'role': 'user', 28 | 'content': f"Query: {query}\n{retrieved_content}"} 29 | ] 30 | 31 | def map_function(data_list): 32 | result = [] 33 | for item in data_list: 34 | ids = [int(num) for num in re.findall(r'\[(\d+)\]', item)] 35 | result.append(ids) 36 | return result 37 | 38 | def convert(id_lists, docs): 39 | prompts = [] 40 | for ids in id_lists: 41 | lines = [] 42 | if len(ids) == 0: 43 | retrieved_content = "Retrieved Content: None" 44 | prompts.append(retrieved_content) 45 | continue 46 | for doc_id in ids: 47 | doc_title = docs[doc_id - 1]['title'] 48 | doc_text = docs[doc_id - 1]['text'] 49 | line = f"{doc_id}. Title: {doc_title}.\nContent: {doc_text}" 50 | lines.append(line) 51 | 52 | retrieved_content = "\n".join(lines) 53 | prompt = f"Retrieved Content:\n{retrieved_content}" 54 | prompts.append(prompt) 55 | return prompts 56 | 57 | def generate_docs(llm, query, docs, sampling_params, chat_template): 58 | messages = get_prompt_docs(query, docs) 59 | outputs = llm.chat(messages, 60 | sampling_params=sampling_params, 61 | use_tqdm=False, 62 | chat_template=chat_template) 63 | return [output.text for output in outputs[0].outputs] 64 | 65 | def generate_final_docs(llm, query, all_docs, max_context_window, sampling_params, chat_template): 66 | """ 67 | Process all documents in batches, reduce to a manageable size (<= max_context_window), and return the final set of docs. 68 | """ 69 | current_docs = [] 70 | for i in range(0, len(all_docs), max_context_window): 71 | batch_docs = all_docs[i:i + max_context_window] 72 | ranking_strings = generate_docs(llm, query, batch_docs, sampling_params, chat_template) 73 | id_list = map_function(ranking_strings) 74 | for ids in id_list: 75 | for doc_id in ids: 76 | if all_docs[doc_id - 1] not in current_docs: 77 | current_docs.append(all_docs[doc_id - 1]) 78 | 79 | while len(current_docs) > max_context_window: 80 | reduced_docs = [] 81 | for i in range(0, len(current_docs), max_context_window): 82 | batch_docs = current_docs[i:i + max_context_window] 83 | ranking_strings = generate_docs(llm, query, batch_docs, sampling_params, chat_template) 84 | id_list = map_function(ranking_strings) 85 | 86 | for ids in id_list: 87 | for doc_id in ids: 88 | if doc_id - 1 < len(batch_docs): 89 | reduced_docs.append(batch_docs[doc_id - 1]) 90 | 91 | current_docs = list(set(reduced_docs)) 92 | 93 | if len(current_docs) > max_context_window and len(reduced_docs) == 0: 94 | raise ValueError("Failed to reduce document count below the context window. Check the LLM output.") 95 | 96 | return current_docs 97 | 98 | def generate_answers(llm, query, final_docs, sampling_params_generate, chat_template): 99 | prompt = convert([[i + 1 for i in range(len(final_docs))]], final_docs)[0] 100 | system_prompt = """You are an intelligent assistant that uses retrieved knowledge to answer user queries accurately and concisely. Follow these rules: 101 | 1. **Task**: 102 | - Use the provided `[Retrieved Content]` to generate responses. 103 | - If the Retrieved Content is None, you should generate answer based on your own knowledge. 104 | - If the information is insufficient or you don't know the answer, state, “I cannot fully answer based on the available information. Please provide more details.” 105 | 2. **Requirements**: 106 | - **Accuracy**: Base your answers on the retrieved content. 107 | - **Conciseness**: Keep answers brief and relevant. 108 | - **Context Awareness**: Ensure your responses align with the user’s query. 109 | 3. **Input Format**: 110 | - Query: `[User Query]` 111 | - Retrieved: `[Retrieved Content]` 112 | 4. **Output Format**: 113 | - A structured, clear response tailored to the query. 114 | Always prioritize clarity and reliability.""" 115 | remember_prompt = 'Please answer the question with a short phrase.' 116 | messages = [ 117 | {'role': 'system', 118 | 'content': system_prompt}, 119 | {'role': 'user', 120 | 'content': (f"{remember_prompt}\n\n" 121 | f"Query: {query}\n\n" 122 | f"{prompt}")} 123 | ] 124 | outputs = llm.chat(messages, 125 | sampling_params=sampling_params_generate, 126 | use_tqdm=False, 127 | chat_template=chat_template) 128 | return outputs[0].outputs[0].text 129 | 130 | def process_jsonl(llm, input_file, output_file, remain_output_file, max_context_window): 131 | result = [] 132 | result_remain = [] 133 | with open(input_file, 'r', encoding='utf-8') as infile: 134 | datas = [json.loads(line.strip()) for line in infile] 135 | 136 | for entry in tqdm(datas, total=len(datas), desc="Processing"): 137 | try: 138 | entry['ctxs'] = entry['ctxs'][:args.topn] 139 | final_docs = generate_final_docs(llm, entry['question'], entry['ctxs'], max_context_window, sampling_params, chat_template) 140 | answer = generate_answers(llm, entry['question'], final_docs, sampling_params_generate, chat_template) 141 | entry['response'] = answer 142 | entry['final_docs'] = final_docs 143 | result.append(entry) 144 | except Exception as e: 145 | entry['error'] = str(e) 146 | result_remain.append(entry) 147 | 148 | with open(output_file, 'w', encoding='utf-8') as outfile: 149 | json.dump(result, outfile, indent=4) 150 | 151 | with open(remain_output_file, 'w', encoding='utf-8') as outfile: 152 | json.dump(result_remain, outfile, indent=4) 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser(description="Process JSONL data using an LLM model.") 156 | parser.add_argument("--template", type=str, required=True, help="Path to the chat template file.") 157 | parser.add_argument("--llm-model", type=str, required=True, help="Path to the LLM model checkpoint.") 158 | parser.add_argument("--input-jsonl", type=str, required=True, help="Path to the input JSONL file.") 159 | parser.add_argument("--output-json", type=str, required=True, help="Path to the output JSON file.") 160 | parser.add_argument("--remain-output-json", type=str, required=True, help="Path to the remain output JSONL file.") 161 | parser.add_argument("--max-context-window", type=int, default=40, help="Maximum number of documents per context window.") 162 | parser.add_argument("--topn", type=int, default=300, help="Maximum number of documents per context window.") 163 | args = parser.parse_args() 164 | 165 | with open(args.template, "r") as f: 166 | chat_template = f.read() 167 | 168 | llm = LLM(model=args.llm_model, tensor_parallel_size=8) 169 | sampling_params = SamplingParams(temperature=0.4, max_tokens=100) 170 | sampling_params_generate = SamplingParams(temperature=0, max_tokens=200) 171 | 172 | process_jsonl(llm, args.input_jsonl, args.output_json, args.remain_output_json, args.max_context_window) 173 | 174 | destroy_model_parallel() 175 | del llm 176 | gc.collect() 177 | torch.cuda.empty_cache() 178 | torch.distributed.destroy_process_group() 179 | print("Successfully deleted the LLM pipeline and freed GPU memory!") 180 | -------------------------------------------------------------------------------- /retriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import argparse 9 | import json 10 | import pickle 11 | import time 12 | import glob 13 | from pathlib import Path 14 | from tqdm import tqdm 15 | import numpy as np 16 | import torch 17 | import transformers 18 | 19 | import src.index 20 | import src.contriever 21 | import src.utils 22 | import src.slurm 23 | import src.data 24 | from src.evaluation import calculate_matches 25 | import src.normalize_text 26 | 27 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 28 | 29 | 30 | class Retriever: 31 | def __init__(self, args, model=None, tokenizer=None) : 32 | self.args = args 33 | self.model = model 34 | self.tokenizer = tokenizer 35 | 36 | def embed_queries(self, args, queries): 37 | embeddings, batch_question = [], [] 38 | with torch.no_grad(): 39 | for k, q in enumerate(queries): 40 | if args.lowercase: 41 | q = q.lower() 42 | if args.normalize_text: 43 | q = src.normalize_text.normalize(q) 44 | batch_question.append(q) 45 | 46 | if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1: 47 | 48 | encoded_batch = self.tokenizer.batch_encode_plus( 49 | batch_question, 50 | return_tensors="pt", 51 | max_length=args.question_maxlength, 52 | padding=True, 53 | truncation=True, 54 | ) 55 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 56 | output = self.model(**encoded_batch) 57 | embeddings.append(output.cpu()) 58 | 59 | batch_question = [] 60 | 61 | embeddings = torch.cat(embeddings, dim=0) 62 | print(f"Questions embeddings shape: {embeddings.size()}") 63 | 64 | return embeddings.numpy() 65 | 66 | 67 | def embed_queries_demo(self, queries): 68 | embeddings, batch_question = [], [] 69 | with torch.no_grad(): 70 | for k, q in enumerate(queries): 71 | batch_question.append(q) 72 | 73 | if len(batch_question) == 16 or k == len(queries) - 1: 74 | 75 | encoded_batch = self.tokenizer.batch_encode_plus( 76 | batch_question, 77 | return_tensors="pt", 78 | max_length=200, 79 | padding=True, 80 | truncation=True, 81 | ) 82 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 83 | output = self.model(**encoded_batch) 84 | embeddings.append(output.cpu()) 85 | 86 | batch_question = [] 87 | 88 | embeddings = torch.cat(embeddings, dim=0) 89 | print(f"Questions embeddings shape: {embeddings.size()}") 90 | 91 | return embeddings.numpy() 92 | 93 | def index_encoded_data(self, index, embedding_files, indexing_batch_size): 94 | allids = [] 95 | allembeddings = np.array([]) 96 | for i, file_path in enumerate(embedding_files): 97 | print(f"Loading file {file_path}") 98 | with open(file_path, "rb") as fin: 99 | ids, embeddings = pickle.load(fin) 100 | 101 | allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings 102 | allids.extend(ids) 103 | while allembeddings.shape[0] > indexing_batch_size: 104 | allembeddings, allids = self.add_embeddings(index, allembeddings, allids, indexing_batch_size) 105 | 106 | while allembeddings.shape[0] > 0: 107 | allembeddings, allids = self.add_embeddings(index, allembeddings, allids, indexing_batch_size) 108 | 109 | print("Data indexing completed.") 110 | 111 | 112 | def add_embeddings(self, index, embeddings, ids, indexing_batch_size): 113 | end_idx = min(indexing_batch_size, embeddings.shape[0]) 114 | ids_toadd = ids[:end_idx] 115 | embeddings_toadd = embeddings[:end_idx] 116 | ids = ids[end_idx:] 117 | embeddings = embeddings[end_idx:] 118 | index.index_data(ids_toadd, embeddings_toadd) 119 | return embeddings, ids 120 | 121 | 122 | def add_passages(self, passages, top_passages_and_scores): 123 | # add passages to original data 124 | docs = [passages[doc_id] for doc_id in top_passages_and_scores[0][0]] 125 | return docs 126 | 127 | def setup_retriever(self): 128 | print(f"Loading model from: {self.args.model_name_or_path}") 129 | self.model, self.tokenizer, _ = src.contriever.load_retriever(self.args.model_name_or_path) 130 | self.model.eval() 131 | self.model = self.model.cuda() 132 | if not self.args.no_fp16: 133 | self.model = self.model.half() 134 | 135 | self.index = src.index.Indexer(self.args.projection_size, self.args.n_subquantizers, self.args.n_bits) 136 | 137 | # index all passages 138 | input_paths = glob.glob(self.args.passages_embeddings) 139 | input_paths = sorted(input_paths) 140 | embeddings_dir = os.path.dirname(input_paths[0]) 141 | index_path = os.path.join(embeddings_dir, "index.faiss") 142 | if self.args.save_or_load_index and os.path.exists(index_path): 143 | self.index.deserialize_from(embeddings_dir) 144 | else: 145 | print(f"Indexing passages from files {input_paths}") 146 | start_time_indexing = time.time() 147 | self.index_encoded_data(self.index, input_paths, self.args.indexing_batch_size) 148 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.") 149 | if self.args.save_or_load_index: 150 | self.index.serialize(embeddings_dir) 151 | 152 | # load passages 153 | print("loading passages") 154 | self.passages = src.data.load_passages(self.args.passages) 155 | self.passage_id_map = {x["id"]: x for x in self.passages} 156 | print("passages have been loaded") 157 | 158 | def search_document(self, query, top_n=10): 159 | questions_embedding = self.embed_queries(self.args, [query]) 160 | 161 | # get top k results 162 | start_time_retrieval = time.time() 163 | top_ids_and_scores = self.index.search_knn(questions_embedding, self.args.n_docs) 164 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.") 165 | 166 | return self.add_passages(self.passage_id_map, top_ids_and_scores)[:top_n] 167 | 168 | def search_document_demo(self, query, n_docs=10): 169 | questions_embedding = self.embed_queries_demo([query]) 170 | 171 | # get top k results 172 | start_time_retrieval = time.time() 173 | top_ids_and_scores = self.index.search_knn(questions_embedding, n_docs) 174 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.") 175 | 176 | return self.add_passages(self.passage_id_map, top_ids_and_scores)[:n_docs] 177 | 178 | def setup_retriever_demo(self, model_name_or_path, passages, passages_embeddings, n_docs=5, save_or_load_index=False): 179 | print(f"Loading model from: {model_name_or_path}") 180 | self.model, self.tokenizer, _ = src.contriever.load_retriever(model_name_or_path) 181 | self.model.eval() 182 | self.model = self.model.cuda() 183 | 184 | self.index = src.index.Indexer(768, 0, 8) 185 | 186 | # index all passages 187 | input_paths = glob.glob(passages_embeddings) 188 | input_paths = sorted(input_paths) 189 | embeddings_dir = os.path.dirname(input_paths[0]) 190 | index_path = os.path.join(embeddings_dir, "index.faiss") 191 | if save_or_load_index and os.path.exists(index_path): 192 | self.index.deserialize_from(embeddings_dir) 193 | else: 194 | print(f"Indexing passages from files {input_paths}") 195 | start_time_indexing = time.time() 196 | self.index_encoded_data(self.index, input_paths, 1000000) 197 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.") 198 | 199 | # load passages 200 | print("loading passages") 201 | self.passages = src.data.load_passages(passages) 202 | self.passage_id_map = {x["id"]: x for x in self.passages} 203 | print("passages have been loaded") 204 | 205 | def add_hasanswer(data, hasanswer): 206 | # add hasanswer to data 207 | for i, ex in enumerate(data): 208 | for k, d in enumerate(ex["ctxs"]): 209 | d["hasanswer"] = hasanswer[i][k] 210 | 211 | 212 | def load_data(data_path): 213 | if data_path.endswith(".json"): 214 | with open(data_path, "r") as fin: 215 | data = json.load(fin) 216 | elif data_path.endswith(".jsonl"): 217 | data = [] 218 | with open(data_path, "r") as fin: 219 | for k, example in enumerate(fin): 220 | example = json.loads(example) 221 | data.append(example) 222 | return data 223 | 224 | 225 | def main(args): 226 | data = load_data(args.query) 227 | 228 | retriever = Retriever(args) 229 | retriever.setup_retriever() 230 | 231 | # 3. 对每条问答执行检索 232 | with open(args.output_dir, "a", encoding="utf-8") as fout: 233 | for item in tqdm(data): 234 | question = item["retrieved_question"] 235 | docs = retriever.search_document(question, args.n_docs) 236 | item["ctxs"] = docs 237 | 238 | fout.write(json.dumps(item, ensure_ascii=False) + "\n") 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser() 243 | 244 | parser.add_argument( 245 | "--query", 246 | type=str, 247 | default=None, 248 | help=".json file containing question and answers, similar format to reader data", 249 | ) 250 | parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") 251 | parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages") 252 | parser.add_argument( 253 | "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix" 254 | ) 255 | parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions") 256 | parser.add_argument( 257 | "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results" 258 | ) 259 | parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding") 260 | parser.add_argument( 261 | "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists" 262 | ) 263 | parser.add_argument( 264 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 265 | ) 266 | parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") 267 | parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question") 268 | parser.add_argument( 269 | "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed" 270 | ) 271 | parser.add_argument("--projection_size", type=int, default=768) 272 | parser.add_argument( 273 | "--n_subquantizers", 274 | type=int, 275 | default=0, 276 | help="Number of subquantizer used for vector quantization, if 0 flat index is used", 277 | ) 278 | parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer") 279 | parser.add_argument("--lang", nargs="+") 280 | parser.add_argument("--dataset", type=str, default="none") 281 | parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") 282 | parser.add_argument("--normalize_text", action="store_true", help="normalize text") 283 | 284 | args = parser.parse_args() 285 | src.slurm.init_distributed_mode(args) 286 | main(args) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DynamicRAG: Leveraging Outputs of Large Language Model as Feedback for Dynamic Reranking in Retrieval-Augmented Generation 2 | 3 |
4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | training_data 15 | 16 | 17 | eval_data 18 | 19 | 20 | model 21 | 22 | 23 | model 24 | 25 |
26 | 27 | **DynamicRAG** is an innovative framework for Retrieval-Augmented Generation (RAG) that dynamically adjusts both the **order** and **number** of retrieved documents per query. A reinforcement learning (RL) agent serves as the reranker, optimizing document retrieval based on feedback from a **Large Language Model (LLM)**. The training process is divided into two main stages: 28 | 29 | 1. **Supervised Fine-Tuning (SFT) via Behavior Cloning**: 30 | - Trains the reranker with expert trajectories. 31 | - Simplifies the action space and establishes a baseline. 32 | 2. **Reinforcement Learning (RL) with LLM Feedback**: 33 | - Uses interactive feedback from the generator. 34 | - Explores improved trajectories and further optimizes the reranker. 35 | 36 | ## How to cite 37 | If you extend or use this work, please cite the [paper](https://arxiv.org/abs/2212.07249) where it was introduced: 38 | 39 | ``` 40 | @misc{sun2025dynamicragleveragingoutputslarge, 41 | title={DynamicRAG: Leveraging Outputs of Large Language Model as Feedback for Dynamic Reranking in Retrieval-Augmented Generation}, 42 | author={Jiashuo Sun and Xianrui Zhong and Sizhe Zhou and Jiawei Han}, 43 | year={2025}, 44 | eprint={2505.07233}, 45 | archivePrefix={arXiv}, 46 | primaryClass={cs.CL}, 47 | url={https://arxiv.org/abs/2505.07233}, 48 | } 49 | ``` 50 | 51 | ## 🔥 Update 52 | * [2025-09-18]: 🚀 Our paper is accepted by NeurIPS 2025! See you in San Diego! 🥳🥳🥳 53 | * [2025-07-13]: 🚀 We release the training data of DynamicRAG: [DynamicRAG_Training_Data_150k](https://huggingface.co/datasets/gasolsun/DynamicRAG_Training_Data_150k) 54 | * [2025-05-13]: 🚀 We release the paper: [https://arxiv.org/abs/2505.07233](https://arxiv.org/abs/2505.07233) 55 | * [2025-05-07]: 🚀 We release the [DynamicRAG-7B](https://huggingface.co/gasolsun/DynamicRAG-7B) and [DynamicRAG-8B](https://huggingface.co/gasolsun/DynamicRAG-8B) and [eval-datas](https://huggingface.co/datasets/gasolsun/DynamicRAG-Eval-Data). 56 | * [2025-05-05]: 🚀 We release the code for training and evaluation. 57 | 58 | 59 | 60 | ## Table of Contents 61 | 62 | - [DynamicRAG Overview](#dynamicrag-overview) 63 | - [Project Visualizations](#project-visualizations) 64 | - [📌 Data Processing Pipeline](#-data-processing-pipeline) 65 | - [🎯 Supervised Fine-Tuning (SFT) Training](#-supervised-fine-tuning-sft-training) 66 | - [🤖 Interactive Data Collection](#-interactive-data-collection) 67 | - [📈 Direct Preference Optimization (DPO) Training](#-direct-preference-optimization-dpo-training) 68 | - [🔍 Inference and Evaluation](#-inference-and-evaluation) 69 | - [📄 Licensing and Claims](#-licensing-and-claims) 70 | 71 | --- 72 | 73 | ## DynamicRAG Overview 74 | 75 | DynamicRAG adjusts the retrieval process on-the-fly by: 76 | - Dynamically reordering and selecting the number of documents per query. 77 | - Leveraging a reranker trained with RL and LLM feedback to improve retrieval quality. 78 | 79 | --- 80 | 81 | ## 💡 Preliminaries 82 | You should install the enviroment by `pip install -r requirements.txt`, and running: 83 | ```python 84 | apt-get update 85 | apt-get install libtiff5 86 | ``` 87 | Moreover, you need to config the retriever corpus, e.g. official 2018 English Wikipedia embeddings. We use the exact same config with [Self-RAG](https://github.com/AkariAsai/self-rag). You can read their Retriever Setup. 88 | 89 | 90 | ## 📌 Data Processing Pipeline 91 | Example: Training LLaMA3-8B with Top-40 Documents 92 | 93 | ### **1. Prepare BC Data Pipeline** 94 | #### **Step 1: Retrieve Top-40 Documents** 95 | Run the retrieval script: 96 | ```bash 97 | #!/bin/bash 98 | 99 | NUM_GPUS=8 100 | INPUT_FILE="data/rag_training_data.json" 101 | SPLIT_DIR="data/splits" 102 | 103 | python split_data.py --input_file $INPUT_FILE --output_dir $SPLIT_DIR --num_splits $NUM_GPUS 104 | 105 | for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do 106 | SPLIT_FILE="${SPLIT_DIR}/split_${GPU_ID}.json" 107 | OUTPUT_FILE="output/retrieval_split_${GPU_ID}.json" 108 | log_file="logs/retriever_split_${GPU_ID}.log" 109 | CUDA_VISIBLE_DEVICES=$GPU_ID python retriever.py \ 110 | --model_name_or_path models/retriever \ 111 | --passages data/psgs_w100.tsv \ 112 | --passages_embeddings "data/wikipedia_embeddings/*" \ 113 | --query $SPLIT_FILE \ 114 | --output_dir $OUTPUT_FILE \ 115 | --n_docs 50 \ 116 | 1>"$log_file" 2>&1 & 117 | 118 | echo "Started process on GPU $GPU_ID with input $SPLIT_FILE" 119 | done 120 | 121 | wait 122 | echo "All processes completed." 123 | 124 | ``` 125 | 126 | #### **Step 2: Aggregate Retrieved Data** 127 | ```bash 128 | python aggregate.py 129 | ``` 130 | 131 | #### **Step 3: Rerank Documents** 132 | ```bash 133 | python reranker.py --model_name_or_path models/reranker/monot5 \ 134 | --input_file output/retrieval_data.jsonl \ 135 | --output_file output/retrieval_data_rerank.jsonl \ 136 | --device cuda 137 | ``` 138 | Outputs: `retrieval_data_rerank.jsonl` 139 | 140 | > 💡 If you running above command slowly, consider running it with multi-gpus like retriever and then combine the results. 141 | 142 | 143 | #### **Step 4: Compute True/False in Reranking** 144 | ```bash 145 | python process_training_data.py 146 | ``` 147 | Outputs: 148 | - `retrieval_data_rerank_sequence.json` (for Reranker BC training) 149 | - `retrieval_data_rerank_normal.json` (for SFT & DPO training) 150 | 151 | #### **Step 5: Convert Reranker Data for Training** 152 | ```bash 153 | python reranker_sequence.py 154 | ``` 155 | Output: `reranker_bc_data.json` (formatted for **LLaMA-Factory**) 156 | 157 | #### **Step 6: Split SFT & DPO Data** 158 | ```bash 159 | python split_for_sft_dpo.py 160 | ``` 161 | 162 | #### **Step 7: Construct Generator SFT Data** 163 | ```bash 164 | python construct_generator_sft.py 165 | ``` 166 | 167 | --- 168 | 169 | ## 🎯 Supervised Fine-Tuning (SFT) Training 170 | We use **LLaMA-Factory** as the training framework. Install it from [here](https://github.com/hiyouga/LLaMA-Factory). 171 | 172 | ### **1. Configure `dataset_info.json`** 173 | Modify `LLaMA-Factory/data/dataset_info.json`: 174 | ```json 175 | { 176 | "generator_sft": { 177 | "file_name": "generator_sft_training.json", 178 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"} 179 | }, 180 | 181 | "reranker_bc": { 182 | "file_name": "reranker_bc_training.json", 183 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"} 184 | }, 185 | 186 | "alpaca_data": { 187 | "file_name": "alpaca_data_cleaned_system.json", 188 | "columns": {"prompt": "instruction", "query": "input", "response": "output", "system": "system"} 189 | } 190 | } 191 | ``` 192 | 193 | ### **2. Train the Model** 194 | Modify `llama8b.yaml` and run: 195 | ```bash 196 | llamafactory-cli train examples/train_full/llama8b.yaml 197 | ``` 198 | > 🛠️ Requires at least **8 A100-80G GPUs**. 199 | 200 | --- 201 | 202 | ## 🤖 Interactive Data Collection 203 | We use **vLLM** for faster sampling. 204 | 205 | ### **1. Sample Interaction Trajectories** 206 | ```bash 207 | python sampling_dpo_trajectories.py \ 208 | --template template/llama3.jinja \ 209 | --llm-model DynamicRAG_llama3_8b \ 210 | --input-jsonl training_data/training_data_dpo.jsonl \ 211 | --output-json results/training_data_dpo_sampling.json 212 | 213 | ``` 214 | 215 | ### **2. Collect Rewards for Trajectories** 216 | ```bash 217 | ython reward_trajectories.py \ 218 | --input_file results/training_data_dpo_sampling.json \ 219 | --output_file training_data/llama3_8b_output_dpo.jsonl \ 220 | ``` 221 | 222 | ### **3. Construct DPO Training Data** 223 | ```bash 224 | python construct_dpo.py 225 | ``` 226 | 227 | --- 228 | 229 | ## 📈 Direct Preference Optimization (DPO) Training 230 | 231 | ### **1. Configure `dataset_info.json`** 232 | ```json 233 | { 234 | "llama3_generator_dpo": { 235 | "file_name": "llama3_8b_generator_dpo.json", 236 | "ranking": true, 237 | "columns": {"prompt": "instruction", "query": "input", "chosen": "chosen", "rejected": "rejected"} 238 | }, 239 | 240 | "llama3_reranker_dpo": { 241 | "file_name": "llama3_8b_reranker_dpo.json", 242 | "ranking": true, 243 | "columns": {"prompt": "instruction", "query": "input", "chosen": "chosen", "rejected": "rejected"} 244 | } 245 | } 246 | ``` 247 | 248 | ### **2. Train the Model** 249 | ```bash 250 | llamafactory-cli train examples/train_full/llama8b_dpo.yaml 251 | ``` 252 | > 🛠️ Requires at least **8 A100-80G GPUs**. 253 | 254 | --- 255 | 256 | ## 🔍 Inference and Evaluation 257 | We use **vLLM** for efficient inference. 258 | 259 | ### **1. Run Inference** 260 | ```bash 261 | #!/bin/bash 262 | 263 | 264 | LOG_DIR="eval_logs" 265 | mkdir -p $LOG_DIR 266 | 267 | run_inference() { 268 | local input_file=$1 269 | local output_file=$2 270 | local remain_output_file=$3 271 | 272 | echo "Running inference for $input_file..." 273 | python inference.py \ 274 | --template template/llama3.jinja \ 275 | --llm-model DynamicRAG_llama3_8b \ 276 | --input-json $input_file \ 277 | --output-json $output_file \ 278 | --remain-output-json $remain_output_file \ 279 | >> $LOG_DIR/$(basename $output_file .json)_log.txt 2>&1 280 | 281 | sleep 5 282 | } 283 | 284 | 285 | run_inference "eval_data/triviaqa.jsonl" \ 286 | "results/llama3_8b_triviaqa.json" \ 287 | "results/llama3_8b_triviaqa_remain.json" 288 | 289 | run_inference "eval_data/nq.jsonl" \ 290 | "results/llama3_8b_nq.json" \ 291 | "results/llama3_8b_nq_remain.json" 292 | 293 | run_inference "eval_data/hotpotqa.jsonl" \ 294 | "results/llama3_8b_hotpotqa.json" \ 295 | "results/llama3_8b_hotpotqa_remain.json" 296 | 297 | run_inference "eval_data/2wikimqa.jsonl" \ 298 | "results/llama3_8b_2wikimqa.json" \ 299 | "results/llama3_8b_2wikimqa_remain.json" 300 | 301 | run_inference "eval_data/fever.jsonl" \ 302 | "results/llama3_8b_fever.json" \ 303 | "results/llama3_8b_fever_remain.json" 304 | 305 | run_inference "eval_data/eli5.jsonl" \ 306 | "results/llama3_8b_eli5.json" \ 307 | "results/llama3_8b_eli5_remain.json" 308 | 309 | run_inference "eval_data/asqa_eval_gtr_top100.jsonl" \ 310 | "results/llama3_8b_asqa.json" \ 311 | "results/llama3_8b_asqa_remain.json" 312 | 313 | echo "All tasks completed. Logs are available in $LOG_DIR." 314 | 315 | ``` 316 | Evaluates **7 different benchmarks**. 317 | 318 | ### **2. Evaluate Performance** 319 | ```bash 320 | # install nltk, rouge_score, spacy 321 | # python -m spacy download en_core_web_sm 322 | 323 | # for example, when we evaluate nq 324 | python evaluate.py \ 325 | --results_file results/llama3_8b_nq.json \ 326 | --metric match 327 | ``` 328 | 329 | ### **3. Run DynamicRAG on 500+ Documents** 330 | ```bash 331 | #!/bin/bash 332 | 333 | TEMPLATE="template/llama3.jinja" 334 | LLM_MODEL="DynamicRAG_llama3_8b" 335 | INPUT_JSONL="eval_data/nq_top500.jsonl" 336 | MAX_CONTEXT_WINDOW=40 337 | 338 | TOPN_VALUES=(50 100 150 200 300 500) 339 | 340 | for TOPN in "${TOPN_VALUES[@]}"; do 341 | LOG_FILE="top_logs/llama3_8b_nq_top_${TOPN}.log" 342 | 343 | python top_inference.py \ 344 | --template "$TEMPLATE" \ 345 | --llm-model "$LLM_MODEL" \ 346 | --input-jsonl "$INPUT_JSONL" \ 347 | --output-json "results/llama3_8b_top_${TOPN}_nq.json" \ 348 | --remain-output-json "results/llama3_8b_top_${TOPN}_nq_remain.json" \ 349 | --max-context-window "$MAX_CONTEXT_WINDOW" \ 350 | --topn "$TOPN" >> "$LOG_FILE" 2>&1 351 | 352 | sleep 3 353 | done 354 | 355 | ``` 356 | 357 | --- 358 | 359 | 360 | 361 | ## Project Visualizations 362 | 363 | Explore the key components and performance of DynamicRAG through the following images: 364 | 365 | - **Introduction of DynamicRAG:** 366 | 367 |
368 | 369 | DynamicRAG Intro 370 | 371 |
372 | 373 | 374 | - **Pipeline of DynamicRAG:** 375 | 376 |
377 | 378 | DynamicRAG Pipeline 379 | 380 |
381 | 382 | - **Generator Experiment:** 383 | 384 |
385 | 386 | Generator Experiment 387 | 388 |
389 | 390 | 391 | - **Reranker Experiment:** 392 | 393 |
394 | 395 | Reranker Experiment 396 | 397 |
398 | 399 | 400 | 401 | - **Efficiency of DynamicRAG:** 402 | 403 |
404 | 405 | Efficiency 406 | 407 |
408 | 409 | - **Case Study:** 410 | 411 |
412 | 413 | Case Study 1 414 | 415 |
416 |
417 | 418 | Case Study 2 419 | 420 |
421 | 422 | 423 | --- 424 | 425 | 426 | ## 📄 Licensing and Claims 427 | This project is licensed under the Apache 2.0 protocol. The project assumes no legal responsibility for any output generated by the models and will not be held liable for any damages resulting from the use of the provided resources and outputs. 428 | 429 | --------------------------------------------------------------------------------