├── README.md
├── data_generation.py
├── images
└── main_results.png
├── inference.py
├── my_config.yaml
├── pipeline
└── reasonrag_pipeline.py
├── preference_data_generatoin.py
└── training_config
└── qwen_dpo.yaml
/README.md:
--------------------------------------------------------------------------------
1 |
ReasonRAG: Enhancing Agentic RAG with Process-Supervised Reinforcement Learning
2 |
3 |
10 |
11 | If you like our project, please give us a star ⭐ on GitHub for the latest update.
12 |
13 | # 💡 Overview
14 |
15 | Recent advancements in outcome-supervised Reinforcement Learning (RL), exemplified by OpenAI's O1 and DeepMind's R1, have demonstrated remarkable improvements in language model (LLM) reasoning capabilities. Integrating outcome-supervised RL with search engines presents another promising avenue for boosting LLM reasoning. However, outcome-supervised RL often grapples with challenges such as sparse reward, training instability, and inefficient exploration.
16 |
17 | To address these limitations, process-supervised RL emerges as a compelling solution for enhancing Agentic RAG, offering the advantage of fine-grained rewards. We introduce **ReasonRAG**, a process-supervised method designed to refine Agentic RAG's strategic preferences.
18 |
19 | Our approach consists of three key steps:
20 |
21 | 1. We leverage Monte Carlo Tree Search (MCTS) to generate process-supervised rollouts, yielding rich data on process-level strategic preferences.
22 | 2. We then employ Direct Preference Optimization (DPO) to effectively optimize these strategic preferences within the Agentic RAG framework.
23 | 3. Finally, we construct an Agentic RAG pipeline that enables the LLM to autonomously generate queries, extract evidence, and formulate answers.
24 | We provide the dataset we constructed and links to our trained models below.
25 |
26 | * **RAG_ProGuide Dataset:** [https://huggingface.co/datasets/reasonrag/RAG_ProGuide](https://huggingface.co/datasets/reasonrag/RAG_ProGuide)
27 | * **Trained Models:** [Qwen2.5-7B-Instruct-ReasonRAG](https://huggingface.co/reasonrag/Qwen2.5-7B-Instruct-ReasonRAG)
28 | * **Trained Lora Models:** [Qwen2.5-7B-Instruct-RAG-Lora](https://huggingface.co/reasonrag/Qwen2.5-7B-Instruct-RAG-Lora)
29 |
30 | ReasonRAG achieves superior performance on five benchmark datasets using only 5k training instances, significantly fewer than the 90k training instances required by Search-R1.
31 |
32 | 
33 |
34 | # ✨ Method
35 | We employ process-supervised RL to enhance Agentic RAG capabilities:
36 | 1. Generate process-supervised reward data.
37 | 2. Policy Preference Optimization
38 | 3. Agentic RAG Inference
39 |
40 | ## Data
41 | We randomly data from PopQA, HotpotQA, 2WikimultihopQA to generate process-supervised preference data. Then, we use GPT-4o as the policy model to generate rollout data. The generated process-supervised data, namely RAG-ProGuide is available at: [https://huggingface.co/datasets/reasonrag/RAG_ProGuide](https://huggingface.co/datasets/reasonrag/RAG_ProGuide)
42 |
43 | # 🏃 Quick Start
44 | ## Environment Settings
45 | Construct [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) environments:
46 | ```bash
47 | conda create --name reasonrag python=3.10.16
48 | conda activate reasonrag
49 | pip install flashrag-dev --pre
50 | pip install flashrag-dev[full]
51 | pip install vllm>=0.4.1
52 | pip install deepspeed
53 | ```
54 |
55 | ## Data Preparation
56 |
57 | Download wikidump as the corpus for retrieval
58 |
59 | ```bash
60 | # Download wikidump
61 | wget https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-articles.xml.bz2
62 |
63 | # Build index
64 | python -m flashrag.retriever.index_builder \
65 | --retrieval_method bge \
66 | --model_path /BAAI/bge-base-en-v1.5 \
67 | --corpus_path indexes/wiki18.jsonl \
68 | --save_dir indexes/ \
69 | --use_fp16 \
70 | --max_length 512 \
71 | --batch_size 256 \
72 | --pooling_method mean \
73 | --faiss_type Flat
74 | ```
75 |
76 | Download QA dataset from huggingface [RUC-NLPIR/FlashRAG_datasets](https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets)
77 |
78 | ## Data Generation
79 | > Note: This code generates policy preference data. You can directly use the RAG-ProGuide dataset (linked above!), or run this code to generate your own, or adapt it as needed.
80 | ```bash
81 | python data_generation.py --dataset_name popqa --model gpt-4o
82 | python data_generation.py --dataset_name hotpotqa --model gpt-4o
83 | python data_generation.py --dataset_name 2wikimultihopqa --model gpt-4o
84 | python preference_data_generation.py
85 | ```
86 |
87 | ## Training
88 | ```bash
89 | # Install LLaMA Factory
90 | git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
91 | cd LLaMA-Factory
92 | pip install -e ".[torch,metrics]"
93 |
94 | # Set the dataset path before prefrence optimization
95 | llamafactory-cli train training_config/qwen_dpo.yaml
96 | ```
97 |
98 | ## Inference
99 | ```bash
100 | python inference.py --dataset_name hotpotqa --model $MODEL_NAME
101 | ```
102 |
--------------------------------------------------------------------------------
/data_generation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import time
4 | from multiprocessing import Process
5 | import copy
6 | import os
7 | from tqdm import tqdm
8 | import yaml
9 | from flashrag.config import Config
10 | from flashrag.utils import get_dataset
11 | from pipeline.reasonrag_pipeline import ReasonRAGPipeline
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--dataset_name", type=str)
15 | parser.add_argument("--model", type=str)
16 | args = parser.parse_args()
17 |
18 |
19 | root_dir = 'output'
20 |
21 | def load_config_from_yaml(yaml_file):
22 | try:
23 | with open(yaml_file, "r") as file:
24 | return yaml.safe_load(file)
25 | except Exception as e:
26 | print(f"Error loading YAML file: {e}")
27 | return {}
28 |
29 | default_config = load_config_from_yaml("my_config.yaml")
30 |
31 | config_dict = {
32 | "data_dir": "dataset/",
33 | "dataset_name": args.dataset_name,
34 | "index_path": "indexes/bge_Flat_wiki_extend.index",
35 | "corpus_path": "indexes/wiki18_100w_extend.jsonl",
36 | "model2path": {
37 | "bge": "BAAI/bge-base-en-v1.5",
38 | "qwen2.5": "Qwen/Qwen2.5-7B-Instruct",
39 | },
40 | "generator_model": "gpt-4o-2024-05-13",
41 | "retrieval_method": "bge",
42 | "framework": "openai",
43 | "gpu_id": "1",
44 | "faiss_gpu": True,
45 | "metrics": ["em", "f1", "acc", "recall", "precision"],
46 | "retrieval_topk": 3,
47 | "save_intermediate_data": True,
48 | "save_note": args.model + "_MCTS",
49 | }
50 |
51 | answer_format = "answer"
52 | max_iter = 8
53 |
54 | config_dict = {**default_config, **config_dict}
55 | config = Config(config_dict=config_dict)
56 |
57 | dataset_path = config["dataset_path"]
58 | split_path = os.path.join(dataset_path, "train.jsonl")
59 | data_split = "train"
60 | all_split = get_dataset(config)
61 | test_data = all_split[data_split]
62 |
63 | def save_data(save_dir, data, file_name="intermediate_data.json"):
64 | data = [item.to_dict() for item in data]
65 | save_path = os.path.join(save_dir, file_name)
66 | with open(save_path, "w", encoding="utf-8") as f:
67 | json.dump(data, f, indent=4)
68 |
69 |
70 | def parallel_process_dataset(config, test_data, num_processes=4):
71 | total_items = len(test_data)
72 | if total_items == 0:
73 | print("No data to process.")
74 | return
75 |
76 | chunk_size = total_items // num_processes
77 | remainder = total_items % num_processes
78 |
79 | chunks = []
80 | start = 0
81 | for i in range(num_processes):
82 | end = start + chunk_size
83 | if i < remainder:
84 | end += 1
85 | if start >= total_items:
86 | break
87 | chunks.append(test_data[start:end])
88 | start = end
89 |
90 | processes = []
91 | for chunk_idx, chunk in enumerate(chunks):
92 | print(f"Chunk {chunk_idx} ready.")
93 | p = Process(target=process_chunk, args=(copy.copy(config), chunk, chunk_idx))
94 | processes.append(p)
95 | p.start()
96 |
97 | for p in processes:
98 | p.join()
99 |
100 |
101 | def process_chunk(config, chunk, chunk_idx):
102 | save_dir = os.path.join(config["save_dir"], f"chunk_{chunk_idx}")
103 | os.makedirs(save_dir, exist_ok=True)
104 | config["save_dir"] = save_dir
105 |
106 | pipeline = ReasonRAGPipeline(config, prompt_template=None, max_iter=7, max_children=2, max_rollouts=64)
107 |
108 | i = 0
109 | start_time = time.time()
110 |
111 | for item in tqdm(chunk, desc=f"Chunk {chunk_idx}"):
112 | try:
113 | pipeline.search(item)
114 | except Exception as e:
115 | print(f"Chunk {chunk_idx} Error at item {i}: {e}")
116 | continue
117 | finally:
118 | i += 1
119 |
120 | save_data(save_dir, chunk, file_name=f"final_{chunk_idx}.json")
121 | print(f"Chunk {chunk_idx} processed {len(chunk)} items in {time.time() - start_time:.2f}s")
122 |
123 |
124 | parallel_process_dataset(config, test_data, num_processes=10)
125 |
--------------------------------------------------------------------------------
/images/main_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wlzhang2020/ReasonRAG/225e249a26fa0d1d2fd3c9bf60855d725b7cd435/images/main_results.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import yaml
4 | from flashrag.config import Config
5 | from flashrag.utils import get_dataset
6 | from pipeline.reasonrag_pipeline import ReasonRAGPipeline
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--dataset_name", type=str)
10 | parser.add_argument("--model", type=str)
11 | parser.add_argument("--max_iter", default=8, type=int)
12 | parser.add_argument("--retrieval_top_k", default=3, type=int)
13 | args = parser.parse_args()
14 |
15 | root_dir = 'output'
16 |
17 | def load_config_from_yaml(yaml_file):
18 | try:
19 | with open(yaml_file, "r") as file:
20 | return yaml.safe_load(file)
21 | except Exception as e:
22 | print(f"Error loading YAML file: {e}")
23 | return {}
24 |
25 | default_config = load_config_from_yaml("my_config.yaml")
26 |
27 | config_dict = {
28 | "data_dir": "dataset/",
29 | "dataset_name": args.dataset_name,
30 | "index_path": "indexes/bge_Flat.index",
31 | "retrieval_method": "bge",
32 | "corpus_path": "indexes/wiki18_100w.jsonl",
33 | "model2path": {
34 | "bge": "BAAI/bge-base-en-v1.5",
35 | "e5": "intfloat/e5-base-v2",
36 | "qwen2.5": "Qwen/Qwen2.5-7B",
37 | "qwen2.5-instruct": "Qwen/Qwen2.5-7B-Instruct",
38 | },
39 | "generator_model": args.model,
40 | "generator_batch_size": 1,
41 | "framework": "vllm",
42 | "gpu_id": "0, 1, 2, 3",
43 | "faiss_gpu": True,
44 | "retrieval_batch_size": 256,
45 | "gpu_memory_utilization": 0.5,
46 | "metrics": ["em", "f1", "acc", "recall", "precision"],
47 | "retrieval_topk": args.retrieval_top_k,
48 | "save_intermediate_data": True,
49 | "save_note": args.model + f"_iter{args.max_iter}",
50 | }
51 |
52 | answer_format = "answer"
53 | max_iter = 10
54 |
55 | config_dict = {**default_config, **config_dict}
56 | config = Config(config_dict=config_dict)
57 |
58 | dataset_path = config["dataset_path"]
59 | split_path = os.path.join(dataset_path, "test.jsonl")
60 | data_split = "test"
61 | if not os.path.exists(split_path):
62 | if os.path.exists(os.path.join(dataset_path, "dev.jsonl")):
63 | data_split = "dev"
64 | elif os.path.exists(os.path.join(dataset_path, "val.jsonl")):
65 | data_split = "val"
66 | else:
67 | data_split = "None"
68 |
69 | all_split = get_dataset(config)
70 | test_data = all_split[data_split]
71 |
72 | pipeline = ReasonRAGPipeline(config, prompt_template=None, answer_format=answer_format, max_iter=args.max_iter, max_children=2, max_rollouts=64)
73 | output_dataset = pipeline.run(test_data, batch_size=1000, do_eval=True)
74 |
--------------------------------------------------------------------------------
/my_config.yaml:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------Global Paths------------------------------------------------#
2 | # Paths to models
3 | model2path:
4 | e5: "intfloat/e5-base-v2"
5 | bge: "BAAI/bge-base-en-v1.5"
6 | contriever: "facebook/contriever"
7 | llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf"
8 | llama2-7B: "meta-llama/Llama-2-7b-hf"
9 | llama2-13B: "meta-llama/Llama-2-13b-hf"
10 | llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf"
11 | llama3-8B-instruct: "meta-llama/Meta-Llama-3-8B-Instruct"
12 | qwen2.5-7b-instrucRt: "Qwen/Qwen2.5-7B-Instruct"
13 | qwen2.5-3b-instrucRt: "Qwen/Qwen2.5-3B-Instruct"
14 |
15 |
16 | # Pooling methods for each embedding model
17 | model2pooling:
18 | e5: "mean"
19 | bge: "cls"
20 | contriever: "mean"
21 | jina: 'mean'
22 | dpr: cls
23 |
24 | # Indexes path for retrieval models
25 | method2index:
26 | e5: ~
27 | bm25: ~
28 | bge: ~
29 | contriever: ~
30 |
31 | # ------------------------------------------------Environment Settings------------------------------------------------#
32 | # Directory paths for data and outputs
33 | data_dir: "dataset/"
34 | save_dir: "output/"
35 |
36 | gpu_id: "0"
37 | dataset_name: "nq" # name of the dataset in data_dir
38 | split: [ "test", "train", "dev", "val"] # dataset split to load (e.g. train,dev,test)
39 |
40 | # Sampling configurations for testing
41 | test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples
42 | random_sample: False # whether to randomly sample the test samples
43 |
44 | # Seed for reproducibility
45 | seed: 2024
46 |
47 | # Whether save intermediate data
48 | save_intermediate_data: True
49 | save_note: 'experiment'
50 |
51 | # -------------------------------------------------Retrieval Settings------------------------------------------------#
52 | # If set the name, the model path will be find in global paths
53 | retrieval_method: "bge" # name or path of the retrieval model.
54 | index_path: "indexes/bge_Flat.index" # set automatically if not provided.
55 | faiss_gpu: True # whether use gpu to hold index
56 | corpus_path: "indexes/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents
57 |
58 | instruction: ~ # instruction for retrieval model
59 | retrieval_topk: 5 # number of retrieved documents
60 | retrieval_batch_size: 256 # batch size for retrieval
61 | retrieval_use_fp16: True # whether to use fp16 for retrieval model
62 | retrieval_query_max_length: 128 # max length of the query
63 | save_retrieval_cache: True # whether to save the retrieval cache
64 | use_retrieval_cache: False # whether to use the retrieval cache
65 | retrieval_cache_path: ~ # path to the retrieval cache
66 | retrieval_pooling_method: ~ # set automatically if not provided
67 |
68 | use_reranker: False # whether to use reranker
69 | rerank_model_name: ~ # same as retrieval_method
70 | rerank_model_path: ~ # path to reranker model, path will be automatically find in `retriever_model2path`
71 | rerank_pooling_method: ~
72 | rerank_topk: 5 # number of remain documents after reranking
73 | rerank_max_length: 512
74 | rerank_batch_size: 256 # batch size for reranker
75 | rerank_use_fp16: True
76 |
77 | ## -------------------------------------------------Generator Settings------------------------------------------------#
78 | framework: vllm # inference frame work of LLM, supporting: 'hf','vllm','fschat'
79 | generator_model: "Qwen/Qwen2.5-7B-Instruct" # name or path of the generator model
80 | model_path: "Qwen/Qwen2.5-7B-Instruct"
81 | generator_max_input_len: 8192 # max length of the input
82 | generator_batch_size: 2 # batch size for generation, invalid for vllm
83 | generation_params:
84 | do_sample: False
85 | max_tokens: 256
86 | use_fid: False # whether to use FID, only valid in encoder-decoder model
87 |
88 | # -----------------------------------------OpenAI Generator Settings------------------------------------------------#
89 | #framework: openai # inference frame work of LLM, supporting: 'hf','vllm','fschat'
90 | #generator_model: "" # name or path of the generator model
91 | #generator_max_input_len: 16000 # max length of the input
92 | #generator_batch_size: 1 # batch size for generation, invalid for vllm
93 | #generation_params:
94 | # do_sample: False
95 | # max_tokens: 16000
96 | # temperature: 1
97 | #use_fid: False # whether to use FID, only valid in encoder-decoder model
98 |
99 | # -------------------------------------------------Refiner Settings------------------------------------------------#
100 | # If set, the refiner will be used to refine the retrieval documents.
101 | refiner_name: ~
102 | refiner_model_path: ~
103 |
104 | # Used for extractive method (e.g embedding models)
105 | refiner_topk: 5 # number of remain sentence after refiner
106 | refiner_pooling_method: 'mean' # pooling method of refiner model
107 | refiner_encode_max_length: 256
108 | # Used for abstractive method (e.g. generation models like bart-large-cnn)
109 | refiner_max_input_length: 1024
110 | refiner_max_output_length: 512
111 |
112 | # Specify settings for llmlingua
113 | llmlingua_config:
114 | rate: 0.55
115 | condition_in_question: 'after_condition'
116 | reorder_context: 'sort'
117 | dynamic_context_compression_ratio: 0.3
118 | condition_compare: True
119 | context_budget: "+100"
120 | rank_method: 'longllmlingua'
121 | sc_config:
122 | 'reduce_ratio': 0.5
123 |
124 | # -------------------------------------------------Evaluation Settings------------------------------------------------#
125 | # Metrics to evaluate the result
126 | metrics: ['em','f1','acc','precision','recall']
127 | # Specify setting for metric, will be called within certain metrics
128 | metric_setting:
129 | retrieval_recall_topk: 5
130 | tokenizer_name: Qwen/Qwen2.5-7B-Instruct
131 | save_metric_score: True # whether to save the metric score into txt file
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/pipeline/reasonrag_pipeline.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import asyncio
3 | import time
4 | import os
5 | import pickle
6 | import psutil
7 | from copy import deepcopy
8 | import os
9 | from multiprocessing import Pool
10 | import time
11 | from tqdm import tqdm
12 | import multiprocessing
13 | from functools import partial
14 | from concurrent.futures import ProcessPoolExecutor, as_completed
15 | from multiprocessing import Manager
16 | import re
17 | import math
18 | from tqdm import tqdm
19 | import numpy as np
20 | import logging
21 | import copy
22 | import queue
23 | from typing import List
24 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
25 | from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse
26 | from flashrag.pipeline import BasicPipeline
27 | from flashrag.dataset import get_batch_dataset, merge_batch_dataset, Dataset
28 | from flashrag.evaluator.metrics import F1_Score
29 | from flashrag.prompt import PromptTemplate
30 | from abc import ABC, ABCMeta, abstractmethod
31 | from typing import Any, Generic, TypeVar, Union
32 |
33 |
34 | def save_data(save_dir, data, file_name="intermediate_data.json"):
35 | """Save the evaluated data, including the raw data and the score of each data
36 | sample on each metric."""
37 | save_path = os.path.join(save_dir, file_name)
38 | data.save(save_path)
39 |
40 | class MCTSRAGNode:
41 | def __init__(self, index, S: Any, parent: MCTSRAGNode | None = None, parent_id: int = -1, step: int = 0,
42 | next_state: str | None = None, question="RAG question", thoughts=None, golden_answers=None,
43 | node_dict: dict[str, Any] | None = None):
44 | self.id = index
45 | self.S = S
46 | self.parent = parent
47 | self.parent_id = parent_id
48 | self.children = []
49 | self.children_ids = []
50 | self.N = 0
51 | self.step = step
52 | self.Q = 0.0
53 | self.Q_list = []
54 | self.reward_list = []
55 | self.next_state = next_state
56 | self.question = question
57 | self.thoughts = thoughts
58 | self.golden_answers = golden_answers
59 | self.node_dict = node_dict if node_dict is not None else {}
60 | self.max_tokens_reached = False
61 |
62 | def add_child(self, child_node: MCTSRAGNode):
63 | self.children.append(child_node)
64 | self.children_ids.append(child_node.id)
65 |
66 | def update_Q(self, reward: float):
67 | self.N += 1
68 | self.Q_list.append(reward)
69 | self.Q = np.mean(self.Q_list)
70 | self.node_dict['Q'] = self.Q
71 | self.node_dict['step'] = self.step
72 | self.node_dict['N'] = self.N
73 |
74 | def update_reward(self, reward: float):
75 | self.reward_list.append(reward)
76 | self.node_dict["reward"] = np.mean(self.reward_list)
77 |
78 | def is_fully_expanded(self, max_children) -> bool:
79 | return self.next_state is None or len(self.children) == max_children
80 |
81 |
82 | def extract_answer(pred, prefix="So the answer is"):
83 | if prefix in pred:
84 | pred = pred.split(prefix)[1].strip()
85 | answer_matches = re.findall(r'(.*?)', pred)
86 | pred = answer_matches[-1] if answer_matches else pred
87 | pred = re.sub(r'.*?|.*?|answer>|([^<]*)', answer, re.DOTALL)
104 | answer = answer_matches[-1] if answer_matches else pred
105 | answer = re.sub(r'.*?|.*?|answer>|.*', next_content):
136 | next_content = f'{next_content}'
137 |
138 | text = text.replace(next_match.group(1), next_content)
139 |
140 | answer_matches = re.finditer(r'So the answer is (.*?)(?= So the answer is|$)', text, re.DOTALL)
141 | for match in answer_matches:
142 | answer_content = match.group(1).strip()
143 | answer_content = re.sub(r'[\.\"]+$', '', answer_content)
144 | if answer_content.startswith('“') and answer_content.endswith('”'):
145 | answer_content = answer_content[1:-1]
146 |
147 | if not re.search(r'.*', answer_content):
148 | answer_content = f'{answer_content}'
149 |
150 | text = text.replace(match.group(1), answer_content)
151 |
152 | return text
153 |
154 | class ReasonRAGPipeline(BasicPipeline):
155 | BEGIN_REASONING_PROMPT = """You are an assistant for question answering with access to a retrieval tool. Upon receiving a question, your task is to:
156 | * Analyze and Decompose the Question: Break the question into smaller, manageable sub-questions to ensure all aspects are addressed.
157 | * Evaluate Your Knowledge: Assess each sub-question or component:
158 | - Identify parts you can confidently answer based on your existing knowledge.
159 | - Pinpoint parts that require additional information or verification through retrieval tools.
160 | * Conciseness: Ensure both queries and answers are concise, using nouns or short phrases whenever possible.
161 | * Respond Format:
162 | If your knowledge is sufficient to answer the question, conclude with:
163 | "So the answer is {answer_format}"
164 | If retrieval is necessary to provide a complete answer, conclude with:
165 | "So the next query is query"
166 | """
167 |
168 | DOCUMENT_ANALYSIS_PROMPT = """You are an information retrieval assistant. Your task is to extract relevant evidence from the provided Wikipedia documents based on the latest query.
169 |
170 | Instructions:
171 |
172 | * Identify key terms or concepts in the query.
173 | * Search the documents for evidence that supports the query.
174 | * Response format:
175 | If relevant evidence is found, output:
176 | Based on the query, the relevant evidence is evidence.
177 | If no relevant evidence is found, output:
178 | None.
179 | """
180 |
181 | REASONING_PROMPT = """You are a question-answering assistant with access to a retrieval tool. Your goal is to provide a concise and accurate reasoning process.
182 | Instructions:
183 | * Error Reflection: If errors exist in previous thoughts, identify and correct them. Skip this step if no errors are present.
184 | * Information Sufficiency: Evaluate whether the current information is sufficient to fully and accurately answer the question. If additional retrieval is needed, deconstruct the question and generate the next query. Avoid repeating previous queries. If no meaningful new query can be generated, explain why and provide an answer based on the current information.
185 | * Conciseness: Ensure both queries and answers are concise, using nouns or short phrases whenever possible.
186 | * Conclusion:
187 | If generating an answer:
188 | "So the answer is {answer_format}".
189 | If more retrieval is needed:
190 | "So the next query is query".
191 | """
192 |
193 | ANSWER_GENERATION_PROMPT = """You are a reasoning assistant with retrieval. Give a precise and very concise final answer for the given question, conclude with 'So the answer is {answer_format}'. Keep your final answer brief and to the point, followed without any explanation.
194 | """
195 |
196 | EVALUATION_PROMPT = """An agent is tasked with answering a question using a retrieval tool.
197 | Critically assess its intermediate reasoning process to determine if it leads to the correct answer.
198 | Identify all flaws, inconsistencies, and mistakes in the thought process.
199 | Every imperfection, no matter how small, must be acknowledged.
200 | Evaluate how effectively the reasoning supports the final answer and the overall accuracy of the response.
201 | Ensure the evaluation is extremely harsh, leaving no leniency.
202 | Even if the answer seems close to correct, do not award full marks to maintain strict grading standards.
203 | Assign a score between [0, 100] based on the severity of flaws and the reasoning’s accuracy in leading to the golden answer.
204 | Respond briefly and conclude with: So the score is [Score].
205 | """
206 |
207 | def __init__(self, config, prompt_template=None, answer_format="answer", retriever=None, generator=None, max_iter=8, max_children=3,
208 | max_rollouts: int = 64, c: float = 1.414, default_uct_score: float = float("inf"), beta=0.95,
209 | batch_size=50, max_workers=50):
210 | self.begin_reasoning_prompt = PromptTemplate(
211 | config=config,
212 | system_prompt=f"{self.BEGIN_REASONING_PROMPT.format(answer_format=answer_format)}",
213 | user_prompt="Question: {question}",
214 | enable_chat=True
215 | )
216 |
217 | self.document_analysis_prompt = PromptTemplate(
218 | config=config,
219 | system_prompt=f"{self.DOCUMENT_ANALYSIS_PROMPT}",
220 | user_prompt="Question: {question}. Reference: {reference}",
221 | reference_template="Wikipedia Title: {title}\n{text}\n\n",
222 | enable_chat=True
223 | )
224 |
225 | self.reasoning_prompt = PromptTemplate(
226 | config=config,
227 | system_prompt=f"{self.REASONING_PROMPT.format(answer_format=answer_format)}",
228 | user_prompt="Question: {question}",
229 | enable_chat=True
230 | )
231 |
232 | self.evaluation_prompt = PromptTemplate(
233 | config=config,
234 | system_prompt=f"{self.EVALUATION_PROMPT}",
235 | user_prompt="Question: {question}",
236 | enable_chat=True
237 | )
238 |
239 | self.answer_generation_prompt = PromptTemplate(
240 | config=config,
241 | system_prompt=f"{self.ANSWER_GENERATION_PROMPT.format(answer_format=answer_format)}",
242 | user_prompt="Question: {question}",
243 | enable_chat=True
244 | )
245 |
246 | super().__init__(config, self.begin_reasoning_prompt)
247 | self.generator = get_generator(config) if generator is None else generator
248 | self.retriever = get_retriever(config) if retriever is None else retriever
249 | self.max_iter = max_iter
250 | self.max_children = max_children
251 | self.max_rollouts = max_rollouts
252 | self.c = c
253 | self.default_uct_score = default_uct_score
254 | self.index = 0
255 | self.beta = beta
256 |
257 | self.batch_size = batch_size
258 | self.max_workers = max_workers or os.cpu_count()
259 |
260 | self.stop_tokens = ["<|im_end|>", "<|endoftext|>", "", "", ""]
261 |
262 | def initialize(self, item) -> MCTSRAGNode:
263 | self.index = 0
264 | thoughts = []
265 | question = item.question
266 | iter_num = 0
267 | next_state = "begin_reasoning"
268 | node = MCTSRAGNode(self.index, question, None, -1, step=iter_num,
269 | next_state=next_state, question=question,
270 | thoughts=thoughts, golden_answers=item.golden_answers)
271 | return node
272 |
273 | def search(self, item):
274 | root = self.initialize(item)
275 | for _ in range(self.max_rollouts):
276 | leaf = self.select(root)
277 | child = self.expand(leaf)
278 | Q, reward = self.simulate(child)
279 | self.backpropagate(child, Q, reward)
280 |
281 | self.tree2log(root, item)
282 |
283 | return root
284 |
285 | def tree2log(self, root, item):
286 | q = queue.Queue()
287 | q.put(root)
288 |
289 | while not q.empty():
290 | node = q.get()
291 | self.node2log(node, item)
292 |
293 | for child in node.children:
294 | q.put(child)
295 |
296 | def node2log(self, node, item):
297 | node_dict = node.node_dict
298 | node_dict["parent_id"] = node.parent_id
299 | node_dict["children_ids"] = node.children_ids
300 |
301 | item.update_output(
302 | f"intermediate_node_{node.id}",
303 | node_dict
304 | )
305 |
306 | def select(self, node: MCTSRAGNode) -> MCTSRAGNode:
307 | while not self.is_terminal(node):
308 | if not node.is_fully_expanded(self.max_children):
309 | return node
310 | node = self._select_child(node)
311 | return node
312 |
313 | def expand(self, node: MCTSRAGNode) -> MCTSRAGNode:
314 | if node.is_fully_expanded(self.max_children):
315 | return node
316 |
317 | child = self.get_next_state(node, True)
318 | return child
319 |
320 | def _select_child(self, node: MCTSRAGNode) -> MCTSRAGNode:
321 | return max(node.children, key=lambda child: self._uct(child))
322 |
323 | def _uct(self, node: MCTSRAGNode) -> float:
324 | if node.N == 0:
325 | return self.default_uct_score
326 | return node.Q + self.c * math.sqrt(math.log(node.parent.N) / node.N)
327 |
328 | def _best_child(self, node: MCTSRAGNode) -> MCTSRAGNode:
329 | return max(node.children, key=lambda child: child.N)
330 |
331 | def simulate(self, node):
332 | Q = self.evaluate_thoughts(node)
333 | reward = self.get_reward(node) if self.is_terminal(node) else None
334 | return Q, reward
335 |
336 | def is_terminal(self, node: MCTSRAGNode) -> bool:
337 | if node.next_state is None or node.max_tokens_reached:
338 | return True
339 |
340 | return False
341 |
342 | def _simulate_policy(self, node: MCTSRAGNode):
343 | action = node.next_state
344 | return action
345 |
346 | def get_next_state(self, parent_node: MCTSRAGNode, use_index=False):
347 | action = parent_node.next_state
348 | if action is None:
349 | return None
350 |
351 | response, node_dict = "", {}
352 | thoughts = copy.copy(parent_node.thoughts)
353 | if action == "begin_reasoning":
354 | response, node_dict = self.handle_begin_reasoning(parent_node.question, thoughts)
355 | elif action == "document_analysis":
356 | response, node_dict = self.handle_document_analysis(parent_node.question, thoughts)
357 | elif action == "reasoning":
358 | response, node_dict = self.handle_reasoning(parent_node.question, thoughts)
359 | elif action == "answer_generation":
360 | response, node_dict = self.handle_answer_generation(parent_node.question, thoughts)
361 |
362 | new_child_node_id = -1
363 | if use_index:
364 | self.index += 1
365 | new_child_node_id = self.index
366 |
367 | node_dict["id"] = new_child_node_id
368 | node_dict["parent_id"] = parent_node.id
369 | next_action_state_for_child = self.next_state(action, response, parent_node.step)
370 |
371 | child_node = MCTSRAGNode(
372 | new_child_node_id,
373 | parent_node.S + " " + response,
374 | parent_node,
375 | parent_node.id,
376 | parent_node.step + 1,
377 | next_action_state_for_child,
378 | parent_node.question,
379 | thoughts,
380 | parent_node.golden_answers,
381 | node_dict
382 | )
383 |
384 | if self.prompt_template.check_prompt_length(node_dict["input_prompt"]):
385 | child_node.max_tokens_reached = True
386 |
387 | if use_index:
388 | parent_node.add_child(child_node)
389 |
390 | return child_node
391 |
392 | def get_reward(self, node: MCTSRAGNode):
393 | pred = extract_answer(node.node_dict["response"])
394 | golden_answers = node.golden_answers
395 | evaluator = F1_Score(self.config)
396 | reward = evaluator.token_level_scores(pred, golden_answers)
397 | reward = reward['f1'] * self.beta ** node.step
398 | return reward
399 |
400 | def backpropagate(self, node, Q, reward):
401 | while node:
402 | node.update_Q(Q)
403 | if reward is not None:
404 | node.update_reward(reward)
405 | node = node.parent
406 |
407 | def next_state(self, current_action: str, response: str, iter_num: int):
408 | action_transitions = {
409 | "begin_reasoning": lambda res: None if re.search(r'.*?',
410 | res, re.DOTALL) else "document_analysis",
411 | "reasoning": lambda res: None if re.search(r'.*?',
412 | res, re.DOTALL) else "document_analysis",
413 | "document_analysis": lambda res: "reasoning",
414 | "answer_generation": lambda res: None,
415 | }
416 |
417 | # Retrieve the function for the current action
418 | transition_function = action_transitions.get(current_action)
419 |
420 | if iter_num == self.max_iter - 1:
421 | return "answer_generation"
422 |
423 | if iter_num < self.max_iter-1 and transition_function:
424 | return transition_function(response)
425 |
426 | return None
427 |
428 | def extract_answer(self, response: str) -> str:
429 | match = re.search(r'So the answer is\s*(.*?)(?=\n|$)', response, re.IGNORECASE | re.DOTALL)
430 | if not match:
431 | return ""
432 |
433 | text = match.group(1).strip()
434 | # Remove special tokens
435 | text = re.sub(r'?(answer|query|evidence)>', '', text)
436 | return text.strip()
437 |
438 | def extract_query(self, response: str) -> str:
439 | match = re.search(r'So the next query is\s*(.*?)(?=\n|$)', response, re.IGNORECASE | re.DOTALL)
440 | if not match:
441 | return ""
442 |
443 | text = match.group(1).strip()
444 | # Remove special tokens
445 | text = re.sub(r'?(answer|query|evidence)>', '', text)
446 | return text.strip()
447 |
448 | def delete_tokens(self, response: str) -> str:
449 | cleaned_response = re.sub(r'.*?|.*?', '', response, flags=re.DOTALL)
450 | return cleaned_response
451 |
452 | def initialize(self, item) -> MCTSRAGNode:
453 | self.index = 0
454 | thoughts = []
455 | question = item.question
456 | iter_num = 0
457 | next_state = "begin_reasoning"
458 | node = MCTSRAGNode(self.index, question, None, -1, step=iter_num,
459 | next_state=next_state, question=question,
460 | thoughts=thoughts, golden_answers=item.golden_answers)
461 | return node
462 |
463 | def test_item(self, item):
464 | root = self.initialize(item)
465 | node = root
466 |
467 | while not self.is_terminal(node):
468 | next_node = self.get_next_state(node, use_index=True)
469 | node = next_node
470 |
471 | self.tree2log(root, item)
472 | if "answer" in node.node_dict and node.node_dict["answer"] is not None:
473 | item.update_output("pred", node.node_dict["answer"])
474 | else:
475 | item.update_output("pred", "none")
476 | return node
477 |
478 | def process_annotation(self, dataset, do_eval=True, pred_process_fun=extract_answer):
479 | for item in tqdm(dataset, desc="Inference: "):
480 | self.search(item)
481 |
482 | save_data(self.config["save_dir"], dataset)
483 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
484 | return dataset
485 |
486 | def run(self, dataset, do_eval=True, batch_size=16, pred_process_fun=None):
487 | all_dataset_list = []
488 | for batch_dataset in tqdm(get_batch_dataset(dataset, batch_size=batch_size), desc="Batch dataset: "):
489 | batch_dataset = self.run_batch(batch_dataset)
490 | all_dataset_list.append(batch_dataset)
491 | dataset = merge_batch_dataset(all_dataset_list)
492 | save_data(self.config["save_dir"], dataset)
493 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
494 | return dataset
495 |
496 | def get_flags(self, responses):
497 | flags = []
498 | for response in responses:
499 | if "So the next query is" in response:
500 | flag = "query"
501 | elif "So the answer is" in response:
502 | flag = "answer"
503 | elif "" in response:
504 | flag = "evidence"
505 | else:
506 | flag = "None"
507 | flags.append(flag)
508 | return flags
509 |
510 | def get_querys(self, responses):
511 | querys = []
512 | for response in responses:
513 | querys.append(self.extract_query(response))
514 | return querys
515 |
516 | def get_answers(self, responses):
517 | answers = []
518 | for response in responses:
519 | answers.append(self.extract_answer(response))
520 | return answers
521 |
522 | def all_finished(self, next_flags):
523 | return all(flag in ["answer", "finish"] for flag in next_flags)
524 |
525 | def run_batch(self, dataset):
526 | for item in dataset:
527 | item.update_output('finish_flag', False)
528 | item.update_output('iteration_count', 0)
529 | item.update_output('previous_thoughts', [])
530 | item.update_output('flag', None)
531 | item.update_output('query', None)
532 | item.update_output('answer', None)
533 |
534 | input_prompts = [self.begin_reasoning_prompt.get_string(question=item.question) for item in dataset]
535 | responses = self.generator.generate(input_prompts, stop=self.stop_tokens)
536 |
537 | for i, item in enumerate(dataset):
538 | response = responses[i]
539 | item.previous_thoughts.append(response)
540 | item.flag = self.get_flags([response])[0]
541 | item.query = self.get_querys([response])[0]
542 | item.answer = self.get_answers([response])[0]
543 |
544 | node_dict = {
545 | "action_name": "begin_reasoning",
546 | "input_prompt": input_prompts[i],
547 | "response": response,
548 | "query": item.query,
549 | "answer": item.answer,
550 | }
551 | item.update_output(f"intermediate_node_0", node_dict)
552 | if item.flag in ["finish", "answer"]:
553 | item.finish_flag = True
554 |
555 | for step in range(self.max_iter):
556 | exist_items = [item for item in dataset if not item.finish_flag]
557 | if not exist_items:
558 | break
559 |
560 | active_questions = [item.question for item in exist_items]
561 | active_previous_thoughts = [item.previous_thoughts for item in exist_items]
562 | active_flags = [item.flag for item in exist_items]
563 | active_querys = [item.query for item in exist_items]
564 |
565 | question_thoughts_list = [
566 | q + "\nPrevious Thoughts: " + " ".join(thoughts)
567 | for q, thoughts in zip(active_questions, active_previous_thoughts)
568 | ]
569 |
570 | retrieval_results = self.retriever.batch_search(active_querys)
571 |
572 | input_prompts = []
573 | for i, item in enumerate(exist_items):
574 | if item.iteration_count >= self.max_iter - 1:
575 | input_prompts.append(self.answer_generation_prompt.get_string(question=question_thoughts_list[i]))
576 | elif "query" in item.flag:
577 | input_prompts.append(self.document_analysis_prompt.get_string(
578 | question=question_thoughts_list[i], retrieval_result=retrieval_results[i]
579 | ))
580 | elif "evidence" in item.flag:
581 | input_prompts.append(self.reasoning_prompt.get_string(question=question_thoughts_list[i]))
582 | else:
583 | input_prompts.append(self.answer_generation_prompt.get_string(question=question_thoughts_list[i]))
584 |
585 | responses = self.generator.generate(input_prompts, stop=self.stop_tokens)
586 | for i, item in enumerate(exist_items):
587 | response = responses[i]
588 | item.previous_thoughts.append(response)
589 | item.iteration_count += 1
590 | item.flag = self.get_flags([response])[0]
591 | item.query = self.get_querys([response])[0]
592 | item.answer = self.get_answers([response])[0]
593 |
594 | node_dict = {
595 | "action_name": "unknown",
596 | "input_prompt": input_prompts[i],
597 | "response": response,
598 | "query": item.query,
599 | "answer": item.answer,
600 | }
601 | if "evidence" in item.flag:
602 | node_dict["action_name"] = "document_analysis"
603 | node_dict["retrieval_result"] = retrieval_results[i]
604 | elif "query" in item.flag:
605 | node_dict["action_name"] = "query_generation"
606 | elif "answer" in item.flag:
607 | node_dict["action_name"] = "answer_generation"
608 | elif "finish" in item.flag:
609 | node_dict["action_name"] = "finish"
610 |
611 | item.update_output(f"intermediate_node_{item.iteration_count}", node_dict)
612 |
613 | if item.flag in ["finish", "answer"] or item.iteration_count >= self.max_iter:
614 | item.finish_flag = True
615 |
616 | for i in range(len(dataset)):
617 | dataset[i].pred = dataset[i].answer
618 |
619 | return dataset
620 |
621 | def handle_begin_reasoning(self, question, thoughts):
622 | begin_reasoning = [self.begin_reasoning_prompt.get_string(question=question)]
623 | begin_response = self.generator.generate(begin_reasoning)[0]
624 | begin_response = process_text(begin_response)
625 | thoughts.append(begin_response)
626 |
627 | query_matches = re.findall(r'(.*?)', begin_response)
628 | extracted_query = self.delete_tokens(query_matches[-1]) if query_matches else None
629 |
630 | answer_matches = re.findall(r'(.*?)', begin_response)
631 | extracted_answer = self.delete_tokens(answer_matches[-1]) if answer_matches else None
632 |
633 | node_dict = {
634 | "action_name": "begin_reasoning",
635 | "input_prompt": begin_reasoning[0],
636 | "response": begin_response,
637 | "next_query": extracted_query,
638 | "answer": extracted_answer,
639 | }
640 |
641 | return begin_response, node_dict
642 |
643 | def handle_reasoning(self, question, thoughts):
644 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts)
645 | reasoning_prompt = [self.reasoning_prompt.get_string(question=question_thoughts)]
646 | response = self.generator.generate(reasoning_prompt)[0]
647 | response = process_text(response)
648 | thoughts.append(response)
649 |
650 | query_matches = re.findall(r'(.*?)', response)
651 | extracted_query = self.delete_tokens(query_matches[-1]) if query_matches else None
652 |
653 | answer_matches = re.findall(r'(.*?)', response)
654 | extracted_answer = self.delete_tokens(answer_matches[-1]) if answer_matches else None
655 |
656 | node_dict = {
657 | "action_name": "reasoning",
658 | "input_prompt": reasoning_prompt[0],
659 | "response": response,
660 | "next_query": extracted_query,
661 | "answer": extracted_answer,
662 | }
663 |
664 | return response, node_dict
665 |
666 | def handle_document_analysis(self, question, thoughts):
667 | extracted_query = self.delete_tokens(self.extract_query(thoughts[-1]))
668 | id2doc = {}
669 | doc2score = {}
670 | retrieval_result, scores = self.retriever.search(extracted_query, return_score=True)
671 | for doc_item, score in zip(retrieval_result, scores):
672 | id2doc[doc_item["id"]] = doc_item
673 | doc_id = doc_item["id"]
674 | if doc_id in doc2score:
675 | doc2score[doc_id] = max(doc2score[doc_id], score)
676 | else:
677 | doc2score[doc_id] = score
678 |
679 | sorted_doc_score = sorted(doc2score.items(), key=lambda x: x[1], reverse=False)
680 | sorted_doc_id = [t[0] for t in sorted_doc_score]
681 | retrieval_result = [id2doc[id] for id in sorted_doc_id]
682 |
683 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts)
684 | grounding_prompt = [self.document_analysis_prompt.get_string(
685 | question=question_thoughts, retrieval_result=retrieval_result
686 | )]
687 | response = self.generator.generate(grounding_prompt)[0]
688 | thoughts.append(response)
689 |
690 | node_dict = {
691 | "action_name": "document_analysis",
692 | "input_prompt": grounding_prompt[0],
693 | "query": extracted_query,
694 | "response": response,
695 | "retrieval_result": retrieval_result,
696 | }
697 |
698 | return response, node_dict
699 |
700 | def handle_answer_generation(self, question, thoughts):
701 | question_thoughts = question + "\nPrevious Thoughts: " + " ".join(thoughts)
702 | answer_generation = [self.answer_generation_prompt.get_string(question=question_thoughts)]
703 | answer = self.generator.generate(answer_generation)[0]
704 | answer = process_text(answer)
705 | thoughts.append(answer)
706 | pred = extract_answer(answer)
707 | node_dict = {
708 | "action_name": "answer_generation",
709 | "input_prompt": answer_generation[0],
710 | "response": answer,
711 | "pred": pred,
712 | }
713 |
714 | return answer, node_dict
715 |
716 | def evaluate_thoughts(self, node):
717 | question_thoughts = node.question + "\nGolden Answer: " + " or ".join(
718 | node.golden_answers) + "\nAgent Reasoning Process: " + " ".join(node.thoughts)
719 | evaluation_prompt = [self.evaluation_prompt.get_string(question=question_thoughts)]
720 | evaluation_response = self.generator.generate(evaluation_prompt)[0]
721 | Q = extract_last_number(evaluation_response)
722 | Q = float(Q) / 100
723 | return Q
724 |
--------------------------------------------------------------------------------
/preference_data_generatoin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import defaultdict
4 | from typing import Dict, List
5 | import numpy as np
6 | import re
7 | import tiktoken
8 | import string
9 | from glob import glob
10 |
11 | BETA = 0.9
12 | ENCODER = tiktoken.get_encoding("cl100k_base")
13 |
14 | def count_tokens(text: str) -> int:
15 | return len(ENCODER.encode(text))
16 |
17 | def get_action_type(response: str) -> str:
18 | if "" in response:
19 | return "Query Generation"
20 | elif "" in response:
21 | return "Evidence Extraction"
22 | elif "" in response:
23 | return "Answer Generation"
24 | return "Other"
25 |
26 | def get_content(response: str) -> str:
27 | for tag in ["answer", "query", "evidence"]:
28 | match = re.search(rf"<{tag}>(.*?){tag}>", response, re.DOTALL)
29 | if match:
30 | return match.group(1).strip()
31 | return ""
32 |
33 | def get_action_type_and_content(response: str) -> tuple[str, str]:
34 | action_type = get_action_type(response)
35 | content = get_content(response)
36 | return action_type, content
37 |
38 | def process_node(data: Dict, node_id: int, results: List, action_type_counts: Dict[str, int], layer: int = 0):
39 | current_node = data[f"intermediate_node_{node_id}"]
40 | children_ids = current_node.get("children_ids", [])
41 | if len(children_ids) < 2:
42 | return
43 |
44 | valid_children = []
45 | prompt = ""
46 | system_prompt = ""
47 | user_prompt = ""
48 | for child_id in children_ids:
49 | child_node = data.get(f"intermediate_node_{child_id}")
50 | if child_node and "reward" in child_node and child_node["reward"] is not None and "input_prompt" in child_node:
51 | valid_children.append(child_node)
52 | input_prompt = child_node["input_prompt"]
53 | system_prompt = input_prompt[0]["content"]
54 | user_prompt = input_prompt[1]["content"]
55 | prompt = f"System: {system_prompt}\nUser: {user_prompt}"
56 |
57 | for i in range(len(valid_children)):
58 | for j in range(i + 1, len(valid_children)):
59 | node_a, node_b = valid_children[i], valid_children[j]
60 | reward_a, reward_b = float(node_a["reward"]), float(node_b["reward"])
61 | if abs(reward_a - reward_b) < 0.01:
62 | continue
63 |
64 | chosen, rejected = (node_a, node_b) if reward_a > reward_b else (node_b, node_a)
65 | chosen_action_type, chosen_content = get_action_type_and_content(chosen["response"])
66 | rejected_action_type, rejected_content = get_action_type_and_content(rejected["response"])
67 |
68 | if chosen_content == rejected_content or chosen_action_type == "Other":
69 | continue
70 |
71 | action_type_counts[chosen_action_type] += 1
72 | results.append({
73 | "instruction": system_prompt,
74 | "input": user_prompt,
75 | "prompt": prompt,
76 | "chosen": chosen["response"],
77 | "rejected": rejected["response"],
78 | "chosen_action_type": chosen_action_type,
79 | "chosen_content": chosen_content,
80 | "rejected_action_type": rejected_action_type,
81 | "rejected_content": rejected_content,
82 | "layer": layer
83 | })
84 |
85 | for child in valid_children:
86 | process_node(data, child["id"], results, action_type_counts, layer + 1)
87 |
88 | def process_json_files(folder_list: List[str], output_file: str = "dpo_data.json"):
89 | all_results = []
90 | action_type_counts = defaultdict(int)
91 |
92 | for folder in folder_list:
93 | pattern = os.path.join(folder, "chunk_*", "reward_*.json")
94 | for json_path in glob(pattern, recursive=True):
95 | try:
96 | with open(json_path, "r") as f:
97 | items = json.load(f)
98 | except Exception:
99 | continue
100 |
101 | for item in items:
102 | if "output" not in item or "golden_answers" not in item or item["golden_answers"] is None:
103 | continue
104 |
105 | output_data = item["output"]
106 | root_node = output_data.get("intermediate_node_0")
107 | if not root_node or root_node.get("reward") is None:
108 | continue
109 |
110 | process_node(output_data, 0, all_results, action_type_counts)
111 |
112 | with open(output_file, "w") as f:
113 | json.dump(all_results, f, indent=2)
114 |
115 | def normalize_answer(s: str) -> str:
116 | def white_space_fix(text):
117 | return " ".join(text.split())
118 | def remove_punc(text):
119 | exclude = set(string.punctuation)
120 | return "".join(ch for ch in text if ch not in exclude)
121 | def lower(text):
122 | return text.lower()
123 | return white_space_fix(remove_punc(lower(s))).strip()
124 |
125 | def compute_f1_single(prediction: str, ground_truth: str) -> float:
126 | pred, gt = normalize_answer(prediction), normalize_answer(ground_truth)
127 | if pred in ["yes", "no", "noanswer"] and pred != gt:
128 | return 0.0
129 | if gt in ["yes", "no", "noanswer"] and pred != gt:
130 | return 0.0
131 |
132 | pred_tokens, gt_tokens = pred.split(), gt.split()
133 | common = sum((Counter(pred_tokens) & Counter(gt_tokens)).values())
134 | if common == 0:
135 | return 0.0
136 | precision = common / len(pred_tokens) if pred_tokens else 0.0
137 | recall = common / len(gt_tokens) if gt_tokens else 0.0
138 | return (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
139 |
140 | def compute_f1_multiple_gt(prediction: str, ground_truths: List[str]) -> float:
141 | return max(compute_f1_single(prediction, gt) for gt in ground_truths)
142 |
143 | def calculate_mcts_reward_bottom_up(data: Dict, golden_answers: List[str]) -> Dict:
144 | unprocessed_children_count = {
145 | k: len(data[k].get("children_ids", [])) for k in data if "intermediate_node_" in k
146 | }
147 | queue = deque(k for k, count in unprocessed_children_count.items() if count == 0)
148 | rewards = {}
149 | updated_nodes = data.copy()
150 |
151 | while queue:
152 | node_key = queue.popleft()
153 | node = updated_nodes[node_key]
154 |
155 | if not node.get("children_ids"):
156 | reward = None
157 | if "answer" in node and node.get("response") and node["answer"] is not None:
158 | f1 = compute_f1_multiple_gt(node["answer"], golden_answers)
159 | step = node.get("step", 0)
160 | reward = f1 * (BETA ** step)
161 | else:
162 | children_keys = [f"intermediate_node_{cid}" for cid in node["children_ids"]]
163 | valid_children = [ck for ck in children_keys if ck in rewards and rewards[ck] is not None]
164 | reward = None
165 | if valid_children:
166 | total_n = sum(updated_nodes[ck].get("N", 0) for ck in valid_children)
167 | if total_n > 0:
168 | reward = sum(rewards[ck] * updated_nodes[ck].get("N", 0) for ck in valid_children) / total_n
169 |
170 | rewards[node_key] = reward
171 | if reward is not None:
172 | updated_nodes[node_key]["reward"] = reward
173 | elif "reward" in updated_nodes[node_key]:
174 | del updated_nodes[node_key]["reward"]
175 |
176 | parent_id = node.get("parent_id")
177 | if parent_id is not None:
178 | parent_key = f"intermediate_node_{parent_id}"
179 | if parent_key in updated_nodes:
180 | unprocessed_children_count[parent_key] -= 1
181 | if unprocessed_children_count[parent_key] == 0:
182 | queue.append(parent_key)
183 |
184 | return updated_nodes
185 |
186 | def process_json_file_and_calculate_reward_save(json_path: str):
187 | try:
188 | with open(json_path, "r") as f:
189 | items = json.load(f)
190 | updated_items = []
191 | for item in items:
192 | if "output" not in item or item.get("golden_answers") is None:
193 | updated_items.append(item)
194 | continue
195 |
196 | mcts_data = item["output"]
197 | if "intermediate_node_0" not in mcts_data:
198 | updated_items.append(item)
199 | continue
200 |
201 | updated_nodes = calculate_mcts_reward_bottom_up(mcts_data, item["golden_answers"])
202 | modified_item = item.copy()
203 | modified_item["output"] = updated_nodes
204 | updated_items.append(modified_item)
205 |
206 | chunk_folder = os.path.dirname(json_path)
207 | base_name = os.path.splitext(os.path.basename(json_path))[0]
208 | sequence_number = base_name.split("_")[-1]
209 | reward_output_path = os.path.join(chunk_folder, f"reward_{sequence_number}.json")
210 | with open(reward_output_path, "w") as f:
211 | json.dump(updated_items, f, indent=2)
212 |
213 | except Exception:
214 | pass
215 |
216 | def process_folder_and_calculate_rewards_save(folder_path: str):
217 | pattern = os.path.join(folder_path, "chunk_*", "intermediate_*.json")
218 | for json_path in glob(pattern, recursive=True):
219 | process_json_file_and_calculate_reward_save(json_path)
220 |
221 | def main():
222 | folder_list = ["output/2wikimultihopqa_2025_02_04_01_49_experiment"]
223 | for folder in folder_list:
224 | process_folder_and_calculate_rewards_save(folder)
225 | process_json_files(folder_list, output_file="training_data/RAG_ProGuide.json")
226 |
227 | if __name__ == "__main__":
228 | main()
229 |
--------------------------------------------------------------------------------
/training_config/qwen_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-7B-Instruct
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: dpo
7 | do_train: true
8 | finetuning_type: full
9 | pref_beta: 0.3
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 |
12 | ### dataset
13 | dataset: RAG_ProGuide
14 | template: qwen
15 | cutoff_len: 10000
16 | max_samples: 1000000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 8
19 |
20 | ### output
21 | output_dir: saves/qwen2.5-7b-instruct/full/dpo
22 | logging_steps: 10
23 | save_steps: 500
24 | plot_loss: true
25 | overwrite_output_dir: true
26 |
27 | ### train
28 | per_device_train_batch_size: 1
29 | gradient_accumulation_steps: 2
30 | learning_rate: 1.0e-6
31 | num_train_epochs: 1.0
32 | lr_scheduler_type: cosine
33 | warmup_ratio: 0.2
34 | bf16: true
35 | ddp_timeout: 180000000
36 |
37 | ### eval
38 | val_size: 0.1
39 | per_device_eval_batch_size: 1
40 | eval_strategy: steps
41 | eval_steps: 500
42 |
43 |
--------------------------------------------------------------------------------