├── README.md ├── data_generation.py ├── images └── main_results.png ├── inference.py ├── my_config.yaml ├── pipeline └── reasonrag_pipeline.py ├── preference_data_generatoin.py └── training_config └── qwen_dpo.yaml /README.md: -------------------------------------------------------------------------------- 1 |

ReasonRAG: Enhancing Agentic RAG with Process-Supervised Reinforcement Learning

2 | 3 |
4 | license 5 | license 6 | Hugging Face 7 | 8 | 9 |
10 | 11 |
If you like our project, please give us a star ⭐ on GitHub for the latest update.
12 | 13 | # 💡 Overview 14 | 15 | Recent advancements in outcome-supervised Reinforcement Learning (RL), exemplified by OpenAI's O1 and DeepMind's R1, have demonstrated remarkable improvements in language model (LLM) reasoning capabilities. Integrating outcome-supervised RL with search engines presents another promising avenue for boosting LLM reasoning. However, outcome-supervised RL often grapples with challenges such as sparse reward, training instability, and inefficient exploration. 16 | 17 | To address these limitations, process-supervised RL emerges as a compelling solution for enhancing Agentic RAG, offering the advantage of fine-grained rewards. We introduce **ReasonRAG**, a process-supervised method designed to refine Agentic RAG's strategic preferences. 18 | 19 | Our approach consists of three key steps: 20 | 21 | 1. We leverage Monte Carlo Tree Search (MCTS) to generate process-supervised rollouts, yielding rich data on process-level strategic preferences. 22 | 2. We then employ Direct Preference Optimization (DPO) to effectively optimize these strategic preferences within the Agentic RAG framework. 23 | 3. Finally, we construct an Agentic RAG pipeline that enables the LLM to autonomously generate queries, extract evidence, and formulate answers. 24 | We provide the dataset we constructed and links to our trained models below. 25 | 26 | * **RAG_ProGuide Dataset:** [https://huggingface.co/datasets/reasonrag/RAG_ProGuide](https://huggingface.co/datasets/reasonrag/RAG_ProGuide) 27 | * **Trained Models:** [Qwen2.5-7B-Instruct-ReasonRAG](https://huggingface.co/reasonrag/Qwen2.5-7B-Instruct-ReasonRAG) 28 | * **Trained Lora Models:** [Qwen2.5-7B-Instruct-RAG-Lora](https://huggingface.co/reasonrag/Qwen2.5-7B-Instruct-RAG-Lora) 29 | 30 | ReasonRAG achieves superior performance on five benchmark datasets using only 5k training instances, significantly fewer than the 90k training instances required by Search-R1. 31 | 32 | ![Main Results](images/main_results.png) 33 | 34 | # ✨ Method 35 | We employ process-supervised RL to enhance Agentic RAG capabilities: 36 | 1. Generate process-supervised reward data. 37 | 2. Policy Preference Optimization 38 | 3. Agentic RAG Inference 39 | 40 | ## Data 41 | We randomly data from PopQA, HotpotQA, 2WikimultihopQA to generate process-supervised preference data. Then, we use GPT-4o as the policy model to generate rollout data. The generated process-supervised data, namely RAG-ProGuide is available at: [https://huggingface.co/datasets/reasonrag/RAG_ProGuide](https://huggingface.co/datasets/reasonrag/RAG_ProGuide) 42 | 43 | # 🏃 Quick Start 44 | ## Environment Settings 45 | Construct [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) environments: 46 | ```bash 47 | conda create --name reasonrag python=3.10.16 48 | conda activate reasonrag 49 | pip install flashrag-dev --pre 50 | pip install flashrag-dev[full] 51 | pip install vllm>=0.4.1 52 | pip install deepspeed 53 | ``` 54 | 55 | ## Data Preparation 56 | 57 | Download wikidump as the corpus for retrieval 58 | 59 | ```bash 60 | # Download wikidump 61 | wget https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-articles.xml.bz2 62 | 63 | # Build index 64 | python -m flashrag.retriever.index_builder \ 65 | --retrieval_method bge \ 66 | --model_path /BAAI/bge-base-en-v1.5 \ 67 | --corpus_path indexes/wiki18.jsonl \ 68 | --save_dir indexes/ \ 69 | --use_fp16 \ 70 | --max_length 512 \ 71 | --batch_size 256 \ 72 | --pooling_method mean \ 73 | --faiss_type Flat 74 | ``` 75 | 76 | Download QA dataset from huggingface [RUC-NLPIR/FlashRAG_datasets](https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets) 77 | 78 | ## Data Generation 79 | > Note: This code generates policy preference data. You can directly use the RAG-ProGuide dataset (linked above!), or run this code to generate your own, or adapt it as needed. 80 | ```bash 81 | python data_generation.py --dataset_name popqa --model gpt-4o 82 | python data_generation.py --dataset_name hotpotqa --model gpt-4o 83 | python data_generation.py --dataset_name 2wikimultihopqa --model gpt-4o 84 | python preference_data_generation.py 85 | ``` 86 | 87 | ## Training 88 | ```bash 89 | # Install LLaMA Factory 90 | git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git 91 | cd LLaMA-Factory 92 | pip install -e ".[torch,metrics]" 93 | 94 | # Set the dataset path before prefrence optimization 95 | llamafactory-cli train training_config/qwen_dpo.yaml 96 | ``` 97 | 98 | ## Inference 99 | ```bash 100 | python inference.py --dataset_name hotpotqa --model $MODEL_NAME 101 | ``` 102 | -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | from multiprocessing import Process 5 | import copy 6 | import os 7 | from tqdm import tqdm 8 | import yaml 9 | from flashrag.config import Config 10 | from flashrag.utils import get_dataset 11 | from pipeline.reasonrag_pipeline import ReasonRAGPipeline 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--dataset_name", type=str) 15 | parser.add_argument("--model", type=str) 16 | args = parser.parse_args() 17 | 18 | 19 | root_dir = 'output' 20 | 21 | def load_config_from_yaml(yaml_file): 22 | try: 23 | with open(yaml_file, "r") as file: 24 | return yaml.safe_load(file) 25 | except Exception as e: 26 | print(f"Error loading YAML file: {e}") 27 | return {} 28 | 29 | default_config = load_config_from_yaml("my_config.yaml") 30 | 31 | config_dict = { 32 | "data_dir": "dataset/", 33 | "dataset_name": args.dataset_name, 34 | "index_path": "indexes/bge_Flat_wiki_extend.index", 35 | "corpus_path": "indexes/wiki18_100w_extend.jsonl", 36 | "model2path": { 37 | "bge": "BAAI/bge-base-en-v1.5", 38 | "qwen2.5": "Qwen/Qwen2.5-7B-Instruct", 39 | }, 40 | "generator_model": "gpt-4o-2024-05-13", 41 | "retrieval_method": "bge", 42 | "framework": "openai", 43 | "gpu_id": "1", 44 | "faiss_gpu": True, 45 | "metrics": ["em", "f1", "acc", "recall", "precision"], 46 | "retrieval_topk": 3, 47 | "save_intermediate_data": True, 48 | "save_note": args.model + "_MCTS", 49 | } 50 | 51 | answer_format = "answer" 52 | max_iter = 8 53 | 54 | config_dict = {**default_config, **config_dict} 55 | config = Config(config_dict=config_dict) 56 | 57 | dataset_path = config["dataset_path"] 58 | split_path = os.path.join(dataset_path, "train.jsonl") 59 | data_split = "train" 60 | all_split = get_dataset(config) 61 | test_data = all_split[data_split] 62 | 63 | def save_data(save_dir, data, file_name="intermediate_data.json"): 64 | data = [item.to_dict() for item in data] 65 | save_path = os.path.join(save_dir, file_name) 66 | with open(save_path, "w", encoding="utf-8") as f: 67 | json.dump(data, f, indent=4) 68 | 69 | 70 | def parallel_process_dataset(config, test_data, num_processes=4): 71 | total_items = len(test_data) 72 | if total_items == 0: 73 | print("No data to process.") 74 | return 75 | 76 | chunk_size = total_items // num_processes 77 | remainder = total_items % num_processes 78 | 79 | chunks = [] 80 | start = 0 81 | for i in range(num_processes): 82 | end = start + chunk_size 83 | if i < remainder: 84 | end += 1 85 | if start >= total_items: 86 | break 87 | chunks.append(test_data[start:end]) 88 | start = end 89 | 90 | processes = [] 91 | for chunk_idx, chunk in enumerate(chunks): 92 | print(f"Chunk {chunk_idx} ready.") 93 | p = Process(target=process_chunk, args=(copy.copy(config), chunk, chunk_idx)) 94 | processes.append(p) 95 | p.start() 96 | 97 | for p in processes: 98 | p.join() 99 | 100 | 101 | def process_chunk(config, chunk, chunk_idx): 102 | save_dir = os.path.join(config["save_dir"], f"chunk_{chunk_idx}") 103 | os.makedirs(save_dir, exist_ok=True) 104 | config["save_dir"] = save_dir 105 | 106 | pipeline = ReasonRAGPipeline(config, prompt_template=None, max_iter=7, max_children=2, max_rollouts=64) 107 | 108 | i = 0 109 | start_time = time.time() 110 | 111 | for item in tqdm(chunk, desc=f"Chunk {chunk_idx}"): 112 | try: 113 | pipeline.search(item) 114 | except Exception as e: 115 | print(f"Chunk {chunk_idx} Error at item {i}: {e}") 116 | continue 117 | finally: 118 | i += 1 119 | 120 | save_data(save_dir, chunk, file_name=f"final_{chunk_idx}.json") 121 | print(f"Chunk {chunk_idx} processed {len(chunk)} items in {time.time() - start_time:.2f}s") 122 | 123 | 124 | parallel_process_dataset(config, test_data, num_processes=10) 125 | -------------------------------------------------------------------------------- /images/main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlzhang2020/ReasonRAG/225e249a26fa0d1d2fd3c9bf60855d725b7cd435/images/main_results.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | from flashrag.config import Config 5 | from flashrag.utils import get_dataset 6 | from pipeline.reasonrag_pipeline import ReasonRAGPipeline 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--dataset_name", type=str) 10 | parser.add_argument("--model", type=str) 11 | parser.add_argument("--max_iter", default=8, type=int) 12 | parser.add_argument("--retrieval_top_k", default=3, type=int) 13 | args = parser.parse_args() 14 | 15 | root_dir = 'output' 16 | 17 | def load_config_from_yaml(yaml_file): 18 | try: 19 | with open(yaml_file, "r") as file: 20 | return yaml.safe_load(file) 21 | except Exception as e: 22 | print(f"Error loading YAML file: {e}") 23 | return {} 24 | 25 | default_config = load_config_from_yaml("my_config.yaml") 26 | 27 | config_dict = { 28 | "data_dir": "dataset/", 29 | "dataset_name": args.dataset_name, 30 | "index_path": "indexes/bge_Flat.index", 31 | "retrieval_method": "bge", 32 | "corpus_path": "indexes/wiki18_100w.jsonl", 33 | "model2path": { 34 | "bge": "BAAI/bge-base-en-v1.5", 35 | "e5": "intfloat/e5-base-v2", 36 | "qwen2.5": "Qwen/Qwen2.5-7B", 37 | "qwen2.5-instruct": "Qwen/Qwen2.5-7B-Instruct", 38 | }, 39 | "generator_model": args.model, 40 | "generator_batch_size": 1, 41 | "framework": "vllm", 42 | "gpu_id": "0, 1, 2, 3", 43 | "faiss_gpu": True, 44 | "retrieval_batch_size": 256, 45 | "gpu_memory_utilization": 0.5, 46 | "metrics": ["em", "f1", "acc", "recall", "precision"], 47 | "retrieval_topk": args.retrieval_top_k, 48 | "save_intermediate_data": True, 49 | "save_note": args.model + f"_iter{args.max_iter}", 50 | } 51 | 52 | answer_format = "answer" 53 | max_iter = 10 54 | 55 | config_dict = {**default_config, **config_dict} 56 | config = Config(config_dict=config_dict) 57 | 58 | dataset_path = config["dataset_path"] 59 | split_path = os.path.join(dataset_path, "test.jsonl") 60 | data_split = "test" 61 | if not os.path.exists(split_path): 62 | if os.path.exists(os.path.join(dataset_path, "dev.jsonl")): 63 | data_split = "dev" 64 | elif os.path.exists(os.path.join(dataset_path, "val.jsonl")): 65 | data_split = "val" 66 | else: 67 | data_split = "None" 68 | 69 | all_split = get_dataset(config) 70 | test_data = all_split[data_split] 71 | 72 | pipeline = ReasonRAGPipeline(config, prompt_template=None, answer_format=answer_format, max_iter=args.max_iter, max_children=2, max_rollouts=64) 73 | output_dataset = pipeline.run(test_data, batch_size=1000, do_eval=True) 74 | -------------------------------------------------------------------------------- /my_config.yaml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------Global Paths------------------------------------------------# 2 | # Paths to models 3 | model2path: 4 | e5: "intfloat/e5-base-v2" 5 | bge: "BAAI/bge-base-en-v1.5" 6 | contriever: "facebook/contriever" 7 | llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf" 8 | llama2-7B: "meta-llama/Llama-2-7b-hf" 9 | llama2-13B: "meta-llama/Llama-2-13b-hf" 10 | llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf" 11 | llama3-8B-instruct: "meta-llama/Meta-Llama-3-8B-Instruct" 12 | qwen2.5-7b-instrucRt: "Qwen/Qwen2.5-7B-Instruct" 13 | qwen2.5-3b-instrucRt: "Qwen/Qwen2.5-3B-Instruct" 14 | 15 | 16 | # Pooling methods for each embedding model 17 | model2pooling: 18 | e5: "mean" 19 | bge: "cls" 20 | contriever: "mean" 21 | jina: 'mean' 22 | dpr: cls 23 | 24 | # Indexes path for retrieval models 25 | method2index: 26 | e5: ~ 27 | bm25: ~ 28 | bge: ~ 29 | contriever: ~ 30 | 31 | # ------------------------------------------------Environment Settings------------------------------------------------# 32 | # Directory paths for data and outputs 33 | data_dir: "dataset/" 34 | save_dir: "output/" 35 | 36 | gpu_id: "0" 37 | dataset_name: "nq" # name of the dataset in data_dir 38 | split: [ "test", "train", "dev", "val"] # dataset split to load (e.g. train,dev,test) 39 | 40 | # Sampling configurations for testing 41 | test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples 42 | random_sample: False # whether to randomly sample the test samples 43 | 44 | # Seed for reproducibility 45 | seed: 2024 46 | 47 | # Whether save intermediate data 48 | save_intermediate_data: True 49 | save_note: 'experiment' 50 | 51 | # -------------------------------------------------Retrieval Settings------------------------------------------------# 52 | # If set the name, the model path will be find in global paths 53 | retrieval_method: "bge" # name or path of the retrieval model. 54 | index_path: "indexes/bge_Flat.index" # set automatically if not provided. 55 | faiss_gpu: True # whether use gpu to hold index 56 | corpus_path: "indexes/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents 57 | 58 | instruction: ~ # instruction for retrieval model 59 | retrieval_topk: 5 # number of retrieved documents 60 | retrieval_batch_size: 256 # batch size for retrieval 61 | retrieval_use_fp16: True # whether to use fp16 for retrieval model 62 | retrieval_query_max_length: 128 # max length of the query 63 | save_retrieval_cache: True # whether to save the retrieval cache 64 | use_retrieval_cache: False # whether to use the retrieval cache 65 | retrieval_cache_path: ~ # path to the retrieval cache 66 | retrieval_pooling_method: ~ # set automatically if not provided 67 | 68 | use_reranker: False # whether to use reranker 69 | rerank_model_name: ~ # same as retrieval_method 70 | rerank_model_path: ~ # path to reranker model, path will be automatically find in `retriever_model2path` 71 | rerank_pooling_method: ~ 72 | rerank_topk: 5 # number of remain documents after reranking 73 | rerank_max_length: 512 74 | rerank_batch_size: 256 # batch size for reranker 75 | rerank_use_fp16: True 76 | 77 | ## -------------------------------------------------Generator Settings------------------------------------------------# 78 | framework: vllm # inference frame work of LLM, supporting: 'hf','vllm','fschat' 79 | generator_model: "Qwen/Qwen2.5-7B-Instruct" # name or path of the generator model 80 | model_path: "Qwen/Qwen2.5-7B-Instruct" 81 | generator_max_input_len: 8192 # max length of the input 82 | generator_batch_size: 2 # batch size for generation, invalid for vllm 83 | generation_params: 84 | do_sample: False 85 | max_tokens: 256 86 | use_fid: False # whether to use FID, only valid in encoder-decoder model 87 | 88 | # -----------------------------------------OpenAI Generator Settings------------------------------------------------# 89 | #framework: openai # inference frame work of LLM, supporting: 'hf','vllm','fschat' 90 | #generator_model: "" # name or path of the generator model 91 | #generator_max_input_len: 16000 # max length of the input 92 | #generator_batch_size: 1 # batch size for generation, invalid for vllm 93 | #generation_params: 94 | # do_sample: False 95 | # max_tokens: 16000 96 | # temperature: 1 97 | #use_fid: False # whether to use FID, only valid in encoder-decoder model 98 | 99 | # -------------------------------------------------Refiner Settings------------------------------------------------# 100 | # If set, the refiner will be used to refine the retrieval documents. 101 | refiner_name: ~ 102 | refiner_model_path: ~ 103 | 104 | # Used for extractive method (e.g embedding models) 105 | refiner_topk: 5 # number of remain sentence after refiner 106 | refiner_pooling_method: 'mean' # pooling method of refiner model 107 | refiner_encode_max_length: 256 108 | # Used for abstractive method (e.g. generation models like bart-large-cnn) 109 | refiner_max_input_length: 1024 110 | refiner_max_output_length: 512 111 | 112 | # Specify settings for llmlingua 113 | llmlingua_config: 114 | rate: 0.55 115 | condition_in_question: 'after_condition' 116 | reorder_context: 'sort' 117 | dynamic_context_compression_ratio: 0.3 118 | condition_compare: True 119 | context_budget: "+100" 120 | rank_method: 'longllmlingua' 121 | sc_config: 122 | 'reduce_ratio': 0.5 123 | 124 | # -------------------------------------------------Evaluation Settings------------------------------------------------# 125 | # Metrics to evaluate the result 126 | metrics: ['em','f1','acc','precision','recall'] 127 | # Specify setting for metric, will be called within certain metrics 128 | metric_setting: 129 | retrieval_recall_topk: 5 130 | tokenizer_name: Qwen/Qwen2.5-7B-Instruct 131 | save_metric_score: True # whether to save the metric score into txt file 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /pipeline/reasonrag_pipeline.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import asyncio 3 | import time 4 | import os 5 | import pickle 6 | import psutil 7 | from copy import deepcopy 8 | import os 9 | from multiprocessing import Pool 10 | import time 11 | from tqdm import tqdm 12 | import multiprocessing 13 | from functools import partial 14 | from concurrent.futures import ProcessPoolExecutor, as_completed 15 | from multiprocessing import Manager 16 | import re 17 | import math 18 | from tqdm import tqdm 19 | import numpy as np 20 | import logging 21 | import copy 22 | import queue 23 | from typing import List 24 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 25 | from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse 26 | from flashrag.pipeline import BasicPipeline 27 | from flashrag.dataset import get_batch_dataset, merge_batch_dataset, Dataset 28 | from flashrag.evaluator.metrics import F1_Score 29 | from flashrag.prompt import PromptTemplate 30 | from abc import ABC, ABCMeta, abstractmethod 31 | from typing import Any, Generic, TypeVar, Union 32 | 33 | 34 | def save_data(save_dir, data, file_name="intermediate_data.json"): 35 | """Save the evaluated data, including the raw data and the score of each data 36 | sample on each metric.""" 37 | save_path = os.path.join(save_dir, file_name) 38 | data.save(save_path) 39 | 40 | class MCTSRAGNode: 41 | def __init__(self, index, S: Any, parent: MCTSRAGNode | None = None, parent_id: int = -1, step: int = 0, 42 | next_state: str | None = None, question="RAG question", thoughts=None, golden_answers=None, 43 | node_dict: dict[str, Any] | None = None): 44 | self.id = index 45 | self.S = S 46 | self.parent = parent 47 | self.parent_id = parent_id 48 | self.children = [] 49 | self.children_ids = [] 50 | self.N = 0 51 | self.step = step 52 | self.Q = 0.0 53 | self.Q_list = [] 54 | self.reward_list = [] 55 | self.next_state = next_state 56 | self.question = question 57 | self.thoughts = thoughts 58 | self.golden_answers = golden_answers 59 | self.node_dict = node_dict if node_dict is not None else {} 60 | self.max_tokens_reached = False 61 | 62 | def add_child(self, child_node: MCTSRAGNode): 63 | self.children.append(child_node) 64 | self.children_ids.append(child_node.id) 65 | 66 | def update_Q(self, reward: float): 67 | self.N += 1 68 | self.Q_list.append(reward) 69 | self.Q = np.mean(self.Q_list) 70 | self.node_dict['Q'] = self.Q 71 | self.node_dict['step'] = self.step 72 | self.node_dict['N'] = self.N 73 | 74 | def update_reward(self, reward: float): 75 | self.reward_list.append(reward) 76 | self.node_dict["reward"] = np.mean(self.reward_list) 77 | 78 | def is_fully_expanded(self, max_children) -> bool: 79 | return self.next_state is None or len(self.children) == max_children 80 | 81 | 82 | def extract_answer(pred, prefix="So the answer is"): 83 | if prefix in pred: 84 | pred = pred.split(prefix)[1].strip() 85 | answer_matches = re.findall(r'(.*?)', pred) 86 | pred = answer_matches[-1] if answer_matches else pred 87 | pred = re.sub(r'.*?|.*?|answer>|([^<]*)', answer, re.DOTALL) 104 | answer = answer_matches[-1] if answer_matches else pred 105 | answer = re.sub(r'.*?|.*?|answer>|.*', next_content): 136 | next_content = f'{next_content}' 137 | 138 | text = text.replace(next_match.group(1), next_content) 139 | 140 | answer_matches = re.finditer(r'So the answer is (.*?)(?= So the answer is|$)', text, re.DOTALL) 141 | for match in answer_matches: 142 | answer_content = match.group(1).strip() 143 | answer_content = re.sub(r'[\.\"]+$', '', answer_content) 144 | if answer_content.startswith('“') and answer_content.endswith('”'): 145 | answer_content = answer_content[1:-1] 146 | 147 | if not re.search(r'.*', answer_content): 148 | answer_content = f'{answer_content}' 149 | 150 | text = text.replace(match.group(1), answer_content) 151 | 152 | return text 153 | 154 | class ReasonRAGPipeline(BasicPipeline): 155 | BEGIN_REASONING_PROMPT = """You are an assistant for question answering with access to a retrieval tool. Upon receiving a question, your task is to: 156 | * Analyze and Decompose the Question: Break the question into smaller, manageable sub-questions to ensure all aspects are addressed. 157 | * Evaluate Your Knowledge: Assess each sub-question or component: 158 | - Identify parts you can confidently answer based on your existing knowledge. 159 | - Pinpoint parts that require additional information or verification through retrieval tools. 160 | * Conciseness: Ensure both queries and answers are concise, using nouns or short phrases whenever possible. 161 | * Respond Format: 162 | If your knowledge is sufficient to answer the question, conclude with: 163 | "So the answer is {answer_format}" 164 | If retrieval is necessary to provide a complete answer, conclude with: 165 | "So the next query is query" 166 | """ 167 | 168 | DOCUMENT_ANALYSIS_PROMPT = """You are an information retrieval assistant. Your task is to extract relevant evidence from the provided Wikipedia documents based on the latest query. 169 | 170 | Instructions: 171 | 172 | * Identify key terms or concepts in the query. 173 | * Search the documents for evidence that supports the query. 174 | * Response format: 175 | If relevant evidence is found, output: 176 | Based on the query, the relevant evidence is evidence. 177 | If no relevant evidence is found, output: 178 | None. 179 | """ 180 | 181 | REASONING_PROMPT = """You are a question-answering assistant with access to a retrieval tool. Your goal is to provide a concise and accurate reasoning process. 182 | Instructions: 183 | * Error Reflection: If errors exist in previous thoughts, identify and correct them. Skip this step if no errors are present. 184 | * Information Sufficiency: Evaluate whether the current information is sufficient to fully and accurately answer the question. If additional retrieval is needed, deconstruct the question and generate the next query. Avoid repeating previous queries. If no meaningful new query can be generated, explain why and provide an answer based on the current information. 185 | * Conciseness: Ensure both queries and answers are concise, using nouns or short phrases whenever possible. 186 | * Conclusion: 187 | If generating an answer: 188 | "So the answer is {answer_format}". 189 | If more retrieval is needed: 190 | "So the next query is query". 191 | """ 192 | 193 | ANSWER_GENERATION_PROMPT = """You are a reasoning assistant with retrieval. Give a precise and very concise final answer for the given question, conclude with 'So the answer is {answer_format}'. Keep your final answer brief and to the point, followed without any explanation. 194 | """ 195 | 196 | EVALUATION_PROMPT = """An agent is tasked with answering a question using a retrieval tool. 197 | Critically assess its intermediate reasoning process to determine if it leads to the correct answer. 198 | Identify all flaws, inconsistencies, and mistakes in the thought process. 199 | Every imperfection, no matter how small, must be acknowledged. 200 | Evaluate how effectively the reasoning supports the final answer and the overall accuracy of the response. 201 | Ensure the evaluation is extremely harsh, leaving no leniency. 202 | Even if the answer seems close to correct, do not award full marks to maintain strict grading standards. 203 | Assign a score between [0, 100] based on the severity of flaws and the reasoning’s accuracy in leading to the golden answer. 204 | Respond briefly and conclude with: So the score is [Score]. 205 | """ 206 | 207 | def __init__(self, config, prompt_template=None, answer_format="answer", retriever=None, generator=None, max_iter=8, max_children=3, 208 | max_rollouts: int = 64, c: float = 1.414, default_uct_score: float = float("inf"), beta=0.95, 209 | batch_size=50, max_workers=50): 210 | self.begin_reasoning_prompt = PromptTemplate( 211 | config=config, 212 | system_prompt=f"{self.BEGIN_REASONING_PROMPT.format(answer_format=answer_format)}", 213 | user_prompt="Question: {question}", 214 | enable_chat=True 215 | ) 216 | 217 | self.document_analysis_prompt = PromptTemplate( 218 | config=config, 219 | system_prompt=f"{self.DOCUMENT_ANALYSIS_PROMPT}", 220 | user_prompt="Question: {question}. Reference: {reference}", 221 | reference_template="Wikipedia Title: {title}\n{text}\n\n", 222 | enable_chat=True 223 | ) 224 | 225 | self.reasoning_prompt = PromptTemplate( 226 | config=config, 227 | system_prompt=f"{self.REASONING_PROMPT.format(answer_format=answer_format)}", 228 | user_prompt="Question: {question}", 229 | enable_chat=True 230 | ) 231 | 232 | self.evaluation_prompt = PromptTemplate( 233 | config=config, 234 | system_prompt=f"{self.EVALUATION_PROMPT}", 235 | user_prompt="Question: {question}", 236 | enable_chat=True 237 | ) 238 | 239 | self.answer_generation_prompt = PromptTemplate( 240 | config=config, 241 | system_prompt=f"{self.ANSWER_GENERATION_PROMPT.format(answer_format=answer_format)}", 242 | user_prompt="Question: {question}", 243 | enable_chat=True 244 | ) 245 | 246 | super().__init__(config, self.begin_reasoning_prompt) 247 | self.generator = get_generator(config) if generator is None else generator 248 | self.retriever = get_retriever(config) if retriever is None else retriever 249 | self.max_iter = max_iter 250 | self.max_children = max_children 251 | self.max_rollouts = max_rollouts 252 | self.c = c 253 | self.default_uct_score = default_uct_score 254 | self.index = 0 255 | self.beta = beta 256 | 257 | self.batch_size = batch_size 258 | self.max_workers = max_workers or os.cpu_count() 259 | 260 | self.stop_tokens = ["<|im_end|>", "<|endoftext|>", "", "", ""] 261 | 262 | def initialize(self, item) -> MCTSRAGNode: 263 | self.index = 0 264 | thoughts = [] 265 | question = item.question 266 | iter_num = 0 267 | next_state = "begin_reasoning" 268 | node = MCTSRAGNode(self.index, question, None, -1, step=iter_num, 269 | next_state=next_state, question=question, 270 | thoughts=thoughts, golden_answers=item.golden_answers) 271 | return node 272 | 273 | def search(self, item): 274 | root = self.initialize(item) 275 | for _ in range(self.max_rollouts): 276 | leaf = self.select(root) 277 | child = self.expand(leaf) 278 | Q, reward = self.simulate(child) 279 | self.backpropagate(child, Q, reward) 280 | 281 | self.tree2log(root, item) 282 | 283 | return root 284 | 285 | def tree2log(self, root, item): 286 | q = queue.Queue() 287 | q.put(root) 288 | 289 | while not q.empty(): 290 | node = q.get() 291 | self.node2log(node, item) 292 | 293 | for child in node.children: 294 | q.put(child) 295 | 296 | def node2log(self, node, item): 297 | node_dict = node.node_dict 298 | node_dict["parent_id"] = node.parent_id 299 | node_dict["children_ids"] = node.children_ids 300 | 301 | item.update_output( 302 | f"intermediate_node_{node.id}", 303 | node_dict 304 | ) 305 | 306 | def select(self, node: MCTSRAGNode) -> MCTSRAGNode: 307 | while not self.is_terminal(node): 308 | if not node.is_fully_expanded(self.max_children): 309 | return node 310 | node = self._select_child(node) 311 | return node 312 | 313 | def expand(self, node: MCTSRAGNode) -> MCTSRAGNode: 314 | if node.is_fully_expanded(self.max_children): 315 | return node 316 | 317 | child = self.get_next_state(node, True) 318 | return child 319 | 320 | def _select_child(self, node: MCTSRAGNode) -> MCTSRAGNode: 321 | return max(node.children, key=lambda child: self._uct(child)) 322 | 323 | def _uct(self, node: MCTSRAGNode) -> float: 324 | if node.N == 0: 325 | return self.default_uct_score 326 | return node.Q + self.c * math.sqrt(math.log(node.parent.N) / node.N) 327 | 328 | def _best_child(self, node: MCTSRAGNode) -> MCTSRAGNode: 329 | return max(node.children, key=lambda child: child.N) 330 | 331 | def simulate(self, node): 332 | Q = self.evaluate_thoughts(node) 333 | reward = self.get_reward(node) if self.is_terminal(node) else None 334 | return Q, reward 335 | 336 | def is_terminal(self, node: MCTSRAGNode) -> bool: 337 | if node.next_state is None or node.max_tokens_reached: 338 | return True 339 | 340 | return False 341 | 342 | def _simulate_policy(self, node: MCTSRAGNode): 343 | action = node.next_state 344 | return action 345 | 346 | def get_next_state(self, parent_node: MCTSRAGNode, use_index=False): 347 | action = parent_node.next_state 348 | if action is None: 349 | return None 350 | 351 | response, node_dict = "", {} 352 | thoughts = copy.copy(parent_node.thoughts) 353 | if action == "begin_reasoning": 354 | response, node_dict = self.handle_begin_reasoning(parent_node.question, thoughts) 355 | elif action == "document_analysis": 356 | response, node_dict = self.handle_document_analysis(parent_node.question, thoughts) 357 | elif action == "reasoning": 358 | response, node_dict = self.handle_reasoning(parent_node.question, thoughts) 359 | elif action == "answer_generation": 360 | response, node_dict = self.handle_answer_generation(parent_node.question, thoughts) 361 | 362 | new_child_node_id = -1 363 | if use_index: 364 | self.index += 1 365 | new_child_node_id = self.index 366 | 367 | node_dict["id"] = new_child_node_id 368 | node_dict["parent_id"] = parent_node.id 369 | next_action_state_for_child = self.next_state(action, response, parent_node.step) 370 | 371 | child_node = MCTSRAGNode( 372 | new_child_node_id, 373 | parent_node.S + " " + response, 374 | parent_node, 375 | parent_node.id, 376 | parent_node.step + 1, 377 | next_action_state_for_child, 378 | parent_node.question, 379 | thoughts, 380 | parent_node.golden_answers, 381 | node_dict 382 | ) 383 | 384 | if self.prompt_template.check_prompt_length(node_dict["input_prompt"]): 385 | child_node.max_tokens_reached = True 386 | 387 | if use_index: 388 | parent_node.add_child(child_node) 389 | 390 | return child_node 391 | 392 | def get_reward(self, node: MCTSRAGNode): 393 | pred = extract_answer(node.node_dict["response"]) 394 | golden_answers = node.golden_answers 395 | evaluator = F1_Score(self.config) 396 | reward = evaluator.token_level_scores(pred, golden_answers) 397 | reward = reward['f1'] * self.beta ** node.step 398 | return reward 399 | 400 | def backpropagate(self, node, Q, reward): 401 | while node: 402 | node.update_Q(Q) 403 | if reward is not None: 404 | node.update_reward(reward) 405 | node = node.parent 406 | 407 | def next_state(self, current_action: str, response: str, iter_num: int): 408 | action_transitions = { 409 | "begin_reasoning": lambda res: None if re.search(r'.*?', 410 | res, re.DOTALL) else "document_analysis", 411 | "reasoning": lambda res: None if re.search(r'.*?', 412 | res, re.DOTALL) else "document_analysis", 413 | "document_analysis": lambda res: "reasoning", 414 | "answer_generation": lambda res: None, 415 | } 416 | 417 | # Retrieve the function for the current action 418 | transition_function = action_transitions.get(current_action) 419 | 420 | if iter_num == self.max_iter - 1: 421 | return "answer_generation" 422 | 423 | if iter_num < self.max_iter-1 and transition_function: 424 | return transition_function(response) 425 | 426 | return None 427 | 428 | def extract_answer(self, response: str) -> str: 429 | match = re.search(r'So the answer is\s*(.*?)(?=\n|$)', response, re.IGNORECASE | re.DOTALL) 430 | if not match: 431 | return "" 432 | 433 | text = match.group(1).strip() 434 | # Remove special tokens 435 | text = re.sub(r'', '', text) 436 | return text.strip() 437 | 438 | def extract_query(self, response: str) -> str: 439 | match = re.search(r'So the next query is\s*(.*?)(?=\n|$)', response, re.IGNORECASE | re.DOTALL) 440 | if not match: 441 | return "" 442 | 443 | text = match.group(1).strip() 444 | # Remove special tokens 445 | text = re.sub(r'', '', text) 446 | return text.strip() 447 | 448 | def delete_tokens(self, response: str) -> str: 449 | cleaned_response = re.sub(r'.*?|.*?', '', response, flags=re.DOTALL) 450 | return cleaned_response 451 | 452 | def initialize(self, item) -> MCTSRAGNode: 453 | self.index = 0 454 | thoughts = [] 455 | question = item.question 456 | iter_num = 0 457 | next_state = "begin_reasoning" 458 | node = MCTSRAGNode(self.index, question, None, -1, step=iter_num, 459 | next_state=next_state, question=question, 460 | thoughts=thoughts, golden_answers=item.golden_answers) 461 | return node 462 | 463 | def test_item(self, item): 464 | root = self.initialize(item) 465 | node = root 466 | 467 | while not self.is_terminal(node): 468 | next_node = self.get_next_state(node, use_index=True) 469 | node = next_node 470 | 471 | self.tree2log(root, item) 472 | if "answer" in node.node_dict and node.node_dict["answer"] is not None: 473 | item.update_output("pred", node.node_dict["answer"]) 474 | else: 475 | item.update_output("pred", "none") 476 | return node 477 | 478 | def process_annotation(self, dataset, do_eval=True, pred_process_fun=extract_answer): 479 | for item in tqdm(dataset, desc="Inference: "): 480 | self.search(item) 481 | 482 | save_data(self.config["save_dir"], dataset) 483 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 484 | return dataset 485 | 486 | def run(self, dataset, do_eval=True, batch_size=16, pred_process_fun=None): 487 | all_dataset_list = [] 488 | for batch_dataset in tqdm(get_batch_dataset(dataset, batch_size=batch_size), desc="Batch dataset: "): 489 | batch_dataset = self.run_batch(batch_dataset) 490 | all_dataset_list.append(batch_dataset) 491 | dataset = merge_batch_dataset(all_dataset_list) 492 | save_data(self.config["save_dir"], dataset) 493 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 494 | return dataset 495 | 496 | def get_flags(self, responses): 497 | flags = [] 498 | for response in responses: 499 | if "So the next query is" in response: 500 | flag = "query" 501 | elif "So the answer is" in response: 502 | flag = "answer" 503 | elif "" in response: 504 | flag = "evidence" 505 | else: 506 | flag = "None" 507 | flags.append(flag) 508 | return flags 509 | 510 | def get_querys(self, responses): 511 | querys = [] 512 | for response in responses: 513 | querys.append(self.extract_query(response)) 514 | return querys 515 | 516 | def get_answers(self, responses): 517 | answers = [] 518 | for response in responses: 519 | answers.append(self.extract_answer(response)) 520 | return answers 521 | 522 | def all_finished(self, next_flags): 523 | return all(flag in ["answer", "finish"] for flag in next_flags) 524 | 525 | def run_batch(self, dataset): 526 | for item in dataset: 527 | item.update_output('finish_flag', False) 528 | item.update_output('iteration_count', 0) 529 | item.update_output('previous_thoughts', []) 530 | item.update_output('flag', None) 531 | item.update_output('query', None) 532 | item.update_output('answer', None) 533 | 534 | input_prompts = [self.begin_reasoning_prompt.get_string(question=item.question) for item in dataset] 535 | responses = self.generator.generate(input_prompts, stop=self.stop_tokens) 536 | 537 | for i, item in enumerate(dataset): 538 | response = responses[i] 539 | item.previous_thoughts.append(response) 540 | item.flag = self.get_flags([response])[0] 541 | item.query = self.get_querys([response])[0] 542 | item.answer = self.get_answers([response])[0] 543 | 544 | node_dict = { 545 | "action_name": "begin_reasoning", 546 | "input_prompt": input_prompts[i], 547 | "response": response, 548 | "query": item.query, 549 | "answer": item.answer, 550 | } 551 | item.update_output(f"intermediate_node_0", node_dict) 552 | if item.flag in ["finish", "answer"]: 553 | item.finish_flag = True 554 | 555 | for step in range(self.max_iter): 556 | exist_items = [item for item in dataset if not item.finish_flag] 557 | if not exist_items: 558 | break 559 | 560 | active_questions = [item.question for item in exist_items] 561 | active_previous_thoughts = [item.previous_thoughts for item in exist_items] 562 | active_flags = [item.flag for item in exist_items] 563 | active_querys = [item.query for item in exist_items] 564 | 565 | question_thoughts_list = [ 566 | q + "\nPrevious Thoughts: " + " ".join(thoughts) 567 | for q, thoughts in zip(active_questions, active_previous_thoughts) 568 | ] 569 | 570 | retrieval_results = self.retriever.batch_search(active_querys) 571 | 572 | input_prompts = [] 573 | for i, item in enumerate(exist_items): 574 | if item.iteration_count >= self.max_iter - 1: 575 | input_prompts.append(self.answer_generation_prompt.get_string(question=question_thoughts_list[i])) 576 | elif "query" in item.flag: 577 | input_prompts.append(self.document_analysis_prompt.get_string( 578 | question=question_thoughts_list[i], retrieval_result=retrieval_results[i] 579 | )) 580 | elif "evidence" in item.flag: 581 | input_prompts.append(self.reasoning_prompt.get_string(question=question_thoughts_list[i])) 582 | else: 583 | input_prompts.append(self.answer_generation_prompt.get_string(question=question_thoughts_list[i])) 584 | 585 | responses = self.generator.generate(input_prompts, stop=self.stop_tokens) 586 | for i, item in enumerate(exist_items): 587 | response = responses[i] 588 | item.previous_thoughts.append(response) 589 | item.iteration_count += 1 590 | item.flag = self.get_flags([response])[0] 591 | item.query = self.get_querys([response])[0] 592 | item.answer = self.get_answers([response])[0] 593 | 594 | node_dict = { 595 | "action_name": "unknown", 596 | "input_prompt": input_prompts[i], 597 | "response": response, 598 | "query": item.query, 599 | "answer": item.answer, 600 | } 601 | if "evidence" in item.flag: 602 | node_dict["action_name"] = "document_analysis" 603 | node_dict["retrieval_result"] = retrieval_results[i] 604 | elif "query" in item.flag: 605 | node_dict["action_name"] = "query_generation" 606 | elif "answer" in item.flag: 607 | node_dict["action_name"] = "answer_generation" 608 | elif "finish" in item.flag: 609 | node_dict["action_name"] = "finish" 610 | 611 | item.update_output(f"intermediate_node_{item.iteration_count}", node_dict) 612 | 613 | if item.flag in ["finish", "answer"] or item.iteration_count >= self.max_iter: 614 | item.finish_flag = True 615 | 616 | for i in range(len(dataset)): 617 | dataset[i].pred = dataset[i].answer 618 | 619 | return dataset 620 | 621 | def handle_begin_reasoning(self, question, thoughts): 622 | begin_reasoning = [self.begin_reasoning_prompt.get_string(question=question)] 623 | begin_response = self.generator.generate(begin_reasoning)[0] 624 | begin_response = process_text(begin_response) 625 | thoughts.append(begin_response) 626 | 627 | query_matches = re.findall(r'(.*?)', begin_response) 628 | extracted_query = self.delete_tokens(query_matches[-1]) if query_matches else None 629 | 630 | answer_matches = re.findall(r'(.*?)', begin_response) 631 | extracted_answer = self.delete_tokens(answer_matches[-1]) if answer_matches else None 632 | 633 | node_dict = { 634 | "action_name": "begin_reasoning", 635 | "input_prompt": begin_reasoning[0], 636 | "response": begin_response, 637 | "next_query": extracted_query, 638 | "answer": extracted_answer, 639 | } 640 | 641 | return begin_response, node_dict 642 | 643 | def handle_reasoning(self, question, thoughts): 644 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts) 645 | reasoning_prompt = [self.reasoning_prompt.get_string(question=question_thoughts)] 646 | response = self.generator.generate(reasoning_prompt)[0] 647 | response = process_text(response) 648 | thoughts.append(response) 649 | 650 | query_matches = re.findall(r'(.*?)', response) 651 | extracted_query = self.delete_tokens(query_matches[-1]) if query_matches else None 652 | 653 | answer_matches = re.findall(r'(.*?)', response) 654 | extracted_answer = self.delete_tokens(answer_matches[-1]) if answer_matches else None 655 | 656 | node_dict = { 657 | "action_name": "reasoning", 658 | "input_prompt": reasoning_prompt[0], 659 | "response": response, 660 | "next_query": extracted_query, 661 | "answer": extracted_answer, 662 | } 663 | 664 | return response, node_dict 665 | 666 | def handle_document_analysis(self, question, thoughts): 667 | extracted_query = self.delete_tokens(self.extract_query(thoughts[-1])) 668 | id2doc = {} 669 | doc2score = {} 670 | retrieval_result, scores = self.retriever.search(extracted_query, return_score=True) 671 | for doc_item, score in zip(retrieval_result, scores): 672 | id2doc[doc_item["id"]] = doc_item 673 | doc_id = doc_item["id"] 674 | if doc_id in doc2score: 675 | doc2score[doc_id] = max(doc2score[doc_id], score) 676 | else: 677 | doc2score[doc_id] = score 678 | 679 | sorted_doc_score = sorted(doc2score.items(), key=lambda x: x[1], reverse=False) 680 | sorted_doc_id = [t[0] for t in sorted_doc_score] 681 | retrieval_result = [id2doc[id] for id in sorted_doc_id] 682 | 683 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts) 684 | grounding_prompt = [self.document_analysis_prompt.get_string( 685 | question=question_thoughts, retrieval_result=retrieval_result 686 | )] 687 | response = self.generator.generate(grounding_prompt)[0] 688 | thoughts.append(response) 689 | 690 | node_dict = { 691 | "action_name": "document_analysis", 692 | "input_prompt": grounding_prompt[0], 693 | "query": extracted_query, 694 | "response": response, 695 | "retrieval_result": retrieval_result, 696 | } 697 | 698 | return response, node_dict 699 | 700 | def handle_answer_generation(self, question, thoughts): 701 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts) 702 | answer_generation = [self.answer_generation_prompt.get_string(question=question_thoughts)] 703 | answer = self.generator.generate(answer_generation)[0] 704 | answer = process_text(answer) 705 | thoughts.append(answer) 706 | pred = extract_answer(answer) 707 | node_dict = { 708 | "action_name": "answer_generation", 709 | "input_prompt": answer_generation[0], 710 | "response": answer, 711 | "pred": pred, 712 | } 713 | 714 | return answer, node_dict 715 | 716 | def evaluate_thoughts(self, node): 717 | question_thoughts = node.question + "\nGolden Answer: " + " or ".join( 718 | node.golden_answers) + "\nAgent Reasoning Process: " + " ".join(node.thoughts) 719 | evaluation_prompt = [self.evaluation_prompt.get_string(question=question_thoughts)] 720 | evaluation_response = self.generator.generate(evaluation_prompt)[0] 721 | Q = extract_last_number(evaluation_response) 722 | Q = float(Q) / 100 723 | return Q 724 | -------------------------------------------------------------------------------- /preference_data_generatoin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | from typing import Dict, List 5 | import numpy as np 6 | import re 7 | import tiktoken 8 | import string 9 | from glob import glob 10 | 11 | BETA = 0.9 12 | ENCODER = tiktoken.get_encoding("cl100k_base") 13 | 14 | def count_tokens(text: str) -> int: 15 | return len(ENCODER.encode(text)) 16 | 17 | def get_action_type(response: str) -> str: 18 | if "" in response: 19 | return "Query Generation" 20 | elif "" in response: 21 | return "Evidence Extraction" 22 | elif "" in response: 23 | return "Answer Generation" 24 | return "Other" 25 | 26 | def get_content(response: str) -> str: 27 | for tag in ["answer", "query", "evidence"]: 28 | match = re.search(rf"<{tag}>(.*?)", response, re.DOTALL) 29 | if match: 30 | return match.group(1).strip() 31 | return "" 32 | 33 | def get_action_type_and_content(response: str) -> tuple[str, str]: 34 | action_type = get_action_type(response) 35 | content = get_content(response) 36 | return action_type, content 37 | 38 | def process_node(data: Dict, node_id: int, results: List, action_type_counts: Dict[str, int], layer: int = 0): 39 | current_node = data[f"intermediate_node_{node_id}"] 40 | children_ids = current_node.get("children_ids", []) 41 | if len(children_ids) < 2: 42 | return 43 | 44 | valid_children = [] 45 | prompt = "" 46 | system_prompt = "" 47 | user_prompt = "" 48 | for child_id in children_ids: 49 | child_node = data.get(f"intermediate_node_{child_id}") 50 | if child_node and "reward" in child_node and child_node["reward"] is not None and "input_prompt" in child_node: 51 | valid_children.append(child_node) 52 | input_prompt = child_node["input_prompt"] 53 | system_prompt = input_prompt[0]["content"] 54 | user_prompt = input_prompt[1]["content"] 55 | prompt = f"System: {system_prompt}\nUser: {user_prompt}" 56 | 57 | for i in range(len(valid_children)): 58 | for j in range(i + 1, len(valid_children)): 59 | node_a, node_b = valid_children[i], valid_children[j] 60 | reward_a, reward_b = float(node_a["reward"]), float(node_b["reward"]) 61 | if abs(reward_a - reward_b) < 0.01: 62 | continue 63 | 64 | chosen, rejected = (node_a, node_b) if reward_a > reward_b else (node_b, node_a) 65 | chosen_action_type, chosen_content = get_action_type_and_content(chosen["response"]) 66 | rejected_action_type, rejected_content = get_action_type_and_content(rejected["response"]) 67 | 68 | if chosen_content == rejected_content or chosen_action_type == "Other": 69 | continue 70 | 71 | action_type_counts[chosen_action_type] += 1 72 | results.append({ 73 | "instruction": system_prompt, 74 | "input": user_prompt, 75 | "prompt": prompt, 76 | "chosen": chosen["response"], 77 | "rejected": rejected["response"], 78 | "chosen_action_type": chosen_action_type, 79 | "chosen_content": chosen_content, 80 | "rejected_action_type": rejected_action_type, 81 | "rejected_content": rejected_content, 82 | "layer": layer 83 | }) 84 | 85 | for child in valid_children: 86 | process_node(data, child["id"], results, action_type_counts, layer + 1) 87 | 88 | def process_json_files(folder_list: List[str], output_file: str = "dpo_data.json"): 89 | all_results = [] 90 | action_type_counts = defaultdict(int) 91 | 92 | for folder in folder_list: 93 | pattern = os.path.join(folder, "chunk_*", "reward_*.json") 94 | for json_path in glob(pattern, recursive=True): 95 | try: 96 | with open(json_path, "r") as f: 97 | items = json.load(f) 98 | except Exception: 99 | continue 100 | 101 | for item in items: 102 | if "output" not in item or "golden_answers" not in item or item["golden_answers"] is None: 103 | continue 104 | 105 | output_data = item["output"] 106 | root_node = output_data.get("intermediate_node_0") 107 | if not root_node or root_node.get("reward") is None: 108 | continue 109 | 110 | process_node(output_data, 0, all_results, action_type_counts) 111 | 112 | with open(output_file, "w") as f: 113 | json.dump(all_results, f, indent=2) 114 | 115 | def normalize_answer(s: str) -> str: 116 | def white_space_fix(text): 117 | return " ".join(text.split()) 118 | def remove_punc(text): 119 | exclude = set(string.punctuation) 120 | return "".join(ch for ch in text if ch not in exclude) 121 | def lower(text): 122 | return text.lower() 123 | return white_space_fix(remove_punc(lower(s))).strip() 124 | 125 | def compute_f1_single(prediction: str, ground_truth: str) -> float: 126 | pred, gt = normalize_answer(prediction), normalize_answer(ground_truth) 127 | if pred in ["yes", "no", "noanswer"] and pred != gt: 128 | return 0.0 129 | if gt in ["yes", "no", "noanswer"] and pred != gt: 130 | return 0.0 131 | 132 | pred_tokens, gt_tokens = pred.split(), gt.split() 133 | common = sum((Counter(pred_tokens) & Counter(gt_tokens)).values()) 134 | if common == 0: 135 | return 0.0 136 | precision = common / len(pred_tokens) if pred_tokens else 0.0 137 | recall = common / len(gt_tokens) if gt_tokens else 0.0 138 | return (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 139 | 140 | def compute_f1_multiple_gt(prediction: str, ground_truths: List[str]) -> float: 141 | return max(compute_f1_single(prediction, gt) for gt in ground_truths) 142 | 143 | def calculate_mcts_reward_bottom_up(data: Dict, golden_answers: List[str]) -> Dict: 144 | unprocessed_children_count = { 145 | k: len(data[k].get("children_ids", [])) for k in data if "intermediate_node_" in k 146 | } 147 | queue = deque(k for k, count in unprocessed_children_count.items() if count == 0) 148 | rewards = {} 149 | updated_nodes = data.copy() 150 | 151 | while queue: 152 | node_key = queue.popleft() 153 | node = updated_nodes[node_key] 154 | 155 | if not node.get("children_ids"): 156 | reward = None 157 | if "answer" in node and node.get("response") and node["answer"] is not None: 158 | f1 = compute_f1_multiple_gt(node["answer"], golden_answers) 159 | step = node.get("step", 0) 160 | reward = f1 * (BETA ** step) 161 | else: 162 | children_keys = [f"intermediate_node_{cid}" for cid in node["children_ids"]] 163 | valid_children = [ck for ck in children_keys if ck in rewards and rewards[ck] is not None] 164 | reward = None 165 | if valid_children: 166 | total_n = sum(updated_nodes[ck].get("N", 0) for ck in valid_children) 167 | if total_n > 0: 168 | reward = sum(rewards[ck] * updated_nodes[ck].get("N", 0) for ck in valid_children) / total_n 169 | 170 | rewards[node_key] = reward 171 | if reward is not None: 172 | updated_nodes[node_key]["reward"] = reward 173 | elif "reward" in updated_nodes[node_key]: 174 | del updated_nodes[node_key]["reward"] 175 | 176 | parent_id = node.get("parent_id") 177 | if parent_id is not None: 178 | parent_key = f"intermediate_node_{parent_id}" 179 | if parent_key in updated_nodes: 180 | unprocessed_children_count[parent_key] -= 1 181 | if unprocessed_children_count[parent_key] == 0: 182 | queue.append(parent_key) 183 | 184 | return updated_nodes 185 | 186 | def process_json_file_and_calculate_reward_save(json_path: str): 187 | try: 188 | with open(json_path, "r") as f: 189 | items = json.load(f) 190 | updated_items = [] 191 | for item in items: 192 | if "output" not in item or item.get("golden_answers") is None: 193 | updated_items.append(item) 194 | continue 195 | 196 | mcts_data = item["output"] 197 | if "intermediate_node_0" not in mcts_data: 198 | updated_items.append(item) 199 | continue 200 | 201 | updated_nodes = calculate_mcts_reward_bottom_up(mcts_data, item["golden_answers"]) 202 | modified_item = item.copy() 203 | modified_item["output"] = updated_nodes 204 | updated_items.append(modified_item) 205 | 206 | chunk_folder = os.path.dirname(json_path) 207 | base_name = os.path.splitext(os.path.basename(json_path))[0] 208 | sequence_number = base_name.split("_")[-1] 209 | reward_output_path = os.path.join(chunk_folder, f"reward_{sequence_number}.json") 210 | with open(reward_output_path, "w") as f: 211 | json.dump(updated_items, f, indent=2) 212 | 213 | except Exception: 214 | pass 215 | 216 | def process_folder_and_calculate_rewards_save(folder_path: str): 217 | pattern = os.path.join(folder_path, "chunk_*", "intermediate_*.json") 218 | for json_path in glob(pattern, recursive=True): 219 | process_json_file_and_calculate_reward_save(json_path) 220 | 221 | def main(): 222 | folder_list = ["output/2wikimultihopqa_2025_02_04_01_49_experiment"] 223 | for folder in folder_list: 224 | process_folder_and_calculate_rewards_save(folder) 225 | process_json_files(folder_list, output_file="training_data/RAG_ProGuide.json") 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /training_config/qwen_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-7B-Instruct 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: dpo 7 | do_train: true 8 | finetuning_type: full 9 | pref_beta: 0.3 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | 12 | ### dataset 13 | dataset: RAG_ProGuide 14 | template: qwen 15 | cutoff_len: 10000 16 | max_samples: 1000000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 8 19 | 20 | ### output 21 | output_dir: saves/qwen2.5-7b-instruct/full/dpo 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 2 30 | learning_rate: 1.0e-6 31 | num_train_epochs: 1.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.2 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.1 39 | per_device_eval_batch_size: 1 40 | eval_strategy: steps 41 | eval_steps: 500 42 | 43 | --------------------------------------------------------------------------------