├── requirements.txt
├── img
├── fig-intro.png
└── tab-main.png
├── src
├── RAMP
│ ├── model.py
│ └── gen_qa.py
└── multi_agent
│ ├── model.py
│ ├── dataset.py
│ ├── web_news_get.py
│ ├── utils.py
│ ├── cot_construct.py
│ └── prompt.py
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | dashscope
2 | openai
3 | tqdm
4 | hanziconv
5 | transformers
--------------------------------------------------------------------------------
/img/fig-intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-NLP/MaskSearch/HEAD/img/fig-intro.png
--------------------------------------------------------------------------------
/img/tab-main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-NLP/MaskSearch/HEAD/img/tab-main.png
--------------------------------------------------------------------------------
/src/RAMP/model.py:
--------------------------------------------------------------------------------
1 | from dashscope import Generation
2 | from http import HTTPStatus
3 | import time
4 | import dashscope
5 |
6 | DASHSCOPE_API_KEY = "YOUR_API_KEY"
7 |
8 | def call_with_messages(model, messages):
9 | while True:
10 | response = Generation.call(
11 | model=model,
12 | messages=messages,
13 | result_format='message',
14 | api_key=DASHSCOPE_API_KEY,
15 | )
16 | if response.status_code == HTTPStatus.OK:
17 | # print(response)
18 | return response.output.choices[0].message.content
19 | else:
20 | print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
21 | response.request_id, response.status_code,
22 | response.code, response.message
23 | ))
24 | if response.status_code == 400:
25 | return None
26 | time.sleep(1)
27 | continue
28 |
29 |
--------------------------------------------------------------------------------
/src/multi_agent/model.py:
--------------------------------------------------------------------------------
1 | from dashscope import Generation
2 | from http import HTTPStatus
3 | import time
4 | import dashscope
5 |
6 | DASHSCOPE_API_KEY = "YOUR_API_KEY"
7 |
8 | def call_with_messages(model, messages):
9 | while True:
10 | response = Generation.call(
11 | model=model,
12 | messages=messages,
13 | result_format='message',
14 | api_key=DASHSCOPE_API_KEY,
15 | )
16 | if response.status_code == HTTPStatus.OK:
17 | # print(response)
18 | return response.output.choices[0].message.content
19 | else:
20 | print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
21 | response.request_id, response.status_code,
22 | response.code, response.message
23 | ))
24 | if response.status_code == 400:
25 | return None
26 | time.sleep(1)
27 | continue
28 |
29 |
--------------------------------------------------------------------------------
/src/multi_agent/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pandas as pd
4 |
5 |
6 | class BaseDataset:
7 | def __init__(self):
8 | self.data = []
9 |
10 | def __iter__(self):
11 | return iter(self.data)
12 |
13 | def __len__(self):
14 | return len(self.data)
15 |
16 | def __getitem__(self, index):
17 | return self.data[index]
18 |
19 |
20 |
21 | class HotpotDataset(BaseDataset):
22 | def __init__(self):
23 | self.data_path = ""
24 | self.data = self.load_data()
25 |
26 | def load_data(self):
27 | data = []
28 |
29 | with open(self.data_path,'r') as f:
30 | json_data = json.load(f)
31 |
32 | for item in json_data:
33 | query = item['question']
34 | answer = item['answer']
35 |
36 | data.append({
37 | 'query': query,
38 | 'answer': answer,
39 | 'ext':{
40 | 'golden_answer' : answer,
41 | }
42 | })
43 |
44 | return data
45 |
46 |
47 |
--------------------------------------------------------------------------------
/src/multi_agent/web_news_get.py:
--------------------------------------------------------------------------------
1 | from serpapi import GoogleSearch
2 | import time
3 |
4 | GOOGLE_API_KEY = "your_api_key"
5 | retry_attempt = 10
6 | def google(text):
7 | params = {
8 | "engine": "google",
9 | "q": text,
10 | "api_key": GOOGLE_API_KEY,
11 | "num": 10,
12 | }
13 |
14 | news_list = []
15 | for i in range(retry_attempt):
16 | try:
17 | search = GoogleSearch(params)
18 | results = search.get_dict()
19 | organic_results = results.get("organic_results", [])
20 | for doc in organic_results:
21 | news_list.append('\"'+ doc['title']+ '\\n' + doc["snippet"] +'\"')
22 |
23 | return news_list
24 | except Exception as e:
25 | print(f"Attempt {i+1} failed: {e}")
26 | if i < retry_attempt - 1:
27 | time.sleep(2)
28 | else:
29 | print("All retries failed.")
30 | return []
31 |
32 |
33 | def merge_news_insert(lists, num=10):
34 | result = []
35 | indices = [0] * len(lists)
36 | list_count = len(lists)
37 | element_count = 0
38 |
39 | while element_count < num:
40 | for i in range(list_count):
41 | current_list = lists[i]
42 |
43 | if indices[i] < len(current_list):
44 | result.append(current_list[indices[i]])
45 | indices[i] += 1
46 | element_count += 1
47 |
48 | if element_count >= num:
49 | break
50 |
51 | if all(indices[j] >= len(lists[j]) for j in range(list_count)):
52 | break
53 |
54 | return result
--------------------------------------------------------------------------------
/src/multi_agent/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | from datetime import datetime, timedelta
4 |
5 |
6 | def get_random_date():
7 |
8 | start_date = datetime.now() - timedelta(days=30)
9 |
10 | random_date = start_date + timedelta(days=random.randint(0, 30))
11 |
12 | weekdays = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
13 |
14 | formatted_date = random_date.strftime("%Y year %m month %d day,") + weekdays[random_date.weekday()]
15 | return formatted_date
16 |
17 |
18 |
19 | knowledge_prompt = """:
20 |
21 | {content}
22 |
23 | """
24 |
25 | def init_data(query, date):
26 |
27 | data = {
28 | "type": "chatml",
29 | "messages": [
30 | {
31 | "role": "system",
32 | "content": f"You are a helpful assistant.\n\n current time:{date}"
33 | },
34 | {
35 | "content": query,
36 | "role": "user"
37 | }
38 | ],
39 | "functions": [
40 | {
41 | "name_for_human": "web_search",
42 | "name_for_model": "web_search",
43 | "description_for_model": "Utilize the web search engine to retrieve relevant information based on multiple queries.",
44 | "parameters": {
45 | "type": "object",
46 | "properties": {
47 | "queries": {
48 | "type": "array",
49 | "items": {
50 | "type": "string",
51 | "description": "The search query."
52 | },
53 | "description": "The list of search queries."
54 | }
55 | },
56 | "required": [
57 | "queries"
58 | ]
59 | }
60 | }
61 | ]
62 | }
63 | return data
64 |
65 | def formate_data(data, extra_data, action):
66 | action = action.lower()
67 | if action == "thought":
68 | data["messages"].append({
69 | "role": "assistant",
70 | "content": extra_data,
71 | "function_call": {
72 | "name": "web_search",
73 | "arguments": None
74 | }
75 | })
76 | elif action == "rewrite":
77 | data["messages"][-1]["function_call"]["arguments"] = extra_data
78 |
79 | elif action == "observation":
80 | data["messages"].append({
81 | "role": "function",
82 | "name": "web_search",
83 | "content": extra_data
84 | })
85 |
86 | elif action == "finish":
87 | data["messages"].append({
88 | "content": extra_data,
89 | "finish_reason": "stop",
90 | "function_call": {
91 | },
92 | "response_role": "assistant",
93 | "role": "assistant"
94 | })
95 |
96 | return data
97 |
98 |
99 | def formate_check(messages):
100 |
101 | if_continue = False
102 |
103 | for message in messages:
104 |
105 | if 'function_call' in message and 'finish_reason' not in message:
106 | if message['function_call']['arguments'] is None:
107 | if_continue = True
108 |
109 |
110 | if 'Act: {' in message['content'] or 'Observation:' in message['content']:
111 | if_continue = True
112 |
113 | return not if_continue
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MaskSearch: A Universal Pre-Training Framework to Enhance Agentic Search Capability
2 |
3 | [](https://pytorch.org/)[](https://arxiv.org/abs/2505.20285)
4 |
5 | ## 🚀 Introduction
6 |
7 | - We propose **MaskSearch**, a novel pre-training framework to further enhance the **universal search capability of agents**.
8 | - We introduce the **Retrieval Augmented Mask Prediction (RAMP) task**, where the model learns to leverage search tools to fill masked spans on a large number of pre-training data, thus acquiring universal retrieval and reasoning capabilities for LLMs.
9 | - We combine agent-based and distillation-based methods to generate training data, starting with a multi-agent system consisting of a planner, rewriter, observer, and followed by a self-evolving teacher model.
10 | - Extensive experiments demonstrate that **MaskSearch** significantly enhances the performance of LLM-based search agents on both in-domain and out-of-domain downstream tasks.
11 |
12 | 
13 |
14 |
15 | ## 💡 Performance
16 | 
17 |
18 | ## 🛠 Running MaskSearch
19 |
20 | Before running, please replace the placeholder with your own Qwen key and Google_search key in `src/RAMP/model.py`, `src/multi_agent/model.py` and `src/multi_agent/web_news_get.py`.
21 | ```python
22 | DASHSCOPE_API_KEY = "YOUR_API_KEY"
23 | GOOGLE_API_KEY = "YOUR_API_KEY"
24 | ```
25 |
26 | Dependencies
27 |
28 | ```bash
29 | pip install -r requirements.txt
30 | ```
31 |
32 | ### Step 1. Generate RAMP QA through Wikipedia
33 | The first step is to generate RAMP QA data using Wikipedia as the data source.
34 |
35 | The Wikipedia data can get from [here](https://dumps.wikimedia.org/enwiki/).
36 |
37 | ```python
38 | python gen_qa.py \
39 | --model "$model" \
40 | --corpus "Wikipedia Directory"\
41 | --output_path "output_path"
42 | ```
43 |
44 | ### Step 2. CoT Trajactory Construction
45 | The second step is to generate CoT trajectories for QA through a Multi Agent approach to construct SFT data.
46 |
47 | You can customize your own dataset and configure the data path in `src/multi_agent/dataset.py`
48 | ```python
49 | python cot_construct.py \
50 | --model "$model" \
51 | --dataset "dataset"\
52 | --output_path "output_path"
53 | ```
54 |
55 | ### Step 3. Training with SFT/RL
56 | After generating the data, the third step is to use the data for training. For SFT, you can refer to the training process of [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory); for RL, you can refer to [Search-R1](https://github.com/PeterGriffinJin/Search-R1) and [ZeroSearch](https://github.com/Alibaba-NLP/ZeroSearch).
57 |
58 |
59 | ## 🙏 Acknowledgements
60 | This work is implemented based on [ChineseWiki](https://github.com/mattzheng/ChineseWiki), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), [Search-R1](https://github.com/PeterGriffinJin/Search-R1), and [verl](https://github.com/volcengine/verl). We greatly appreciate their valuable contributions to the community.
61 |
62 | ## 📝 Citation
63 |
64 | ```bigquery
65 | @article{wu2025masksearchuniversalpretrainingframework,
66 | title={MaskSearch: A Universal Pre-Training Framework to Enhance Agentic Search Capability},
67 | author={Weiqi Wu and Xin Guan and Shen Huang and Yong Jiang and Pengjun Xie and Fei Huang and Jiuxin Cao and Hai Zhao and Jingren Zhou},
68 | year={2025},
69 | eprint={2505.20285},
70 | archivePrefix={arXiv},
71 | primaryClass={cs.CL},
72 | url={https://arxiv.org/abs/2505.20285},
73 | }
74 | ```
--------------------------------------------------------------------------------
/src/multi_agent/cot_construct.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tqdm import tqdm
4 | from model import call_with_messages
5 | from web_news_get import merge_news_insert, google
6 | from dataset import HotpotDataset
7 | from utils import init_data, formate_data, knowledge_prompt, get_random_date, formate_check
8 | from prompt import Planner_Agent, Searcher_Agent, Observation_Agent
9 | import concurrent.futures
10 |
11 |
12 | def prediect_check(question, answer, model_response ,args):
13 |
14 | prompt = [{'role': 'system', 'content': 'You are a helpful assistant.'},
15 | {'role': 'user', 'content': f'Given the correct answer to a question, determine if the model\'s response is correct. If correct, output "correct"; if incorrect, output "incorrect". Do not include unrelated content.\nQuestion: {question}\nCorrect Answer: {answer}\nModel Response: {model_response}'}]
16 |
17 | response = call_with_messages(args.model, prompt)
18 |
19 | if 'incorrect' in response or 'Incorrect' in response:
20 | return False
21 | else:
22 | return True
23 |
24 |
25 | def handle_thought_response(response, args):
26 | """
27 | return:
28 | next_action, current_thought
29 | """
30 | out_put = response
31 |
32 | if '' in out_put:
33 | try:
34 | return 'finish', out_put.split('Thought:')[1].strip()
35 | except:
36 | return 'finish', '' + out_put.split('')[1].strip()
37 |
38 | elif 'Thought:' in out_put:
39 | return 'thought', out_put.split('Thought:')[1].strip()
40 |
41 | else:
42 | # print(out_put)
43 | return 'thought', out_put
44 |
45 |
46 | def construct_data(data, args):
47 |
48 | current_date = get_random_date()
49 |
50 |
51 | system_prompt = {'role': 'system', 'content': f'You are a helpful assistant. current time: {current_date}'}
52 |
53 | formated_data = init_data(data['query'], current_date)
54 | Thought_list = []
55 |
56 |
57 | first_plan = Planner_Agent.replace('{input}', data['query'])
58 |
59 | data_prompt = [system_prompt, {'role': 'user', 'content': first_plan}]
60 |
61 | response = call_with_messages(args.model, data_prompt)
62 | action, out_data = handle_thought_response(response, args)
63 | Thought_list.append(out_data)
64 | formated_data = formate_data(formated_data, out_data, action)
65 |
66 | count = 0 #
67 |
68 | try:
69 |
70 | while True:
71 | if action == 'thought':
72 |
73 | data_prompt = [system_prompt, {'role': 'user', 'content': Searcher_Agent.replace('{input}', Thought_list[-1])}]
74 | response = call_with_messages(args.model, data_prompt)
75 |
76 | start_index = response.find("{")
77 | end_index = response.rfind("}") + 1
78 | response = response[start_index:end_index]
79 |
80 | Thought_list[-1] = Thought_list[-1] +'\n' + response
81 |
82 | action = 'rewrite'
83 | formated_data = formate_data(formated_data, response, action)
84 | queries = eval(response)['queries']
85 | if len(queries) == 0:
86 | print('rewrite fail')
87 | raise Exception
88 |
89 | elif action == 'rewrite':
90 |
91 | news_list = [google(query) for query in queries]
92 | news_list = merge_news_insert(news_list, 10)
93 |
94 | if len(news_list) == 0:
95 | print('search fail')
96 | raise Exception
97 |
98 | observation_data = knowledge_prompt.format(content="\n\n".join(news_list))
99 | action = 'observation'
100 | formated_data = formate_data(formated_data, observation_data, action)
101 |
102 |
103 | elif action == 'observation':
104 | previous_all_thought = "\n".join(['Thought: '+ thought for thought in Thought_list])
105 |
106 |
107 | data_prompt = [system_prompt, {'role': 'user', 'content': Observation_Agent.replace('{input}', data['query']).replace('{Thought}', previous_all_thought).replace('{Observation}', observation_data)}]
108 |
109 | response = call_with_messages(args.model, data_prompt)
110 | action, out_data = handle_thought_response(response, args)
111 | Thought_list.append(out_data)
112 | formated_data = formate_data(formated_data, out_data, action)
113 |
114 | count += 1
115 |
116 | elif action == 'finish':
117 | formated_data['ext'] = data['ext']
118 | if prediect_check(data['query'], data['answer'], out_data, args):
119 | formated_data['ext']['response_acc'] = True
120 | return formated_data, True
121 | else:
122 | formated_data['ext']['response_acc'] = False
123 | return formated_data, True
124 |
125 | elif action == 'error':
126 | formated_data['ext'] = data['ext']
127 | print('error in thought')
128 | return formated_data, False
129 |
130 | if count > 4:
131 | formated_data['ext'] = data['ext']
132 | data['ext']['response_acc'] = 'Max_turn'
133 | print('Max Thought turn')
134 | return formated_data, False
135 |
136 |
137 | except Exception as e:
138 | formated_data['ext'] = data['ext']
139 | return formated_data, False
140 |
141 |
142 |
143 | if __name__ == '__main__':
144 | argparse = argparse.ArgumentParser()
145 | argparse.add_argument('--model', type=str, default='qwen-max')
146 | argparse.add_argument('--dataset', type=str, default='hotpot')
147 | argparse.add_argument('--multi_thread', action_store=True)
148 | argparse.add_argument('--num_threads', type=int, default=20)
149 | argparse.add_argument('--output_path', type=str, default='hotpot_cot.jsonl')
150 | argparse.add_argument('--start_index', type=int, default=0)
151 | argparse.add_argument('--end_index', type=int, default=20000)
152 | args = argparse.parse_args()
153 |
154 |
155 | if args.dataset == 'hotpot':
156 | data_test = HotpotDataset()
157 |
158 |
159 | with open(args.output_path, 'a') as f:
160 | if args.multi_thread == False:
161 | num = 0
162 | for item in tqdm(data_test, total=len(data_test)):
163 | num += 1
164 | if num < args.start_index:
165 | continue
166 | if num > args.end_index:
167 | break
168 |
169 | data, status = construct_data(item, args)
170 |
171 | if status == True:
172 | status = formate_check(data['messages'])
173 |
174 | if status:
175 | f.write(json.dumps(data, ensure_ascii= False) +'\n')
176 | f.flush()
177 | else:
178 | error_path = args.output_path.split('.jsonl')[0] + '_error.jsonl'
179 | error_file = open(error_path,'a')
180 | error_file.write(json.dumps(data, ensure_ascii= False) +'\n')
181 | error_file.flush()
182 | error_file.close()
183 |
184 | else:
185 | data_test = data_test[args.start_index:args.end_index]
186 |
187 | print(f'start handle dataset {args.dataset}, total length is {len(data_test)}')
188 | end_index = min(args.end_index, len(data_test)+args.start_index)
189 | print(f'start index is {args.start_index}, end index is {end_index}')
190 |
191 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor:
192 | results = {executor.submit(construct_data, item, args): item for item in data_test}
193 |
194 | for future in tqdm(concurrent.futures.as_completed(results),total=len(data_test)):
195 | item = results[future]
196 | try:
197 | data, status = future.result()
198 | if status == True:
199 | status = formate_check(data['messages'])
200 |
201 | if status:
202 | f.write(json.dumps(data, ensure_ascii= False) +'\n')
203 | f.flush()
204 |
205 | else:
206 | error_path = args.output_path.split('.jsonl')[0] + '_error.jsonl'
207 | error_file = open(error_path,'a')
208 | error_file.write(json.dumps(data, ensure_ascii= False) +'\n')
209 | error_file.flush()
210 | error_file.close()
211 |
212 | except Exception as e:
213 | print(f'error in thread:{item["query"]}')
214 | print(e)
--------------------------------------------------------------------------------
/src/multi_agent/prompt.py:
--------------------------------------------------------------------------------
1 |
2 | Planner_Agent = """Your task is to provide the steps for solving a multi-hop search problem.
3 |
4 | The output format should be:
5 | : {Overall thought process} \n [{\"query\": \"{query}\", \"intent\": 1}]"
6 | query should be in sentence format.
7 |
8 | Here are some examples:
9 |
10 | Question: What is the undergraduate school of the director of the movie "Sense and Sensibility"?
11 | Thought: To answer this question, I will take the following steps:
12 | 1. First, find out who directed the movie "Sense and Sensibility".
13 | 2. Investigate the educational background of the director, particularly their undergraduate school.
14 | 3. Identify the specific institution where the director completed their undergraduate studies.
15 | Now, I will start with the first step and search for the director of the movie "Sense and Sensibility".
16 | [{"query": "Who is the director of the movie ’Sense and Sensibility’?", "intent": 1}]
17 |
18 |
19 | Question: When did the birthplace of the performer of Live and Beyond become the capital of the state where Knowles is located?
20 | Thought: To determine when the birthplace of the performer of Live and Beyond
21 | became the capital of the state where Knowles is located, I will take the following steps:
22 | 1. First, identify who the performer of Live and Beyond is.
23 | 2. Then find out the birthplace of this performer.
24 | 3. Next, search for which state Knowles is from.
25 | 4. Finally, determine when the birth city of the performer of Live and Beyond became the capital of Knowles’ state.
26 | Now, I will start with the first step and search online to determine who the performer of Live and
27 | Beyond is.
28 | [{"query": "Find out who the performer of Live and Beyond is", "intent": 1}]
29 |
30 |
31 |
32 | Question: {input}
33 | """
34 |
35 |
36 | Searcher_Agent = """
37 | Given a piece of content containing SUBQUERIES to search, rewrite the SUBQUERIES in order to obtain more comprehensive search results. Please provide at least three rewritten queries.
38 |
39 | The output format should be the following JSON structure:
40 | `{"queries": ["query 1", "query 2"]}`
41 |
42 | Here are some examples:
43 |
44 | **Content:**
45 | To find out which undergraduate school the director of the movie "Sense and
46 | Sensibility" attended, I will take the following steps:
47 | 1. First, determine who the director of the movie "Sense and Sensibility" is.
48 | 2. Then, search for educational background of this director, particularly undergraduate education.
49 | Now, I’ll proceed with the first step by using online searches to identify the director of the movie "Sense and Sensibility".
50 | [{"query": "Find out who the director of the movie ’Sense and Sensibility’ is",
51 | "intent": 1}]
52 |
53 | **Rewritten Queries:**
54 | {"queries": ["Sense and Sensibility director", "Sense and Sensibility 1995 director", "Sense and Sensibility Filmmaker"]}
55 |
56 | **Content:**
57 | After analyzing the search results in detail, I concluded that the director of the movie "Sense and Sensibility" is Ang Lee. Therefore, I will proceed with the next step, where I need to search for his detailed undergraduate education.
58 | [{"query": "Search for the undergraduate school of Ang Lee", "intent": 1}]
59 |
60 |
61 | **Rewritten Queries:**
62 | {"queries": ["Ang Lee education background", "Ang Lee undergraduate school", "Ang Lee biography"]}
63 |
64 | **Content:**
65 | {input}
66 |
67 | **Rewritten Queries:** """
68 |
69 |
70 |
71 | Observation_Agent = """
72 | Given a multi-hop search problem, the steps already taken, and the search results obtained from the last executed step, your task is to:
73 |
74 | 1. Carefully analyze the search results to determine if they resolve the previous step, and provide a summary.
75 | 2. Update the execution steps and propose new steps based on the search results. You may encounter the following scenarios:
76 | a. If the current search results are sufficient to arrive at the Final Answer to the problem, provide the final answer directly. Use the format: " {thought process} {answer} ". Please strictly adhere to this format.
77 | b. If the current search results resolve the previous step, proceed according to the original plan and provide the next new step. Use the format: " {thought process} [{"query": "query", "intent": 1}] "; the query should be in sentence format.
78 | c. If the current search results do not resolve the previous step, reflect on and update the previous execution steps based on the search results. Use the format: " {reflection} [{"query": "query", "intent": 1}] "; the query should be in sentence format.
79 | d. If after multiple reflections (at least twice), the search still fails to resolve the issue, supplement based on your knowledge and provide the next new step. Use the format: " {thought process} [{"query": "query", "intent": 1}] "; the query should be in sentence format.
80 |
81 | Below are some examples:
82 |
83 | Question: What is the undergraduate school of the director of the movie "Sense and Sensibility"?
84 | Thought: To find out the undergraduate school of the director of the movie "Sense and Sensibility", I will take the following steps:\\1. First, determine who directed the movie "Sense and Sensibility". \\2. Then, search for this director’s educational background, focusing on undergraduate education. \\I will now begin with step one by searching online to identify the director of the movie "Sense and Sensibility".
85 | [{"query": "Find out who directed the movie 'Sense and Sensibility'", "intent": 1}]
86 | Observation:
87 | 1. The movie "Sense and Sensibility" was released in 1995 and is directed by Ang Lee, a renowned director known for his work in both Eastern and Western cinema.
88 | 2. Ang Lee is recognized for directing major films such as "Crouching Tiger, Hidden Dragon" and "Life of Pi".
89 | 3. Ang Lee's background in filmmaking and storytelling has earned him multiple Academy Awards and international acclaim.
90 |
91 | Thought: After a detailed analysis of the search results, I concluded that Ang Lee is the director of the movie "Sense and Sensibility". Thus, I will proceed to the next step and search for Ang Lee’s educational background, focusing on his undergraduate education.
92 | [{"query": "Search for the undergraduate school of Ang Lee", "intent": 1}]
93 |
94 |
95 |
96 | Question: What is the undergraduate school of the director of the movie "Sense and Sensibility"?
97 | Thought: To find out the undergraduate school of the director of the movie "Sense and Sensibility", I will take the following steps:\\1. First, determine who directed the movie "Sense and Sensibility". \\2. Then, search for this director’s educational background, focusing on undergraduate education. \\I will now begin with step one by searching online to identify the director of the movie "Sense and Sensibility".
98 | [{"query": "Find out who directed the movie "Sense and Sensibility", "intent": 1}]
99 | Thought: After a detailed analysis of the search results, I concluded that Ang Lee is the director of the movie "Sense and Sensibility". Thus, I will proceed to the next step and search for Ang Lee’s educational background, focusing on his undergraduate education.
100 | [{"query": "Search for the undergraduate school of Ang Lee", "intent": 1}]
101 | Observation:
102 | 1. Ang Lee graduated from the National Taiwan College of Arts in 1975 and then came to the U.S. to receive a B.F.A. Degree in Theatre/Theater Direction at the University of Illinois at Urbana-Champaign, and a Masters Degree in Film Production at New York University.
103 | 2. After high school Lee attended the Taiwan Academy of Art, where he became interested in acting. He moved to the United States in 1978 to study theater at the University of Illinois at Urbana-Champaign.
104 | Thought: After a thorough analysis of the search results, I found out that Ang Lee graduated from the National Taiwan University with a degree in French literature, as well as a B.F.A. Degree from the University of Illinois at Urbana-Champaign, therefore, the answer is National Taiwan College of Arts and University of Illinois at Urbana-Champaign.
105 |
106 |
107 |
108 | Question: Who is the director of the suspense drama that premiered on MBC?
109 | Thought: To solve this problem, I need to follow a few steps in my search: \n\nFirst, I need to identify which suspense drama premiered on MBC.\nThen, I will look up who directed that suspense drama.\nFirst, I will search for recent suspense dramas that premiered on MBC.
110 | [{ "query": "What recent suspense dramas premiered on MBC?", "intent": 1}]
111 | Observation:
112 | 1. Premiering soon! A 12-episode suspense drama is coming, starring Wang Zhiwen and Yu Ai... From the trailer, 'The Vanishing Elephant' combines suspense and comedy, bringing an amusing story with humor and absurdity."
113 | 2. Chen Weiting's suspense drama is airing, featuring time travel and criminal investigation, with top actors like award-winning actors joining."
114 | 3. The best 10 suspense dramas of the past 4 years, with 'The Big Squad' at the bottom. Though the series is framed as a historical drama, it unfolds as a large-scale suspense espionage drama."
115 | 4. Premiering soon! An 18-episode suspense drama is coming, starring Sun Li, with a strong cast. However, recent suspense dramas such as 'Dim Flames' and 'Rebirth' have not been well-received."
116 | 5. Lee Jong-suk's return! New drama 'The Black Lawyer' premiered on MBC, achieving high viewership. The story is a suspenseful crime drama about a vulnerable lawyer who is framed and battles powerful forces."
117 | 6. The top 10 suspense dramas to watch in 2024, featuring leading actors like Liao Fan, Duan Yihong, and Zhang Songwen..."
118 | 7. Premiering tonight! Another suspenseful drama is coming! 'Thirteen Years in Dust' will air on iQIYI's Mist Theater, starring Chen Jianbin and Chen Xiao."
119 | 8. Chen Weiting's new drama 'The Sun and Stars' debuted with impressive suspenseful plotlines.
120 |
121 | Thought: From the search results, I found several relevant pieces of information:\n1. The suspense drama 'The Black Lawyer' premiered on MBC, starring Lee Jong-suk and Im Yoon-ah, and it is a crime suspense drama. 2. The premiere of 'The Black Lawyer' achieved high ratings, telling a story about a vulnerable lawyer who is framed and fights back against powerful figures. 3. However, the search results did not provide the director's information for 'The Black Lawyer.' \nFilter out unrelated content: The search results also mention suspense dramas on other platforms like iQIYI and Mist Theater, and other non-MBC dramas, such as 'The Sun and Stars.' These are not directly related to the question and can be excluded. \nCurrent progress summary: Based on the retrieved information, 'The Black Lawyer' is a recent suspense drama premiered on MBC, but the director’s information is still missing. \nNext plan: I will continue with the second step of my plan, searching for the director of 'The Black Lawyer.'
122 | [{"query": "Who is the director of 'The Black Lawyer'?", "intent": 1}]
123 |
124 |
125 | Question: {input}
126 | {Thought}
127 | Observation: {Observation}
128 | """
129 |
--------------------------------------------------------------------------------
/src/RAMP/gen_qa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | from model import call_with_messages
5 | import argparse
6 | from tqdm import tqdm
7 | from concurrent.futures import ThreadPoolExecutor, as_completed
8 | from hanziconv import HanziConv
9 |
10 | query_prompt_zh = """{paragraph}
11 |
12 | 补充所有的[mask],并重新输出,要求格式不能发生变化。"""
13 |
14 | query_prompt_en = """{paragraph}
15 |
16 | Fill in all the [mask] and output the whole paragraph without changing its format."""
17 |
18 | ner_prompt_zh = """
19 | 从以下文本中提取所有知识实体(最好5个以上),知识实体是指文本中提到的全部具体中文人物、地点、组织、会议期刊、短学术名词概念,以及唯一出现的短数字等。如果多个实体连续出现,只提取最细粒度的实体。重复出现的实体不要提取。
20 | 请将提取的实体以列表的形式直接输出,无需输出其他内容。每个实体不要超过10个字。
21 | 确保你提取的实体不能通过该段落中的上下文直接推理出来,需要额外的信息搜索才能推理。
22 | 你的实体只能是中文或者数字。
23 |
24 | ##例子
25 | 段落:1980年2月27日,林义雄妻子方素敏前往军事情报局看守所并且代替国际特赦组织大坂市办事处传达讯息。而隔天军事法庭召开第一次调查庭,包括方素敏等待审党外运动人士的家属纷纷前往看守所探视和旁听。不过方素敏则对于家中小孩感到不放心,担任林义雄议员秘书的田秋堇则搭乘公车先行前往林义雄在台北市信义住家。发现林义雄9岁的长女林奂均身中6刀重伤后田秋堇赶紧向外求助,之后林浊水和康文雄等党外运动人士在接到消息赶来帮忙后,在住处地下室分别发现其中身中14刀的林义雄母亲林游阿妹,以及各自因为从后背贯穿前胸的1刀丧命的双胞胎女儿林亮均与林亭均。
26 | 实体:["方素敏", "国际特赦组织大坂市办事处", "台北市信义", "6", "14", "林奂均", "林浊水", "康文雄", "林游阿妹", "林亮均", "林亭均"]
27 |
28 | 段落:近日,上海科技大学信息科学与技术学院后摩尔中心(PMICC)寇煦丰、祝智峰团队,利用分子束外延技术设计制备了基于2英寸磁性拓扑异质结Bi2Te3/CrTe2薄膜,实现了能同时具备类脑突触和神经元功能的自旋轨道矩器件阵列(spin-orbit torque device array),并集成了批量归一化算法和可训练激活函数,相关研究成果以"Integrated Artificial Neural Network with Trainable Activation Function Enabled byTopological Insulator-based Spin-Orbit Torque Devices"为题在线发表于知名学术期刊ACS Nano。
29 | 实体:["上海科技大学", "寇煦丰", "祝智峰", "自旋轨道矩器件阵列", "ACS Nano"]
30 |
31 | 段落:对于开放定址法,荷载因子是特别重要因素,应严格限制在0.7-0.8以下。超过0.8,查表时的CPU缓存不命中(cache missing)按照指数曲线上升。因此,一些采用开放定址法的hash库,如Java的hash库限制了荷载因子为0.75,超过此值将resize散列表。
32 | 实体:["荷载因子", "指数曲线", "Java", "散列表", "0.75"]
33 |
34 | 段落:大连公交车以12米为主,目前最小为6米级。纯电动逐渐占据主导地位。目前的主力车型是比亚迪K9系列纯电动客车,目前大连公交开始大批量采购比亚迪纯电动客车。柴油、天然气及混合动力车辆逐渐退出大连公交的舞台。
35 | 实体:["比亚迪K9", "6", "12"]
36 |
37 | 段落:{paragraph}
38 | 实体:
39 | """
40 |
41 | ner_prompt_en = """
42 | Extract all the knowledge entities (more than 5 entities if exist) from the following text. Knowledge entities refer to specific individuals, locations, organizations, conferences, journals, short academic concepts, and unique short numbers mentioned in the text. If multiple entities appear consecutively, only extract the finest-grained entity. Please output the extracted entities as a list directly, without any other content. Each entity should not exceed 10 characters.
43 | If an entity repeatedly appears, you should not extract it. You need extract a whole word like [American] instead of [America]n.
44 | Ensure that the entities you extract cannot be directly infer from the context of the paragraph. You must need extra information search on the Internet to infer the entities.
45 |
46 | ##Example
47 | Paragraph: On February 27, 1980, Lin Yixiong's wife, Fang Sumin, went to the military intelligence bureau detention center and conveyed messages on behalf of the International Amnesty Organization Osaka office. The next day, the military court held its first investigation session, and families waiting for the trial of opposition movement activists, including Fang Sumin, went to the detention center to visit and listen. However, Fang Sumin was worried about her children at home, and Tian Qiujin, who served as Lin Yixiong's secretary, took the bus to Lin Yixiong's residence in Xinyi, Taipei City. Upon discovering that Lin Yixiong's 9-year-old eldest daughter, Lin Huanjun, was seriously injured with six stab wounds, Tian Qiujin quickly sought help. Subsequently, Lin Zhuoshui and Kang Wenxiong, among other opposition movement activists, rushed to help after receiving the news and found Lin Yixiong's mother, Lin You'a, stabbed 14 times, and her twin daughters, Lin Liangjun and Lin Tingjun, who died from a single stab wound that pierced their backs and exited their chests, in the basement of the residence.
48 | Entities: ["Lin Yixiong", "Fang Sumin", "International Amnesty Organization Osaka office", "Tian Qiujin", "Xinyi, Taipei City", "6", "14", "Lin Huanjun", "Lin Zhuoshui", "Kang Wenxiong", "Lin You'a", "Lin Liangjun", "Lin Tingjun"]
49 |
50 | Paragraph: Recently, the team of Kou Xufen and Zhu Zhifeng from the Post-Moore Center (PMICC) at the School of Information Science and Technology, ShanghaiTech University, designed and prepared a 2-inch magnetic topological heterojunction Bi2Te3/CrTe2 thin film based on molecular beam epitaxy technology. They achieved a spin-orbit torque device array capable of possessing both brain-like synapse and neuron functions, and integrated batch normalization algorithms and trainable activation functions. The relevant research results were published online in the renowned academic journal ACS Nano under the title "Integrated Artificial Neural Network with Trainable Activation Function Enabled by Topological Insulator-based Spin-Orbit Torque Devices".
51 | Entities: ["ShanghaiTech University", "Kou Xufen", "Zhu Zhifeng", "spin-orbit torque device array", "ACS Nano"]
52 |
53 | Paragraph: For open addressing, the load factor is a particularly important factor and should be strictly limited below 0.7-0.8. Beyond 0.8, CPU cache misses (cache missing) increase exponentially when looking up tables. Therefore, some hash libraries that use open addressing, such as Java's hash libraries, limit the load factor to 0.75, and the hash table will be resized when this value is exceeded.
54 | Entities: ["load factor", "exponential curve", "Java", "hash table", "0.75"]
55 |
56 | Paragraph: Dalian buses are mainly 12 meters long, with the smallest currently being 6 meters. Pure electric vehicles are gradually taking the dominant position. The main model is the BYD K9 series pure electric bus, and Dalian public transportation has begun to purchase a large number of BYD pure electric buses. Diesel, natural gas, and hybrid vehicles are gradually exiting the stage of Dalian public transportation.
57 | Entities: ["BYD K9", "6", "12"]
58 |
59 | Paragraph: {paragraph}
60 | Entities:
61 | """
62 |
63 |
64 | def ner(para, args):
65 | if 'zh' in args.corpus:
66 | prompt = ner_prompt_zh.format(paragraph=para)
67 | else:
68 | prompt = ner_prompt_en.format(paragraph=para)
69 |
70 | prompt = [{'role': 'system', 'content': 'You are a helpful assistant.'},
71 | {'role': 'user', 'content': prompt}]
72 | response = call_with_messages(args.model, prompt)
73 | response = eval(response)
74 | return response
75 |
76 | def generate_qa(paragraph, args):
77 | try:
78 | if "zh" in args.corpus:
79 | para, title = paragraph[0], HanziConv.toSimplified(paragraph[1])
80 | else:
81 | para, title = paragraph[0], paragraph[1]
82 | entities = ner(para, args)
83 | except:
84 | print("NER error")
85 | return []
86 |
87 | if title in entities:
88 | entities.remove(title)
89 |
90 | for e in entities:
91 | if para.count(e) > 1:
92 | entities.remove(e)
93 | if len(entities) == 0:
94 | return []
95 | pairs = []
96 | for _ in range(1):
97 | if len(entities) < 3: # 1, 2
98 | num_mask = random.randint(min(len(entities), 1), min(len(entities), 4))
99 | # num_mask = 1
100 | elif len(entities) >= 4: # 4, 5, ...
101 | num_mask = 4
102 | else: # 3
103 | num_mask = random.randint(3, min(len(entities), 4))
104 |
105 | def get_weight(entity, is_zh, is_num=False, max_weight=0):
106 | if is_num:
107 | return max_weight
108 | if is_zh:
109 | return len(entity)
110 | else:
111 | return len(entity.split(" "))
112 |
113 | non_numeric_entities = [entity for entity in entities if not entity.isdigit()]
114 | max_weight = max(get_weight(entity, "zh" in args.corpus) for entity in non_numeric_entities) if non_numeric_entities else 1
115 | weights = [get_weight(entity, "zh" in args.corpus, entity.isdigit(), max_weight) for entity in entities]
116 | sampled_indexes = random.choices(range(len(entities)), weights=weights, k=num_mask)
117 | masks = [entities[i] for i in sorted(sampled_indexes)]
118 |
119 | if title.split(' (')[0] not in para:
120 | para = para.replace("他们", title.split(' (')[0], 1).replace("她", title.split(' (')[0], 1).replace("他", title.split(' (')[0], 1).replace("它", title.split(' (')[0], 1)
121 | if title.split(' (')[0] not in para:
122 | para = title.split(' (')[0] + " " + para
123 | masked_paragraph = para
124 | for m in masks:
125 | masked_paragraph = masked_paragraph.replace(m, '[mask]', 1)
126 |
127 | if masked_paragraph.startswith('[mask]'):
128 | continue
129 |
130 | if "zh" in args.corpus:
131 | pairs.append({"query": query_prompt_zh.format(paragraph=masked_paragraph), "answer": para, "ext": {"mask_num": num_mask, "masks": masks, "entities": entities, "source": "wiki", "title": title}})
132 | else:
133 | pairs.append({"query": query_prompt_en.format(paragraph=masked_paragraph), "answer": para, "ext": {"mask_num": num_mask, "masks": masks, "entities": entities, "source": "wiki", "title": title}})
134 | return pairs
135 |
136 |
137 | def process_file(fp, existing_title):
138 | if fp in existing_title:
139 | return []
140 | if 'zh' in args.corpus:
141 | with open(f'{args.corpus}/{fp}', 'r') as f:
142 | content = f.read().split('\n')
143 | para = [c for c in content if '图表' not in c and c.count('/') < 2 and not c.startswith('*') and not c.endswith(":")] # Filter Pages
144 | if len(para) > 2:
145 | return [(p, fp) for p in random.sample(para, len(para) - 1)]
146 | elif len(para) == 2:
147 | return [(p, fp) for p in random.sample(para, 2)]
148 | elif len(para) == 1:
149 | return [(para[0], fp)]
150 | return []
151 | else:
152 | with open(f'{args.corpus}/{fp}', 'r') as f:
153 | content = f.read().split('\n')
154 | para = [c for c in content if 'List' not in c and c.count('/') < 2 and not c.startswith('*') and not c.endswith(":")] # Filter Pages
155 | if len(para) > 2:
156 | return [(p, fp) for p in random.sample(para, len(para))]
157 | elif len(para) == 2:
158 | return [(p, fp) for p in random.sample(para, 2)]
159 | elif len(para) == 1:
160 | return [(para[0], fp)]
161 | return []
162 |
163 | if __name__ == '__main__':
164 | argparse = argparse.ArgumentParser()
165 | argparse.add_argument('--num_threads', type=int, default=250)
166 | argparse.add_argument('--model', type=str, default='qwen-turbo')
167 |
168 | argparse.add_argument('--corpus', type=str, default='wiki-en', help='Directory of Wiki Pages')
169 | argparse.add_argument('--output_path', type=str, default='qa_en.jsonl')
170 | argparse.add_argument('--continue_gen', action_store=True)
171 |
172 | args = argparse.parse_args()
173 |
174 | files = os.listdir(args.corpus)
175 | random.seed(42)
176 | random.shuffle(files)
177 |
178 | existing_title = set()
179 | if args.continue_gen:
180 | with open(args.output_path, 'r', encoding='utf-8') as f:
181 | for line in tqdm(f):
182 | data = json.loads(line)
183 | existing_title.add(data["ext"]["title"])
184 |
185 | paragraphs = []
186 | with ThreadPoolExecutor(max_workers=100) as executor:
187 | future_to_file = {executor.submit(process_file, fp, existing_title): fp for fp in files}
188 | for future in tqdm(as_completed(future_to_file), total=len(files), desc='Processing files'):
189 | try:
190 | paragraphs.extend(future.result())
191 | except:
192 | pass
193 |
194 | with open(args.output_path, 'a') as f:
195 | with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
196 | results = {executor.submit(generate_qa, item, args): item for item in paragraphs}
197 |
198 | for future in tqdm(as_completed(results),total=len(paragraphs)):
199 | item = results[future]
200 | try:
201 | data = future.result()
202 | for d in data:
203 | f.write(json.dumps(d, ensure_ascii= False) +'\n')
204 | f.flush()
205 |
206 | except Exception as e:
207 | print(f'error in thread:{item}, {e}')
208 |
--------------------------------------------------------------------------------