├── image.png
├── requirements.txt
├── LICENSE
├── scripts
├── Qwen_MATH.sh
├── Qwen_GSM8K.sh
├── GPT_GSM8K.sh
└── Gemini_GSM8K.sh
├── prompts
├── AIME.py
├── GSM8K.py
├── GSM_Hard.py
├── MATH.py
├── GPQA.py
└── MMLU.py
├── readme.md
├── model.py
├── main.py
├── eval_csv_cost.py
├── eval_csv_N.py
└── dataset.py
/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MraDonkey/rethinking_prompting/HEAD/image.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opentelemetry-proto==1.26.0
2 | pyarrow==18.0.0
3 | multiprocess==0.70.16
4 | numpy==1.26.4
5 | datasets==3.2.0
6 | openai==1.53.0
7 | vllm==0.8.4
8 | protobuf==4.25.8
9 | google-generativeai==0.7.2
10 | openpyxl
11 | matplotlib
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 donkey
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/Qwen_MATH.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --model_name Qwen/Qwen2.5-7B-Instruct \
3 | --model_type vllm \
4 | --split test \
5 | --dataset MATH \
6 | --reasoning DiP \
7 | --shot 0 \
8 | --batchsize 10 \
9 | --range_begin 0 \
10 | --range_end 32 \
11 | --gpu 0,1
12 |
13 | python main.py \
14 | --model_name Qwen/Qwen2.5-7B-Instruct \
15 | --model_type vllm \
16 | --split test \
17 | --dataset MATH \
18 | --reasoning CoT \
19 | --shot 0 \
20 | --batchsize 10 \
21 | --range_begin 0 \
22 | --range_end 32 \
23 | --gpu 0,1
24 |
25 | python main.py \
26 | --model_name Qwen/Qwen2.5-7B-Instruct \
27 | --model_type vllm \
28 | --split test \
29 | --dataset MATH \
30 | --reasoning L2M \
31 | --shot 1 \
32 | --batchsize 10 \
33 | --range_begin 0 \
34 | --range_end 32 \
35 | --gpu 0,1
36 |
37 | python main.py \
38 | --model_name Qwen/Qwen2.5-7B-Instruct \
39 | --model_type vllm \
40 | --split test \
41 | --dataset MATH \
42 | --reasoning SBP \
43 | --shot 0 \
44 | --batchsize 10 \
45 | --range_begin 0 \
46 | --range_end 32 \
47 | --gpu 0,1
48 |
49 | python main.py \
50 | --model_name Qwen/Qwen2.5-7B-Instruct \
51 | --model_type vllm \
52 | --split test \
53 | --dataset MATH \
54 | --reasoning AnP \
55 | --shot 1 \
56 | --batchsize 10 \
57 | --range_begin 0 \
58 | --range_end 32 \
59 | --gpu 0,1
60 |
61 | python main.py \
62 | --model_name Qwen/Qwen2.5-7B-Instruct \
63 | --model_type vllm \
64 | --split test \
65 | --dataset MATH \
66 | --reasoning ToT \
67 | --shot 3 \
68 | --batchsize 10 \
69 | --range_begin 0 \
70 | --range_end 5 \
71 | --gpu 0,1
72 |
73 | python main.py \
74 | --model_name Qwen/Qwen2.5-7B-Instruct \
75 | --model_type vllm \
76 | --split test \
77 | --dataset MATH \
78 | --reasoning ToT \
79 | --shot 5 \
80 | --batchsize 10 \
81 | --range_begin 0 \
82 | --range_end 5 \
83 | --gpu 0,1
84 |
85 | python main.py \
86 | --model_name Qwen/Qwen2.5-7B-Instruct \
87 | --model_type vllm \
88 | --split test \
89 | --dataset MATH \
90 | --reasoning ToT \
91 | --shot 10 \
92 | --batchsize 10 \
93 | --range_begin 0 \
94 | --range_end 5 \
95 | --gpu 0,1
96 |
97 | python main.py \
98 | --model_name Qwen/Qwen2.5-7B-Instruct \
99 | --model_type vllm \
100 | --split test \
101 | --dataset MATH \
102 | --reasoning S-RF \
103 | --shot 0 \
104 | --batchsize 5 \
105 | --range_begin 0 \
106 | --range_end 1 \
107 | --gpu 0,1
108 |
109 | python main.py \
110 | --model_name Qwen/Qwen2.5-7B-Instruct \
111 | --model_type vllm \
112 | --split test \
113 | --dataset MATH \
114 | --reasoning MAD \
115 | --shot 0 \
116 | --batchsize 1 \
117 | --range_begin 0 \
118 | --range_end 1 \
119 | --gpu 0,1
120 |
--------------------------------------------------------------------------------
/scripts/Qwen_GSM8K.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --model_name Qwen/Qwen2.5-7B-Instruct \
3 | --model_type vllm \
4 | --split test \
5 | --dataset GSM8K \
6 | --reasoning DiP \
7 | --shot 0 \
8 | --batchsize 10 \
9 | --range_begin 0 \
10 | --range_end 32 \
11 | --gpu 0,1
12 |
13 | python main.py \
14 | --model_name Qwen/Qwen2.5-7B-Instruct \
15 | --model_type vllm \
16 | --split test \
17 | --dataset GSM8K \
18 | --reasoning CoT \
19 | --shot 0 \
20 | --batchsize 10 \
21 | --range_begin 0 \
22 | --range_end 32 \
23 | --gpu 0,1
24 |
25 | python main.py \
26 | --model_name Qwen/Qwen2.5-7B-Instruct \
27 | --model_type vllm \
28 | --split test \
29 | --dataset GSM8K \
30 | --reasoning L2M \
31 | --shot 1 \
32 | --batchsize 10 \
33 | --range_begin 0 \
34 | --range_end 32 \
35 | --gpu 0,1
36 |
37 | python main.py \
38 | --model_name Qwen/Qwen2.5-7B-Instruct \
39 | --model_type vllm \
40 | --split test \
41 | --dataset GSM8K \
42 | --reasoning SBP \
43 | --shot 0 \
44 | --batchsize 10 \
45 | --range_begin 0 \
46 | --range_end 32 \
47 | --gpu 0,1
48 |
49 | python main.py \
50 | --model_name Qwen/Qwen2.5-7B-Instruct \
51 | --model_type vllm \
52 | --split test \
53 | --dataset GSM8K \
54 | --reasoning AnP \
55 | --shot 1 \
56 | --batchsize 10 \
57 | --range_begin 0 \
58 | --range_end 32 \
59 | --gpu 0,1
60 |
61 | python main.py \
62 | --model_name Qwen/Qwen2.5-7B-Instruct \
63 | --model_type vllm \
64 | --split test \
65 | --dataset GSM8K \
66 | --reasoning ToT \
67 | --shot 3 \
68 | --batchsize 10 \
69 | --range_begin 0 \
70 | --range_end 5 \
71 | --gpu 0,1
72 |
73 | python main.py \
74 | --model_name Qwen/Qwen2.5-7B-Instruct \
75 | --model_type vllm \
76 | --split test \
77 | --dataset GSM8K \
78 | --reasoning ToT \
79 | --shot 5 \
80 | --batchsize 10 \
81 | --range_begin 0 \
82 | --range_end 5 \
83 | --gpu 0,1
84 |
85 | python main.py \
86 | --model_name Qwen/Qwen2.5-7B-Instruct \
87 | --model_type vllm \
88 | --split test \
89 | --dataset GSM8K \
90 | --reasoning ToT \
91 | --shot 10 \
92 | --batchsize 10 \
93 | --range_begin 0 \
94 | --range_end 5 \
95 | --gpu 0,1
96 |
97 | python main.py \
98 | --model_name Qwen/Qwen2.5-7B-Instruct \
99 | --model_type vllm \
100 | --split test \
101 | --dataset GSM8K \
102 | --reasoning S-RF \
103 | --shot 0 \
104 | --batchsize 5 \
105 | --range_begin 0 \
106 | --range_end 1 \
107 | --gpu 0,1
108 |
109 | python main.py \
110 | --model_name Qwen/Qwen2.5-7B-Instruct \
111 | --model_type vllm \
112 | --split test \
113 | --dataset GSM8K \
114 | --reasoning MAD \
115 | --shot 0 \
116 | --batchsize 1 \
117 | --range_begin 0 \
118 | --range_end 1 \
119 | --gpu 0,1
120 |
--------------------------------------------------------------------------------
/scripts/GPT_GSM8K.sh:
--------------------------------------------------------------------------------
1 | ## You can control the maximum concurrency level by adjusting the "max_num_workers" parameter.
2 |
3 | python main.py \
4 | --model_name gpt-4o-mini \
5 | --model_type openai \
6 | --split test \
7 | --dataset GSM8K \
8 | --reasoning DiP \
9 | --shot 0 \
10 | --batchsize 10 \
11 | --range_begin 0 \
12 | --range_end 16 \
13 | --max_num_workers 10 \
14 | --openai_api_key your_api_key \
15 | --openai_base_url your_base_url
16 |
17 | python main.py \
18 | --model_name gpt-4o-mini \
19 | --model_type openai \
20 | --split test \
21 | --dataset GSM8K \
22 | --reasoning CoT \
23 | --shot 0 \
24 | --batchsize 10 \
25 | --range_begin 0 \
26 | --range_end 16 \
27 | --max_num_workers 10 \
28 | --openai_api_key your_api_key \
29 | --openai_base_url your_base_url
30 |
31 | python main.py \
32 | --model_name gpt-4o-mini \
33 | --model_type openai \
34 | --split test \
35 | --dataset GSM8K \
36 | --reasoning L2M \
37 | --shot 1 \
38 | --batchsize 10 \
39 | --range_begin 0 \
40 | --range_end 16 \
41 | --max_num_workers 10 \
42 | --openai_api_key your_api_key \
43 | --openai_base_url your_base_url
44 |
45 | python main.py \
46 | --model_name gpt-4o-mini \
47 | --model_type openai \
48 | --split test \
49 | --dataset GSM8K \
50 | --reasoning SBP \
51 | --shot 0 \
52 | --batchsize 10 \
53 | --range_begin 0 \
54 | --range_end 16 \
55 | --max_num_workers 10 \
56 | --openai_api_key your_api_key \
57 | --openai_base_url your_base_url
58 |
59 | python main.py \
60 | --model_name gpt-4o-mini \
61 | --model_type openai \
62 | --split test \
63 | --dataset GSM8K \
64 | --reasoning AnP \
65 | --shot 1 \
66 | --batchsize 10 \
67 | --range_begin 0 \
68 | --range_end 16 \
69 | --max_num_workers 10 \
70 | --openai_api_key your_api_key \
71 | --openai_base_url your_base_url
72 |
73 | python main.py \
74 | --model_name gpt-4o-mini \
75 | --model_type openai \
76 | --split test \
77 | --dataset GSM8K \
78 | --reasoning ToT \
79 | --shot 3 \
80 | --batchsize 10 \
81 | --range_begin 0 \
82 | --range_end 5 \
83 | --max_num_workers 10 \
84 | --openai_api_key your_api_key \
85 | --openai_base_url your_base_url
86 |
87 | python main.py \
88 | --model_name gpt-4o-mini \
89 | --model_type openai \
90 | --split test \
91 | --dataset GSM8K \
92 | --reasoning ToT \
93 | --shot 5 \
94 | --batchsize 10 \
95 | --range_begin 0 \
96 | --range_end 5 \
97 | --max_num_workers 10 \
98 | --openai_api_key your_api_key \
99 | --openai_base_url your_base_url
100 |
101 | python main.py \
102 | --model_name gpt-4o-mini \
103 | --model_type openai \
104 | --split test \
105 | --dataset GSM8K \
106 | --reasoning ToT \
107 | --shot 10 \
108 | --batchsize 10 \
109 | --range_begin 0 \
110 | --range_end 5 \
111 | --max_num_workers 10 \
112 | --openai_api_key your_api_key \
113 | --openai_base_url your_base_url
114 |
115 | python main.py \
116 | --model_name gpt-4o-mini \
117 | --model_type openai \
118 | --split test \
119 | --dataset GSM8K \
120 | --reasoning S-RF \
121 | --shot 0 \
122 | --batchsize 5 \
123 | --range_begin 0 \
124 | --range_end 1 \
125 | --max_num_workers 10 \
126 | --openai_api_key your_api_key \
127 | --openai_base_url your_base_url
128 |
129 | python main.py \
130 | --model_name gpt-4o-mini \
131 | --model_type openai \
132 | --split test \
133 | --dataset GSM8K \
134 | --reasoning MAD \
135 | --shot 0 \
136 | --batchsize 1 \
137 | --range_begin 0 \
138 | --range_end 1 \
139 | --max_num_workers 1 \
140 | --openai_api_key your_api_key \
141 | --openai_base_url your_base_url
142 |
--------------------------------------------------------------------------------
/prompts/AIME.py:
--------------------------------------------------------------------------------
1 | prompt_format = " Your final result should be in the form \\boxed{answer}, at the end of your response."
2 |
3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words."
4 |
5 | io = "{question}" + prompt_format
6 |
7 | io_briefly = io + " You should answer with no more than 200 words."
8 |
9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n'
10 |
11 | cot_0_shot = cot_pre + '''Question: {question}
12 | Answer: Let's think step by step.'''
13 |
14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
16 |
17 | # Question: {question}
18 | # Answer:
19 | # '''
20 |
21 | Least_to_Most_0_shot = cot_pre + ''' In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers.
22 |
23 | Question: {question}
24 | Answer:
25 | '''
26 |
27 | tot_post = '''
28 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
29 |
30 | tot_3_solutions = '''
31 | Question: {question}
32 |
33 | Solution 1: {solution1}
34 | Solution 2: {solution2}
35 | Solution 3: {solution3}'''
36 |
37 | tot_5_solutions = tot_3_solutions + '''
38 | Solution 4: {solution4}
39 | Solution 5: {solution5}'''
40 |
41 | tot_10_solutions = tot_5_solutions + '''
42 | Solution 6: {solution6}
43 | Solution 7: {solution7}
44 | Solution 8: {solution8}
45 | Solution 9: {solution9}
46 | Solution 10: {solution10}'''
47 |
48 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
49 |
50 | # Initial Problem:
51 | {question}
52 |
53 | # Instructions:
54 | ## Relevant Problems:
55 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
56 | - After "Q: ", describe the problem.
57 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
58 |
59 | ## Solve the Initial Problem:
60 | Q: Copy and paste the initial problem here.
61 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
62 | '''
63 |
64 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
65 |
66 | # Initial Problem:
67 | {question}
68 |
69 | # Instructions:
70 | ## Relevant Problems:
71 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
72 | - After "Q: ", describe the problem.
73 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
74 |
75 | ## Solve the Initial Problem:
76 | Q: Copy and paste the initial problem here.
77 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
78 | '''
79 |
80 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
81 |
82 | # Initial Problem:
83 | {question}
84 |
85 | # Instructions:
86 | ## Relevant Problems:
87 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
88 | - After "Q: ", describe the problem.
89 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
90 |
91 | ## Solve the Initial Problem:
92 | Q: Copy and paste the initial problem here.
93 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
94 | '''
95 |
96 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem.
97 | Question:
98 | {question}
99 |
100 | Principles involved:
101 | '''
102 |
103 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + '''
104 | Question:
105 | {question}
106 |
107 | Principles:
108 | {principles}
109 |
110 | Answer:
111 | '''
112 |
113 |
--------------------------------------------------------------------------------
/scripts/Gemini_GSM8K.sh:
--------------------------------------------------------------------------------
1 | ## You can control the maximum concurrency level by adjusting the "max_num_workers" parameter.
2 |
3 | python main.py \
4 | --model_name gemini-1.5-flash \
5 | --model_type gemini \
6 | --split test \
7 | --dataset GSM8K \
8 | --reasoning DiP \
9 | --shot 0 \
10 | --batchsize 10 \
11 | --range_begin 0 \
12 | --range_end 8 \
13 | --max_num_workers 10 \
14 | --gemini_api_key your_api_key
15 |
16 | python main.py \
17 | --model_name gemini-1.5-flash \
18 | --model_type gemini \
19 | --split test \
20 | --dataset GSM8K \
21 | --reasoning CoT \
22 | --shot 0 \
23 | --batchsize 10 \
24 | --range_begin 0 \
25 | --range_end 8 \
26 | --max_num_workers 10 \
27 | --gemini_api_key your_api_key
28 |
29 | python main.py \
30 | --model_name gemini-1.5-flash \
31 | --model_type gemini \
32 | --split test \
33 | --dataset GSM8K \
34 | --reasoning L2M \
35 | --shot 1 \
36 | --batchsize 10 \
37 | --range_begin 0 \
38 | --range_end 8 \
39 | --max_num_workers 10 \
40 | --gemini_api_key your_api_key
41 |
42 | python main.py \
43 | --model_name gemini-1.5-flash \
44 | --model_type gemini \
45 | --split test \
46 | --dataset GSM8K \
47 | --reasoning SBP \
48 | --shot 0 \
49 | --batchsize 10 \
50 | --range_begin 0 \
51 | --range_end 8 \
52 | --max_num_workers 10 \
53 | --gemini_api_key your_api_key
54 |
55 | python main.py \
56 | --model_name gemini-1.5-flash \
57 | --model_type gemini \
58 | --split test \
59 | --dataset GSM8K \
60 | --reasoning AnP \
61 | --shot 1 \
62 | --batchsize 10 \
63 | --range_begin 0 \
64 | --range_end 8 \
65 | --max_num_workers 10 \
66 | --gemini_api_key your_api_key
67 |
68 | ####
69 |
70 | python main.py \
71 | --model_name gemini-1.5-flash \
72 | --model_type gemini \
73 | --split test \
74 | --dataset GSM8K \
75 | --reasoning DiP \
76 | --shot 0 \
77 | --batchsize 10 \
78 | --range_begin 8 \
79 | --range_end 16 \
80 | --max_num_workers 10 \
81 | --gemini_api_key your_api_key
82 |
83 | python main.py \
84 | --model_name gemini-1.5-flash \
85 | --model_type gemini \
86 | --split test \
87 | --dataset GSM8K \
88 | --reasoning CoT \
89 | --shot 0 \
90 | --batchsize 10 \
91 | --range_begin 8 \
92 | --range_end 16 \
93 | --max_num_workers 10 \
94 | --gemini_api_key your_api_key
95 |
96 | python main.py \
97 | --model_name gemini-1.5-flash \
98 | --model_type gemini \
99 | --split test \
100 | --dataset GSM8K \
101 | --reasoning L2M \
102 | --shot 1 \
103 | --batchsize 10 \
104 | --range_begin 8 \
105 | --range_end 16 \
106 | --max_num_workers 10 \
107 | --gemini_api_key your_api_key
108 |
109 | python main.py \
110 | --model_name gemini-1.5-flash \
111 | --model_type gemini \
112 | --split test \
113 | --dataset GSM8K \
114 | --reasoning SBP \
115 | --shot 0 \
116 | --batchsize 10 \
117 | --range_begin 8 \
118 | --range_end 16 \
119 | --max_num_workers 10 \
120 | --gemini_api_key your_api_key
121 |
122 | python main.py \
123 | --model_name gemini-1.5-flash \
124 | --model_type gemini \
125 | --split test \
126 | --dataset GSM8K \
127 | --reasoning AnP \
128 | --shot 1 \
129 | --batchsize 10 \
130 | --range_begin 8 \
131 | --range_end 16 \
132 | --max_num_workers 10 \
133 | --gemini_api_key your_api_key
134 |
135 |
136 |
137 | python main.py \
138 | --model_name gemini-1.5-flash \
139 | --model_type gemini \
140 | --split test \
141 | --dataset GSM8K \
142 | --reasoning ToT \
143 | --shot 3 \
144 | --batchsize 10 \
145 | --range_begin 0 \
146 | --range_end 5 \
147 | --max_num_workers 10 \
148 | --gemini_api_key your_api_key
149 |
150 | python main.py \
151 | --model_name gemini-1.5-flash \
152 | --model_type gemini \
153 | --split test \
154 | --dataset GSM8K \
155 | --reasoning ToT \
156 | --shot 5 \
157 | --batchsize 10 \
158 | --range_begin 0 \
159 | --range_end 5 \
160 | --max_num_workers 10 \
161 | --gemini_api_key your_api_key
162 |
163 | python main.py \
164 | --model_name gemini-1.5-flash \
165 | --model_type gemini \
166 | --split test \
167 | --dataset GSM8K \
168 | --reasoning ToT \
169 | --shot 10 \
170 | --batchsize 10 \
171 | --range_begin 0 \
172 | --range_end 5 \
173 | --max_num_workers 10 \
174 | --gemini_api_key your_api_key
175 |
176 | python main.py \
177 | --model_name gemini-1.5-flash \
178 | --model_type gemini \
179 | --split test \
180 | --dataset GSM8K \
181 | --reasoning S-RF \
182 | --shot 0 \
183 | --batchsize 5 \
184 | --range_begin 0 \
185 | --range_end 1 \
186 | --max_num_workers 10 \
187 | --gemini_api_key your_api_key
188 |
189 | python main.py \
190 | --model_name gemini-1.5-flash \
191 | --model_type gemini \
192 | --split test \
193 | --dataset GSM8K \
194 | --reasoning MAD \
195 | --shot 0 \
196 | --batchsize 1 \
197 | --range_begin 0 \
198 | --range_end 1 \
199 | --max_num_workers 1 \
200 | --gemini_api_key your_api_key
201 |
--------------------------------------------------------------------------------
/prompts/GSM8K.py:
--------------------------------------------------------------------------------
1 | prompt_format = " Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response."
2 |
3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words."
4 |
5 | io = "{question}" + prompt_format
6 |
7 | io_briefly = io + " You should answer with no more than 200 words."
8 |
9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n'
10 |
11 | cot_0_shot = cot_pre + '''Question: {question} Let's think step by step.
12 | Answer:'''
13 |
14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
16 |
17 | # Question: {question}
18 | # Answer:
19 | # '''
20 |
21 | cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
22 | Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
23 |
24 | Question: {question}
25 | Answer:
26 | '''
27 |
28 | cot_5_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
29 | Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
30 |
31 | Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
32 | Answer: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. The answer is \\boxed{10}.
33 |
34 | Question: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?"
35 | Answer: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30. This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more. The answer is \\boxed{5}.
36 |
37 | Question: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?
38 | Answer: Maila read 12 x 2 = <<12*2=24>>24 pages today. So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday. There are 120 - 36 = <<120-36=84>>84 pages left to be read. Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages. The answer is \\boxed{42}.
39 |
40 | Question: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?
41 | Answer: He writes each friend 3*2=<<3*2=6>>6 pages a week. So he writes 6*2=<<6*2=12>>12 pages every week. That means he writes 12*52=<<12*52=624>>624 pages a year. The answer is \\boxed{624}.
42 |
43 | Question: {question}
44 | Answer:
45 | '''
46 |
47 | Least_to_Most_1_shot = cot_pre + '''Question: Elsa has 5 apples. Anna has 2 more apples than Elsa. How many apples do they have together?
48 | Answer: Let's break down this problem: 1. How many apples does Anna have? 2. How many apples do they have together?
49 | 1. Anna has 2 more apples than Elsa. So Anna has 2 + 5 = 7 apples.
50 | 2. Elsa and Anna have 5 + 7 = 12 apples together.
51 | The answer is: \\boxed{12}.
52 |
53 | Question: {question}
54 | Answer:
55 | '''
56 |
57 | tot_post = '''
58 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
59 |
60 | tot_3_solutions = '''
61 | Question: {question}
62 |
63 | Solution 1: {solution1}
64 | Solution 2: {solution2}
65 | Solution 3: {solution3}'''
66 |
67 | tot_5_solutions = tot_3_solutions + '''
68 | Solution 4: {solution4}
69 | Solution 5: {solution5}'''
70 |
71 | tot_10_solutions = tot_5_solutions + '''
72 | Solution 6: {solution6}
73 | Solution 7: {solution7}
74 | Solution 8: {solution8}
75 | Solution 9: {solution9}
76 | Solution 10: {solution10}'''
77 |
78 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
79 |
80 | # Initial Problem:
81 | {question}
82 |
83 | # Instructions:
84 | ## Relevant Problems:
85 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
86 | - After "Q: ", describe the problem.
87 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
88 |
89 | ## Solve the Initial Problem:
90 | Q: Copy and paste the initial problem here.
91 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
92 | '''
93 |
94 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
95 |
96 | # Initial Problem:
97 | {question}
98 |
99 | # Instructions:
100 | ## Relevant Problems:
101 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
102 | - After "Q: ", describe the problem.
103 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
104 |
105 | ## Solve the Initial Problem:
106 | Q: Copy and paste the initial problem here.
107 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
108 | '''
109 |
110 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
111 |
112 | # Initial Problem:
113 | {question}
114 |
115 | # Instructions:
116 | ## Relevant Problems:
117 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
118 | - After "Q: ", describe the problem.
119 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
120 |
121 | ## Solve the Initial Problem:
122 | Q: Copy and paste the initial problem here.
123 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
124 | '''
125 |
126 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem.
127 | Question:
128 | {question}
129 |
130 | Principles involved:
131 | '''
132 |
133 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + '''
134 | Question:
135 | {question}
136 |
137 | Principles:
138 | {principles}
139 |
140 | Answer:
141 | '''
142 |
143 |
--------------------------------------------------------------------------------
/prompts/GSM_Hard.py:
--------------------------------------------------------------------------------
1 | prompt_format = " The given information may not conform to common sense and the result may be a nonsense decimal or negative number, it's okay, output it instead of considering it is unreasonable. Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response."
2 |
3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words."
4 |
5 | io = "{question}" + prompt_format
6 |
7 | io_briefly = io + " You should answer with no more than 200 words."
8 |
9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n'
10 |
11 | cot_0_shot = cot_pre + '''Question: {question} Let's think step by step.
12 | Answer:'''
13 |
14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
16 |
17 | # Question: {question}
18 | # Answer:
19 | # '''
20 |
21 | cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48564 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
22 | Answer: Natalia sold 48564/2 = <<48564/2=24282>>24282 clips in May. Natalia sold 48564+24282 = <<48564+24282=72846>>72846 clips altogether in April and May. The answer is \\boxed{72846}.
23 |
24 | Question: {question}
25 | Answer:
26 | '''
27 |
28 | cot_5_shot = cot_pre + '''Question: Natalia sold clips to 48564 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
29 | Answer: Natalia sold 48564/2 = <<48564/2=24282>>24282 clips in May. Natalia sold 48564+24282 = <<48564+24282=72846>>72846 clips altogether in April and May. The answer is \\boxed{72846}.
30 |
31 | Question: Weng earns $1293 an hour for babysitting. Yesterday, she just did 612 minutes of babysitting. How much did she earn?
32 | Answer: Weng earns 1293/60 = $<<1293/60=21.55>>21.55 per minute. Working 612 minutes, she earned 21.55 x 612 = $<<21.55*612=13188.6>>13188.6. The answer is \\boxed{13188.6}.
33 |
34 | Question: Betty is saving money for a new wallet which costs $8200. Betty has only half of the money she needs. Her parents decided to give her $1525 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?"
35 | Answer: In the beginning, Betty has only 8200 / 2 = $<<8200/2=4100>>4100. Betty's grandparents gave her 1525 * 2 = $<<1525*2=3050>>3050. This means, Betty needs 8200 - 4100 - 3050 - 1525 = $<<8200-4100-3050-1525=-475>>-475 more. The answer is \\boxed{-475}.
36 |
37 | Question: Julie is reading a 12602-page book. Yesterday, she was able to read 3127 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?
38 | Answer: Maila read 3127 x 2 = <<3127*2=6254>>6254 pages today. So she was able to read a total of 3127 + 6254 = <<3127+6254=9381>>9381 pages since yesterday. There are 12602 - 9381 = <<12602-9381=3221>>3221 pages left to be read. Since she wants to read half of the remaining pages tomorrow, then she should read 3221/2 = <<3221/2=1610.5>>1610.5 pages. The answer is \\boxed{1610.5}.
39 |
40 | Question: James writes a 312996-page letter to 2143 different friends twice a week. How many pages does he write a year?
41 | Answer: He writes each friend 312996*2143=<<312996*2143=670750428>>670750428 pages a week. So he writes 670750428*2=<<670750428*2=1341500856>>1341500856 pages every week. That means he writes 1341500856*52=<<1341500856*52=69758044512>>69758044512 pages a year. The answer is \\boxed{69758044512}.
42 |
43 | Question: {question}
44 | Answer:
45 | '''
46 |
47 | Least_to_Most_1_shot = cot_pre + '''Question: Elsa has 524866 apples. Anna has 432343 more apples than Elsa. How many apples do they have together?
48 | Answer: Let's break down this problem: 1. How many apples does Anna have? 2. How many apples do they have together?
49 | 1. Anna has 432343 more apples than Elsa. So Anna has 524866 + 432343 = 957209 apples.
50 | 2. Elsa and Anna have 524866 + 957209 = 1482075 apples together.
51 | The answer is: \\boxed{1482075}.
52 |
53 | Question: {question}
54 | Answer:
55 | '''
56 |
57 | tot_post = '''
58 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
59 |
60 | tot_3_solutions = '''
61 | Question: {question}
62 |
63 | Solution 1: {solution1}
64 | Solution 2: {solution2}
65 | Solution 3: {solution3}'''
66 |
67 | tot_5_solutions = tot_3_solutions + '''
68 | Solution 4: {solution4}
69 | Solution 5: {solution5}'''
70 |
71 | tot_10_solutions = tot_5_solutions + '''
72 | Solution 6: {solution6}
73 | Solution 7: {solution7}
74 | Solution 8: {solution8}
75 | Solution 9: {solution9}
76 | Solution 10: {solution10}'''
77 |
78 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
79 |
80 | # Initial Problem:
81 | {question}
82 |
83 | # Instructions:
84 | ## Relevant Problems:
85 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
86 | - After "Q: ", describe the problem.
87 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
88 |
89 | ## Solve the Initial Problem:
90 | Q: Copy and paste the initial problem here.
91 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
92 | '''
93 |
94 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
95 |
96 | # Initial Problem:
97 | {question}
98 |
99 | # Instructions:
100 | ## Relevant Problems:
101 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
102 | - After "Q: ", describe the problem.
103 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
104 |
105 | ## Solve the Initial Problem:
106 | Q: Copy and paste the initial problem here.
107 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
108 | '''
109 |
110 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
111 |
112 | # Initial Problem:
113 | {question}
114 |
115 | # Instructions:
116 | ## Relevant Problems:
117 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
118 | - After "Q: ", describe the problem.
119 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
120 |
121 | ## Solve the Initial Problem:
122 | Q: Copy and paste the initial problem here.
123 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
124 | '''
125 |
126 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem.
127 | Question:
128 | {question}
129 |
130 | Principles involved:
131 | '''
132 |
133 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + '''
134 | Question:
135 | {question}
136 |
137 | Principles:
138 | {principles}
139 |
140 | Answer:
141 | '''
142 |
143 |
--------------------------------------------------------------------------------
/prompts/MATH.py:
--------------------------------------------------------------------------------
1 | prompt_format = " Your final result should be in the form \\boxed{answer}, at the end of your response."
2 |
3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words."
4 |
5 | io = "{question}" + prompt_format
6 |
7 | io_briefly = io + " You should answer with no more than 200 words."
8 |
9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n'
10 |
11 | cot_0_shot = cot_pre + '''Question: {question}
12 | Answer: Let's think step by step.'''
13 |
14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}.
16 |
17 | # Question: {question}
18 | # Answer:
19 | # '''
20 |
21 | cot_1_shot = cot_pre + '''Question: Ryan has 3 red lava lamps and 3 blue lava lamps. He arranges them in a row on a shelf randomly, then turns 3 random lamps on. What is the probability that the leftmost lamp on the shelf is red, and the leftmost lamp which is turned on is also red?
22 | Answer: There are $\\binom{6}{3}=20$ ways for Ryan to arrange the lamps, and $\\binom{6}{3}=20$ ways for him to choose which lamps are on, giving $20\\cdot20=400$ total possible outcomes. There are two cases for the desired outcomes: either the left lamp is on, or it isn't. If the left lamp is on, there are $\\binom{5}{2}=10$ ways to choose which other lamps are on, and $\\binom{5}{2}=10$ ways to choose which other lamps are red. This gives $10\\cdot10=100$ possibilities. If the first lamp isn't on, there are $\\binom{5}{3}=10$ ways to choose which lamps are on, and since both the leftmost lamp and the leftmost lit lamp must be red, there are $\\binom{4}{1}=4$ ways to choose which other lamp is red. This case gives 40 valid possibilities, for a total of 140 valid arrangements out of 400. Therefore, the probability is $\\dfrac{140}{400}=\\boxed{\\dfrac{7}{20}}$.
23 |
24 | Question: {question}
25 | Answer:
26 | '''
27 |
28 | cot_5_shot = cot_pre + '''Question: Ryan has 3 red lava lamps and 3 blue lava lamps. He arranges them in a row on a shelf randomly, then turns 3 random lamps on. What is the probability that the leftmost lamp on the shelf is red, and the leftmost lamp which is turned on is also red?
29 | Answer: There are $\\binom{6}{3}=20$ ways for Ryan to arrange the lamps, and $\\binom{6}{3}=20$ ways for him to choose which lamps are on, giving $20\\cdot20=400$ total possible outcomes. There are two cases for the desired outcomes: either the left lamp is on, or it isn't. If the left lamp is on, there are $\\binom{5}{2}=10$ ways to choose which other lamps are on, and $\\binom{5}{2}=10$ ways to choose which other lamps are red. This gives $10\\cdot10=100$ possibilities. If the first lamp isn't on, there are $\\binom{5}{3}=10$ ways to choose which lamps are on, and since both the leftmost lamp and the leftmost lit lamp must be red, there are $\\binom{4}{1}=4$ ways to choose which other lamp is red. This case gives 40 valid possibilities, for a total of 140 valid arrangements out of 400. Therefore, the probability is $\\dfrac{140}{400}=\\boxed{\\dfrac{7}{20}}$.
30 |
31 | Question: On the $xy$-plane, the origin is labeled with an $M$. The points $(1,0)$, $(-1,0)$, $(0,1)$, and $(0,-1)$ are labeled with $A$'s. The points $(2,0)$, $(1,1)$, $(0,2)$, $(-1, 1)$, $(-2, 0)$, $(-1, -1)$, $(0, -2)$, and $(1, -1)$ are labeled with $T$'s. The points $(3,0)$, $(2,1)$, $(1,2)$, $(0, 3)$, $(-1, 2)$, $(-2, 1)$, $(-3, 0)$, $(-2,-1)$, $(-1,-2)$, $(0, -3)$, $(1, -2)$, and $(2, -1)$ are labeled with $H$'s. If you are only allowed to move up, down, left, and right, starting from the origin, how many distinct paths can be followed to spell the word MATH?
32 | Answer: From the M, we can proceed to four different As. Note that the letters are all symmetric, so we can simply count one case (say, that of moving from M to the bottom A) and then multiply by four.\n\nFrom the bottom A, we can proceed to any one of three Ts. From the two Ts to the sides of the A, we can proceed to one of two Hs. From the T that is below the A, we can proceed to one of three Hs. Thus, this case yields $2 \\cdot 2 + 3 = 7$ paths.\n\nThus, there are $4 \\cdot 7 = \\boxed{28}$ distinct paths.
33 |
34 | Question: Factor the following expression: $55z^{17}+121z^{34}$.
35 | Answer: The greatest common factor of the two coefficients is $11$, and the greatest power of $z$ that divides both terms is $z^{17}$. So, we factor $11z^{17}$ out of both terms:\n\n\\begin{align*}\n55z^{17}+121z^{34} &= 11z^{17}\\cdot 5 +11z^{17}\\cdot 11z^{17}\\\\\n&= \\boxed{11z^{17}(5+11z^{17})}\n\\end{align*}.
36 |
37 | Question: Allen and Ben are painting a fence. The ratio of the amount of work Allen does to the amount of work Ben does is $3:5$. If the fence requires a total of $240$ square feet to be painted, how many square feet does Ben paint?
38 | Answer: Between them, Allen and Ben are dividing the work into $8$ equal parts, $3$ of which Allen does and $5$ of which Ben does. Each part of the work requires $\\frac{240}{8} = 30$ square feet to be painted. Since Ben does $5$ parts of the work, he will paint $30 \\cdot 5 = \\boxed{150}$ square feet of the fence.
39 |
40 | Question: Suppose $z$ and $w$ are complex numbers such that\n\\[|z| = |w| = z \\overline{w} + \\overline{z} w= 1.\\]Find the largest possible value of the real part of $z + w.$
41 | Answer: Let $z = a + bi$ and $w = c + di,$ where $a,$ $b,$ $c,$ and $d$ are complex numbers. Then from $|z| = 1,$ $a^2 + b^2 = 1,$ and from $|w| = 1,$ $c^2 + d^2 = 1.$ Also, from $z \\overline{w} + \\overline{z} w = 1,$\n\\[(a + bi)(c - di) + (a - bi)(c + di) = 1,\\]so $2ac + 2bd = 1.$\n\nThen\n\\begin{align*}\n(a + c)^2 + (b + d)^2 &= a^2 + 2ac + c^2 + b^2 + 2bd + d^2 \\\\\n&= (a^2 + b^2) + (c^2 + d^2) + (2ac + 2bd) \\\\\n&= 3.\n\\end{align*}The real part of $z + w$ is $a + c,$ which can be at most $\\sqrt{3}.$ Equality occurs when $z = \\frac{\\sqrt{3}}{2} + \\frac{1}{2} i$ and $w = \\frac{\\sqrt{3}}{2} - \\frac{1}{2} i,$ so the largest possible value of $a + c$ is $\\boxed{\\sqrt{3}}.$
42 |
43 | Question: {question}
44 | Answer:
45 | '''
46 |
47 | Least_to_Most_0_shot = cot_pre + ''' In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers.
48 |
49 | Question: {question}
50 | Answer:
51 | '''
52 |
53 | tot_post = '''
54 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
55 |
56 | tot_3_solutions = '''
57 | Question: {question}
58 |
59 | Solution 1: {solution1}
60 | Solution 2: {solution2}
61 | Solution 3: {solution3}'''
62 |
63 | tot_5_solutions = tot_3_solutions + '''
64 | Solution 4: {solution4}
65 | Solution 5: {solution5}'''
66 |
67 | tot_10_solutions = tot_5_solutions + '''
68 | Solution 6: {solution6}
69 | Solution 7: {solution7}
70 | Solution 8: {solution8}
71 | Solution 9: {solution9}
72 | Solution 10: {solution10}'''
73 |
74 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
75 |
76 | # Initial Problem:
77 | {question}
78 |
79 | # Instructions:
80 | ## Relevant Problems:
81 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
82 | - After "Q: ", describe the problem.
83 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
84 |
85 | ## Solve the Initial Problem:
86 | Q: Copy and paste the initial problem here.
87 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
88 | '''
89 |
90 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
91 |
92 | # Initial Problem:
93 | {question}
94 |
95 | # Instructions:
96 | ## Relevant Problems:
97 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
98 | - After "Q: ", describe the problem.
99 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
100 |
101 | ## Solve the Initial Problem:
102 | Q: Copy and paste the initial problem here.
103 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
104 | '''
105 |
106 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
107 |
108 | # Initial Problem:
109 | {question}
110 |
111 | # Instructions:
112 | ## Relevant Problems:
113 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
114 | - After "Q: ", describe the problem.
115 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}.
116 |
117 | ## Solve the Initial Problem:
118 | Q: Copy and paste the initial problem here.
119 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here.
120 | '''
121 |
122 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem.
123 | Question:
124 | {question}
125 |
126 | Principles involved:
127 | '''
128 |
129 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + '''
130 | Question:
131 | {question}
132 |
133 | Principles:
134 | {principles}
135 |
136 | Answer:
137 | '''
138 |
139 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
Rethinking the Role of Prompting Strategies in LLM Test-Time Scaling: A Perspective of Probability Theory
3 |
4 |
12 |
13 | 1MAIS, Institute of Automation, Chinese Academy of Sciences
14 | 2School of Artificial Intelligence, University of Chinese Academy of Sciences
15 | 3University of California, Santa Barbara
16 | 4Beijing Wenge Technology Co., Ltd 5Nanjing University
17 | *Corresponding Author
18 |
19 |
20 |
21 |
22 | ACL 2025 Main 🏆 Outstanding Paper Award
23 |
24 |
25 |
26 | [](https://aclanthology.org/2025.acl-long.1356/) [](https://arxiv.org/abs/2505.10981) [](https://opensource.org/licenses/MIT) [](https://www.python.org/downloads/)
27 |
28 |
29 | ## 📑 Brief Introduction
30 |
31 | ### Abstract
32 | Recently, scaling test-time compute on Large Language Models (LLM) has garnered wide attention. However, there has been limited investigation of how various reasoning prompting strategies perform as scaling. In this paper, we focus on a standard and realistic scaling setting: majority voting. We systematically conduct experiments on 6 LLMs $\times$ 8 prompting strategies $\times$ 6 benchmarks. Experiment results consistently show that as the sampling time and computational overhead increase, complicated prompting strategies with superior initial performance gradually fall behind simple Chain-of-Thought. We analyze this phenomenon and provide theoretical proofs. Additionally, we propose a probabilistic method to efficiently predict scaling performance and identify the best prompting strategy under large sampling times, eliminating the need for resource-intensive inference processes in practical applications. Furthermore, we introduce two ways derived from our theoretical analysis to significantly improve the scaling performance. We hope that our research can promote to re-examine the role of complicated prompting, unleash the potential of simple prompting strategies, and provide new insights for enhancing test-time scaling performance.
33 |
34 | ### Contributions
35 |
36 | 1. **Comprehensive experiments.** Our study covers a wide range - 6 LLMs $\times$ 8 prompting strategies $\times$ 6 benchmarks, providing sufficient evidence and context to fully support the claim.
37 | 2. **Valuable findings breaking the conventional wisdom.** Our extensive experiments consistently demonstrate that a complex prompting strategy with higher pass@1 accuracy may not always be better as test-time scaling, while simple CoT/DiP gradually dominates even if with an initial inferior performance.
38 | 3. **Rigorous theoretical analysis.** We provide an in-depth probability theoretic backed explanation of what leads to more rapid improvements with scale.
39 | - 3.1 **Definition of easy and hard questions by answer distribution.** The difficulty of the question is not only related to pass@1 accuracy, but determined by the probability distribution of all possible answer outputs. The accuracy of easy questions increases as scaling while hard questions reduces.
40 | - 3.2 **Disturbed peaks of wrong answer distribution.** Scaling performance is affected by enormous answer distribution, and we quantify this with our theory.
41 | 4. **Practical $O(1)$ approach to predict scaling performance without resource-intensive inference**.
42 | 5. **Two effective and general methods to significantly improve scaling performance verified on multiple models and datasets.** Combining the two methods will lead to much more improvements, e.g., improving Majority@10 accuracy from 15.2% to 61.0% with LLaMA-3-8B-Instruct on MATH-500.
43 | - 5.1 **Adaptively scaling based on the question difficulty.**
44 | - 5.2 **Dynamically selecting the optimal prompting strategy based on our theory.**
45 |
46 |
47 | ## 🔍 Features
48 |
49 | - Support for multiple LLM backends (VLLM, Gemini, OpenAI and other API-based models)
50 | - You can specify **any model** according to your needs.
51 | - Various reasoning prompting strategies:
52 | - Non-Iterative:
53 | - **DiP**: Direct Prompting
54 | - **CoT**: [Chain of Thought Prompting](https://proceedings.neurips.cc/paper_files/paper/2022/hash/9d5609613524ecf4f15af0f7b31abca4-Abstract-Conference.html?ref=https://githubhelp.com)
55 | - **L2M**: [Least-to-Most Prompting](https://arxiv.org/abs/2205.10625)
56 | - **SBP**: [Step-Back Prompting](https://arxiv.org/abs/2310.06117)
57 | - **AnP**: [Analogous Prompting](https://arxiv.org/abs/2310.01714)
58 | - Iterative:
59 | - **ToT**: [Tree of Thoughts](https://proceedings.neurips.cc/paper_files/paper/2023/hash/271db9922b8d1f4dd7aaef84ed5ac703-Abstract-Conference.html)
60 | - **S-RF**: [Self-Refine](https://proceedings.neurips.cc/paper_files/paper/2023/hash/91edff07232fb1b55a505a9e9f6c0ff3-Abstract-Conference.html)
61 | - **MAD**: [Multi-Agent Debate](https://dl.acm.org/doi/abs/10.5555/3692070.3692537)
62 | - Extensive dataset support:
63 | - Mathematical reasoning:
64 | - [GSM8K](https://arxiv.org/abs/2110.14168)
65 | - [GSM-Hard](https://proceedings.mlr.press/v202/gao23f)
66 | - [MATH](https://arxiv.org/abs/2103.03874)
67 | - [AIME_2024](https://modelscope.cn/datasets/AI-ModelScope/AIME_2024)
68 | - Scientific reasoning:
69 | - [GPQA](https://arxiv.org/abs/2311.12022)
70 | - [MMLU](https://arxiv.org/abs/2009.03300)
71 | - MMLU-high_school_physics
72 | - MMLU-high_school_chemistry
73 | - MMLU-high_school_biology
74 | - Two different budgets for evaluation:
75 | - Sampling time
76 | - Computation overhead (Cost)
77 |
78 | ## 🛠️ Installation
79 |
80 | 1. Clone the repository:
81 | ```bash
82 | git clone https://github.com/MraDonkey/rethinking_prompting.git
83 | cd rethinking_prompting
84 | ```
85 |
86 | 2. Create conda environment and install dependencies:
87 | ```bash
88 | conda create -n rethinking_prompting python=3.11
89 | conda activate rethinking_prompting
90 | pip install -r requirements.txt
91 | ```
92 |
93 | ## ⚙️ Configuration
94 |
95 | Before running the framework, you need to set up your API keys for different LLM providers:
96 |
97 | - For vllm models:
98 | - You may need to login huggingface to get access to some LLMs.
99 | - For OpenAI models or other OpenAI-like API-based models:
100 | - Set api_key `openai_api_key`
101 | - Set base_url `openai_base_url`
102 | - For Google Gemini:
103 | - Set `google_api_key`
104 |
105 | Complete the variables `hf_token` in `main.py` and `base_path` in `dataset.py`.
106 |
107 | ## 🪛 Usage
108 |
109 | For example, to get the inference results of all prompting strategies with Qwen2.5-7B-Instruct on GSM8K, you can run this script.
110 |
111 | ```bash
112 | bash scripts/Qwen_GSM8K.sh
113 | ```
114 |
115 | You can further customize hyperparameters to suit your specific requirements.
116 |
117 | ## 🔬 Evaluation
118 |
119 | To evaluate the performance of all tested prompting strategies:
120 |
121 | ```bash
122 | python eval_csv_N.py --model_name "your_model" --dataset "your_dataset"
123 | python eval_csv_cost.py --model_name "your_model" --dataset "your_dataset"
124 | ```
125 |
126 | You can customize the variable `sampling_times` to adjust the points in the figure, in the style of Figure 1 and 2 in our [paper](https://arxiv.org/abs/2505.10981).
127 |
128 | 
129 |
130 | ## ✒️ Citation
131 |
132 | Should you find our work beneficial to your research, we would appreciate citations to our paper and GitHub stars to support ongoing development. ⭐
133 |
134 | ```bibtex
135 | @inproceedings{liu-etal-2025-rethinking,
136 | title = "Rethinking the Role of Prompting Strategies in {LLM} Test-Time Scaling: A Perspective of Probability Theory",
137 | author = "Liu, Yexiang and
138 | Li, Zekun and
139 | Fang, Zhi and
140 | Xu, Nan and
141 | He, Ran and
142 | Tan, Tieniu",
143 | editor = "Che, Wanxiang and
144 | Nabende, Joyce and
145 | Shutova, Ekaterina and
146 | Pilehvar, Mohammad Taher",
147 | booktitle = "Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
148 | month = jul,
149 | year = "2025",
150 | address = "Vienna, Austria",
151 | publisher = "Association for Computational Linguistics",
152 | url = "https://aclanthology.org/2025.acl-long.1356/",
153 | doi = "10.18653/v1/2025.acl-long.1356",
154 | pages = "27962--27994",
155 | ISBN = "979-8-89176-251-0"
156 | }
157 | ```
158 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import time
2 | from dataset import parse_best_solution, parse_answer
3 |
4 | import google.generativeai as genai # pip install -q -U google-generativeai
5 | from google.generativeai.types import HarmCategory, HarmBlockThreshold
6 | from openai import OpenAI
7 |
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 | import pdb
10 | from tqdm import tqdm
11 |
12 |
13 | def get_messages(args):
14 | messages = []
15 | assert args.messages != None or args.query != None
16 | if "gemini" not in args.model_name:
17 | if args.messages != None:
18 | roles = ['user', 'assistant']
19 | for i in range(0, len(args.messages)):
20 | messages_i = []
21 | assert len(args.messages[i]) % 2 == 1
22 | if args.system != None:
23 | messages_i.append({"role": "system", "content": args.system})
24 | for j, message in enumerate(args.messages[i]):
25 | messages_i.append({'role': roles[j%2], 'content': message})
26 | messages.append(messages_i)
27 | else:
28 | for query in args.query:
29 | if args.system == None:
30 | messages.append([{"role": "user", "content": query}])
31 | else:
32 | messages.append([{"role": "system", "content": args.system},
33 | {"role": "user", "content": query}])
34 | else:
35 | if args.messages != None:
36 | roles = ['user', 'model']
37 | for i in range(len(args.messages)):
38 | messages_i = []
39 | assert len(args.messages[i]) % 2 == 1
40 | for j, message in enumerate(args.messages[i]):
41 | messages_i.append({'role': roles[j%2], 'parts': [message]})
42 | messages.append(messages_i)
43 | else:
44 | for query in args.query:
45 | messages.append([{"role": "user", "parts": query}])
46 | return messages
47 |
48 |
49 | class Gemini:
50 | def __init__(self, model="gemini-1.5-flash", N=3):
51 | self.model = genai.GenerativeModel(model)
52 | self.count_limit = N
53 | self.safety_settings = {
54 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
55 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
56 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
57 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
58 | }
59 |
60 | def get_response(self, messages, system_instruction = None, n = 1) -> str:
61 | # Query the model
62 | records = []
63 | counts = 0
64 | while len(records) < n and counts < self.count_limit:
65 | if system_instruction != None:
66 | self.system_instruction = system_instruction
67 | try:
68 | num = n
69 | records = []
70 | while (len(records) < n):
71 | time_begin = time.time()
72 | response = self.model.generate_content(messages,
73 | safety_settings=self.safety_settings,
74 | generation_config=genai.types.GenerationConfig(candidate_count = num))
75 | time_completion = time.time()
76 | for i in range(0, num):
77 | if response.candidates[i].finish_reason == 1:
78 | output = response.candidates[i].content.parts[0].text.strip()
79 | completion_tokens = self.model.count_tokens(output).total_tokens
80 | usage = {"prompt_tokens": response.usage_metadata.prompt_token_count,
81 | "completion_tokens": completion_tokens,
82 | "time_prompt": 0,
83 | "time_completion": time_completion - time_begin}
84 | record = {"output": output, "usage": usage}
85 | records.append(record)
86 | num = n - len(record)
87 |
88 | except Exception as error:
89 | response = ""
90 | print(error)
91 | print('Sleeping for 10 seconds')
92 | time.sleep(10)
93 | counts += 1
94 | if counts == self.count_limit:
95 | raise EOFError("")
96 | return records
97 |
98 |
99 | def load_model(args):
100 | if args.model_type == "vllm":
101 | from vllm import LLM
102 | args.model = LLM(model=args.model_name, trust_remote_code=True, tensor_parallel_size=len(args.gpu.split(",")), dtype = args.dtype)
103 | elif args.model_type == "gemini":
104 | genai.configure(api_key=args.google_api_key)
105 | args.model = Gemini(model=args.model_name)
106 | elif args.model_type == "openai":
107 | args.client = OpenAI(api_key=args.openai_api_key, base_url=args.openai_base_url)
108 |
109 |
110 | def gpt_parallel_generate(args, message):
111 | records = []
112 | res = args.client.chat.completions.create(
113 | model=args.model_name,
114 | messages = message,
115 | n = args.num,
116 | logprobs = True,
117 | max_tokens = args.max_new_tokens
118 | )
119 | for i in range(0, args.num):
120 | output = res.choices[i].message.content
121 | completion_tokens = len(res.choices[i].logprobs.content)
122 | usage = {
123 | "prompt_tokens": res.usage.prompt_tokens,
124 | "completion_tokens": completion_tokens
125 | }
126 | record = {
127 | "output": output,
128 | "usage": usage
129 | }
130 | records.append(record)
131 | return records
132 |
133 |
134 | def LLM_generate(args):
135 | """
136 | Generate responses using different LLM backends with parallel processing support.
137 |
138 | Args:
139 | args: Arguments containing model configuration and generation parameters
140 |
141 | Returns:
142 | List of generated records with outputs and usage statistics
143 | """
144 |
145 | assert args.messages != None or args.query != None
146 | messages = []
147 | messages = get_messages(args)
148 | outputs = []
149 | records = []
150 | prompt_tokens = []
151 | completion_tokens = []
152 | if args.model_type == "gemini":
153 | if args.max_num_workers == 1:
154 | for i in tqdm(range(0, len(messages))):
155 | records_ = args.model.get_response(messages[i], args.system, args.num)
156 | records.append(records_)
157 | else:
158 | records = [None] * len(messages)
159 | with ThreadPoolExecutor(max_workers=args.max_num_workers) as executor:
160 | future_to_index = {}
161 | for index, message in enumerate(messages):
162 | future = executor.submit(args.model.get_response, message, args.system, args.num)
163 | future_to_index[future] = index
164 | with tqdm(total=len(messages)) as pbar:
165 | for future in as_completed(future_to_index):
166 | idx = future_to_index[future]
167 | try:
168 | data = future.result()
169 | records[idx] = data
170 | pbar.update(1)
171 | except Exception as exc:
172 | print(f'Index {idx} generated exception: {exc}')
173 | pdb.set_trace()
174 |
175 | for i in range(0, len(messages)):
176 | for j in range(0, args.num):
177 | if "tot" in args.reasoning:
178 | output_key = parse_best_solution(records[i][j]["output"])
179 | else:
180 | output_key = parse_answer(args, records[i][j]["output"])
181 | records[i][j]["output_key"] = output_key
182 |
183 | elif args.model_type == "openai":
184 | if args.max_num_workers == 1:
185 | for i in tqdm(range(0, len(messages))):
186 | records_i = gpt_parallel_generate(args, messages[i])
187 | records.append(records_i)
188 | else:
189 | records = [None] * len(messages)
190 | with ThreadPoolExecutor(max_workers=args.max_num_workers) as executor:
191 | future_to_index = {}
192 | for index, message in enumerate(messages):
193 | future = executor.submit(gpt_parallel_generate, args, message)
194 | future_to_index[future] = index
195 | with tqdm(total=len(messages)) as pbar:
196 | for future in as_completed(future_to_index):
197 | idx = future_to_index[future]
198 | try:
199 | data = future.result()
200 | records[idx] = data
201 | pbar.update(1)
202 | except Exception as exc:
203 | print(f'Index {idx} generated exception: {exc}')
204 | pdb.set_trace()
205 |
206 | for i in range(0, len(messages)):
207 | for j in range(0, args.num):
208 | if "tot" in args.reasoning:
209 | output_key = parse_best_solution(records[i][j]["output"])
210 | else:
211 | output_key = parse_answer(args, records[i][j]["output"])
212 | records[i][j]["output_key"] = output_key
213 |
214 | elif args.model_type == "vllm":
215 | from vllm import SamplingParams
216 | sampling_params = SamplingParams(temperature=args.temperature, n = args.num, max_tokens = args.max_new_tokens, top_p = 0.9)
217 | res = args.model.chat(messages=messages, sampling_params=sampling_params)
218 | for i in range(0, len(res)):
219 | outputs.append([output.text for output in res[i].outputs])
220 | prompt_tokens.append(len(res[i].prompt_token_ids))
221 | completion_tokens.append([len(output.token_ids) for output in res[i].outputs])
222 |
223 | for i in range(0, len(outputs)):
224 | outputs_i = outputs[i]
225 | completion_tokens_i = completion_tokens[i]
226 | prompt_tokens_i = prompt_tokens[i]
227 | records_i = []
228 | for j in range(0, args.num):
229 | usage = {"prompt_tokens": prompt_tokens_i,
230 | "completion_tokens": completion_tokens_i[j]
231 | }
232 | output = outputs_i[j]
233 | if "tot" in args.reasoning:
234 | output_key = parse_best_solution(output)
235 | else:
236 | output_key = parse_answer(args, output)
237 | record = {"output": output, "output_key": output_key, "usage": usage}
238 | records_i.append(record)
239 | records.append(records_i)
240 | else:
241 | raise NotImplementedError(f"Model type \"{args.model_type}\" not supported (should be \"vllm\", \"gemini\" or \"openai\").")
242 | return records
243 |
--------------------------------------------------------------------------------
/prompts/GPQA.py:
--------------------------------------------------------------------------------
1 | prompt_format = '''Please choose the correct choice. Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\".'''
2 |
3 | GPQA_prompt = '''Question: {question}
4 |
5 | Choices:
6 | (A) {choice1}
7 | (B) {choice2}
8 | (C) {choice3}
9 | (D) {choice4}
10 |
11 | '''
12 |
13 | io = GPQA_prompt + prompt_format
14 |
15 | cot_pre = prompt_format + '\n\n'
16 |
17 | cot_0_shot = GPQA_prompt + prompt_format + " Let's think step by step:"
18 |
19 | cot_1_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?
20 |
21 | Choices:
22 | (A) 1/400
23 | (B) 19/400
24 | (C) 20/400
25 | (D) 38/400
26 |
27 | Let's think step by step:
28 | The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population.
29 | According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele.
30 | Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95.
31 | The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400.
32 | The correct answer is (D).
33 |
34 | ''' + GPQA_prompt + "Let's think step by step:"
35 |
36 | cot_5_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?
37 |
38 | Choices:
39 | (A) 1/400
40 | (B) 19/400
41 | (C) 20/400
42 | (D) 38/400
43 |
44 | Let's think step by step:
45 | The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population.
46 | According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele.
47 | Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95.
48 | The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400.
49 | The correct answer is (D).
50 |
51 | Question: A Fe pellet of 0.056 g is first dissolved in 10 mL of hydrobromic acid HBr (0.1 M). The resulting solution is then titrated by KMnO4 (0.02 M). How many equivalence points are there?
52 |
53 | Choices:
54 | (A) Two points, 25 ml and 35 ml
55 | (B) One point, 25 mL
56 | (C) One point, 10 ml
57 | (D) Two points, 25 ml and 30 ml
58 |
59 | Let's think step by step:
60 | HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-.
61 | Two equivalence points will exist 25 ml and 35 ml.
62 | HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-.
63 | Two equivalence points will exist 25 ml and 35 ml.
64 | In the beaker there is Fe2+ and Br-.
65 | When considering titration with two analytes one will have to consider which reaction will occur first.
66 | Since it is a redox titration consider the reduction potential of:
67 | E0 (Br2 /Br- ) = 1.09 V E0 (MnO4-/ Mn2+) = 1.49 V E0 (Fe3+/Fe2+) =0.77 V
68 | [Fe2+]=m/MV=0.1M.
69 | Reaction 1: MnO4- + 5Fe2+ + 8H+ → Mn2+ + 5Fe3+ + 4H2O
70 | Reaction 2: 2MnO4- + 10Br- + 16H+ → 2Mn2+ + 5Br2 + 8H2O
71 | So MnO4- will first react with Fe2+ with a stoichiometry of 1:5 so Veq1 will be 10 ml.
72 | Then when Fe2+ is used up, MnO4- will react with Br- with a stoichiometry of 2:10 then V added will be 25 ml so Veq2=25+10=35 ml.
73 | The correct answer is (A).
74 |
75 | Question: Consider a quantum mechanical system containing a particle of mass $m$ moving in an istropic three dimensional potential of the form $V(r) = 1/2 m \omega^2 r^2$ corresponding to the acted force obeying Hooke’s law. Here, $\omega$ is the angular frequency of oscillation and $r$ is the radial distance of the particle from the origin in spherical polar coordinate. What is the value of energy of the third excited state, and how many linearly independent eigenfunctions are possible for the same energy eigenvalue?
76 |
77 | Choices:
78 | (A) 11 \pi^2 \hbar^2 / (2m r^2), 3
79 | (B) (9/2) \hbar \omega , 10
80 | (C) 11 \pi^2 \hbar^2 / (2m r^2), 10
81 | (D) (9/2) \hbar \omega, 3
82 |
83 | Let's think step by step:
84 | This problem is nothing but the three dimensional simple harmonic oscillator (SHO) problem.
85 | The energy spectrum of three dimensional SHO is $E_n= (n+3/2)\hbar \omega$ where $n=0,1,2,3….$.
86 | For third excited state n=3.
87 | 3+3/2=6/2+3/2=9/2.
88 | Thus the corresponding energy is $(9/2)\hbar \omega$.
89 | The degeneracy of the state is $g_n= (n+1)(n+2)/2$.
90 | For n=3, degeneracy is (3+1)*(3+2)/2=4*5/2=10.
91 | The correct answer is (B).
92 |
93 | Question: "Your overhear two chemists talking to each other as they leave a synthetic organic chemistry lab. One asks the other "So, how did it go?" The second chemist replies, "Not well - my compounds are on top of each other." What is the second chemist most likely referring to?"
94 |
95 | Choices:
96 | (A) The compounds they are working with have similar polarities.
97 | (B) The compounds they are working with have similar boiling points.
98 | (C) The compounds they are working with are bonding to each other through non-covalent/van der Waals interactions.
99 | (D) The compounds they are working with have similar optical rotations.
100 |
101 | Let's think step by step:
102 | "On top of each other" commonly refers to two compounds that have similar Rf values on chromatography (a common operation in synthetic chemistry).
103 | Similar Rf values arise for compounds with similar polarities.
104 | The correct answer is (A).
105 |
106 | Question: Two people are playing the following game. A fair coin is tossed into the air. Person A says that in a single toss of the coin, the tail will come. So it's like the first shot or the third shot or the fifth shot. Person B says that the coin will come with a double toss. So like the second, fourth, sixth or eighth shot. Imagine this game played forever. What is the probability that person A wins this game?
107 |
108 | Choices:
109 | (A) 1/2
110 | (B) 1/4
111 | (C) 2/3
112 | (D) 1/8
113 |
114 | Let's think step by step:
115 | When finding the correct answer, the probability of playing forever and the coin's single-point toss will be calculated.
116 | For example, a tail may appear on the first shot.
117 | This probability is 1/2. if the first toss doesn't come up, it shouldn't come to the second roll either, because the second throw is an even number.
118 | So it can come in the third shot.
119 | This is (1/2)(1/2)(1/2).
120 | So (1/2)^3=1/8.
121 | Or it could come on the fifth shot.
122 | This is (1/2)^5=1/32.
123 | This is actually a geometric series that goes on forever.
124 | We can write this series as follows.
125 | (1/2) + (1/2)^3 + (1/2)^5 + (1/2)^7 + ……….
126 | The solution for this series is as follows : a1/(1-r) where a1 is the first number and r is the sequence or r= a2/a1 or a3/a2 etc.
127 | a1=1/2
128 | r=(1/2)^2=1/4
129 | So a1/(1-r)=(1/2)/(1-1/4)=(1/2)/(3/4)=2/3.
130 | The correct answer is (C).
131 |
132 | ''' + GPQA_prompt + "Let's think step by step:"
133 |
134 | Least_to_Most_0_shot = GPQA_prompt + '''Please choose the correct choice. In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers.
135 | ''' + "Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\"."
136 |
137 | tot_post = '''
138 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
139 |
140 | tot_3_solutions = GPQA_prompt + '''
141 | Solution 1: {solution1}
142 | Solution 2: {solution2}
143 | Solution 3: {solution3}'''
144 |
145 | tot_5_solutions = tot_3_solutions + '''
146 | Solution 4: {solution4}
147 | Solution 5: {solution5}'''
148 |
149 | tot_10_solutions = tot_5_solutions + '''
150 | Solution 6: {solution6}
151 | Solution 7: {solution7}
152 | Solution 8: {solution8}
153 | Solution 9: {solution9}
154 | Solution 10: {solution10}'''
155 |
156 | anologous_1_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
157 | ''' + GPQA_prompt + '''
158 | # Instructions:
159 | ## Relevant Problems:
160 | Recall an example of the {subject} problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
161 | - After "Q: ", describe the problem.
162 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
163 |
164 | ## Solve the Initial Problem:
165 | Q: Copy and paste the initial problem here.
166 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
167 | '''
168 |
169 | anologous_3_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
170 | ''' + GPQA_prompt + '''
171 | # Instructions:
172 | ## Relevant Problems:
173 | Recall three examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
174 | - After "Q: ", describe the problem.
175 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
176 |
177 | ## Solve the Initial Problem:
178 | Q: Copy and paste the initial problem here.
179 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
180 | '''
181 |
182 | anologous_5_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
183 | ''' + GPQA_prompt + '''
184 | # Instructions:
185 | ## Relevant Problems:
186 | Recall five examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
187 | - After "Q: ", describe the problem.
188 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
189 |
190 | ## Solve the Initial Problem:
191 | Q: Copy and paste the initial problem here.
192 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
193 | '''
194 |
195 | SBP_extract = '''You are an expert at {subject}. Your task is to extract the {subject} concepts and principles involved in solving the problem.
196 | ''' + GPQA_prompt + '''
197 | Principles involved:
198 | '''
199 |
200 | SBP_answer = "You are an expert at {subject}. You are given a {subject} problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + '''
201 | ''' + GPQA_prompt + "\nInstruction:\n"+ prompt_format + '''
202 | Principles:
203 | {principles}
204 | ''' + '''
205 | Answer:
206 | '''
207 |
208 |
--------------------------------------------------------------------------------
/prompts/MMLU.py:
--------------------------------------------------------------------------------
1 | prompt_format = '''Please choose the correct choice. Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\".'''
2 |
3 | MMLU_prompt = '''Question:
4 | {question}
5 |
6 | Choices:
7 | (A) {choice1}
8 | (B) {choice2}
9 | (C) {choice3}
10 | (D) {choice4}
11 |
12 | '''
13 |
14 | io = MMLU_prompt + prompt_format
15 |
16 | cot_pre = prompt_format + '\n\n'
17 |
18 | cot_0_shot = MMLU_prompt + prompt_format + " Let's think step by step:"
19 |
20 | # cot_1_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?
21 |
22 | # Choices:
23 | # (A) 1/400
24 | # (B) 19/400
25 | # (C) 20/400
26 | # (D) 38/400
27 |
28 | # Let's think step by step:
29 | # The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population.
30 | # According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele.
31 | # Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95.
32 | # The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400.
33 | # The correct answer is (D).
34 |
35 | # ''' + MMLU_prompt + "Let's think step by step:"
36 |
37 | # cot_5_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?
38 |
39 | # Choices:
40 | # (A) 1/400
41 | # (B) 19/400
42 | # (C) 20/400
43 | # (D) 38/400
44 |
45 | # Let's think step by step:
46 | # The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population.
47 | # According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele.
48 | # Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95.
49 | # The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400.
50 | # The correct answer is (D).
51 |
52 | # Question: A Fe pellet of 0.056 g is first dissolved in 10 mL of hydrobromic acid HBr (0.1 M). The resulting solution is then titrated by KMnO4 (0.02 M). How many equivalence points are there?
53 |
54 | # Choices:
55 | # (A) Two points, 25 ml and 35 ml
56 | # (B) One point, 25 mL
57 | # (C) One point, 10 ml
58 | # (D) Two points, 25 ml and 30 ml
59 |
60 | # Let's think step by step:
61 | # HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-.
62 | # Two equivalence points will exist 25 ml and 35 ml.
63 | # HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-.
64 | # Two equivalence points will exist 25 ml and 35 ml.
65 | # In the beaker there is Fe2+ and Br-.
66 | # When considering titration with two analytes one will have to consider which reaction will occur first.
67 | # Since it is a redox titration consider the reduction potential of:
68 | # E0 (Br2 /Br- ) = 1.09 V E0 (MnO4-/ Mn2+) = 1.49 V E0 (Fe3+/Fe2+) =0.77 V
69 | # [Fe2+]=m/MV=0.1M.
70 | # Reaction 1: MnO4- + 5Fe2+ + 8H+ → Mn2+ + 5Fe3+ + 4H2O
71 | # Reaction 2: 2MnO4- + 10Br- + 16H+ → 2Mn2+ + 5Br2 + 8H2O
72 | # So MnO4- will first react with Fe2+ with a stoichiometry of 1:5 so Veq1 will be 10 ml.
73 | # Then when Fe2+ is used up, MnO4- will react with Br- with a stoichiometry of 2:10 then V added will be 25 ml so Veq2=25+10=35 ml.
74 | # The correct answer is (A).
75 |
76 | # Question: Consider a quantum mechanical system containing a particle of mass $m$ moving in an istropic three dimensional potential of the form $V(r) = 1/2 m \omega^2 r^2$ corresponding to the acted force obeying Hooke’s law. Here, $\omega$ is the angular frequency of oscillation and $r$ is the radial distance of the particle from the origin in spherical polar coordinate. What is the value of energy of the third excited state, and how many linearly independent eigenfunctions are possible for the same energy eigenvalue?
77 |
78 | # Choices:
79 | # (A) 11 \pi^2 \hbar^2 / (2m r^2), 3
80 | # (B) (9/2) \hbar \omega , 10
81 | # (C) 11 \pi^2 \hbar^2 / (2m r^2), 10
82 | # (D) (9/2) \hbar \omega, 3
83 |
84 | # Let's think step by step:
85 | # This problem is nothing but the three dimensional simple harmonic oscillator (SHO) problem.
86 | # The energy spectrum of three dimensional SHO is $E_n= (n+3/2)\hbar \omega$ where $n=0,1,2,3….$.
87 | # For third excited state n=3.
88 | # 3+3/2=6/2+3/2=9/2.
89 | # Thus the corresponding energy is $(9/2)\hbar \omega$.
90 | # The degeneracy of the state is $g_n= (n+1)(n+2)/2$.
91 | # For n=3, degeneracy is (3+1)*(3+2)/2=4*5/2=10.
92 | # The correct answer is (B).
93 |
94 | # Question: "Your overhear two chemists talking to each other as they leave a synthetic organic chemistry lab. One asks the other "So, how did it go?" The second chemist replies, "Not well - my compounds are on top of each other." What is the second chemist most likely referring to?"
95 |
96 | # Choices:
97 | # (A) The compounds they are working with have similar polarities.
98 | # (B) The compounds they are working with have similar boiling points.
99 | # (C) The compounds they are working with are bonding to each other through non-covalent/van der Waals interactions.
100 | # (D) The compounds they are working with have similar optical rotations.
101 |
102 | # Let's think step by step:
103 | # "On top of each other" commonly refers to two compounds that have similar Rf values on chromatography (a common operation in synthetic chemistry).
104 | # Similar Rf values arise for compounds with similar polarities.
105 | # The correct answer is (A).
106 |
107 | # Question: Two people are playing the following game. A fair coin is tossed into the air. Person A says that in a single toss of the coin, the tail will come. So it's like the first shot or the third shot or the fifth shot. Person B says that the coin will come with a double toss. So like the second, fourth, sixth or eighth shot. Imagine this game played forever. What is the probability that person A wins this game?
108 |
109 | # Choices:
110 | # (A) 1/2
111 | # (B) 1/4
112 | # (C) 2/3
113 | # (D) 1/8
114 |
115 | # Let's think step by step:
116 | # When finding the correct answer, the probability of playing forever and the coin's single-point toss will be calculated.
117 | # For example, a tail may appear on the first shot.
118 | # This probability is 1/2. if the first toss doesn't come up, it shouldn't come to the second roll either, because the second throw is an even number.
119 | # So it can come in the third shot.
120 | # This is (1/2)(1/2)(1/2).
121 | # So (1/2)^3=1/8.
122 | # Or it could come on the fifth shot.
123 | # This is (1/2)^5=1/32.
124 | # This is actually a geometric series that goes on forever.
125 | # We can write this series as follows.
126 | # (1/2) + (1/2)^3 + (1/2)^5 + (1/2)^7 + ……….
127 | # The solution for this series is as follows : a1/(1-r) where a1 is the first number and r is the sequence or r= a2/a1 or a3/a2 etc.
128 | # a1=1/2
129 | # r=(1/2)^2=1/4
130 | # So a1/(1-r)=(1/2)/(1-1/4)=(1/2)/(3/4)=2/3.
131 | # The correct answer is (C).
132 |
133 | # ''' + MMLU_prompt + "Let's think step by step:"
134 |
135 | Least_to_Most_0_shot = MMLU_prompt + '''Please choose the correct choice. In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers.
136 | ''' + "Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\"."
137 |
138 | tot_post = '''
139 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.'''
140 |
141 | tot_3_solutions = MMLU_prompt + '''
142 | Solution 1: {solution1}
143 | Solution 2: {solution2}
144 | Solution 3: {solution3}'''
145 |
146 | tot_5_solutions = tot_3_solutions + '''
147 | Solution 4: {solution4}
148 | Solution 5: {solution5}'''
149 |
150 | tot_10_solutions = tot_5_solutions + '''
151 | Solution 6: {solution6}
152 | Solution 7: {solution7}
153 | Solution 8: {solution8}
154 | Solution 9: {solution9}
155 | Solution 10: {solution10}'''
156 |
157 | anologous_1_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
158 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + '''
159 | # Instructions:
160 | ## Relevant Problems:
161 | Recall an example of the {subject} problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem:
162 | - After "Q: ", describe the problem.
163 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
164 |
165 | ## Solve the Initial Problem:
166 | Q: Copy and paste the initial problem here.
167 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
168 | '''
169 |
170 | anologous_3_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
171 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + '''
172 | # Instructions:
173 | ## Relevant Problems:
174 | Recall three examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
175 | - After "Q: ", describe the problem.
176 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
177 |
178 | ## Solve the Initial Problem:
179 | Q: Copy and paste the initial problem here.
180 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
181 | '''
182 |
183 | anologous_5_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem.
184 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + '''
185 | # Instructions:
186 | ## Relevant Problems:
187 | Recall five examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem:
188 | - After "Q: ", describe the problem.
189 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
190 |
191 | ## Solve the Initial Problem:
192 | Q: Copy and paste the initial problem here.
193 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\".
194 | '''
195 |
196 | SBP_extract = '''You are an expert at {subject}. Your task is to extract the {subject} concepts and principles involved in solving the problem.
197 | ''' + MMLU_prompt + '''
198 | Principles involved:
199 | '''
200 |
201 | SBP_answer = "You are an expert at {subject}. You are given a {subject} problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + '''
202 | ''' + MMLU_prompt + "\nInstruction:\n"+ prompt_format + '''
203 | Principles:
204 | {principles}
205 | ''' + '''
206 | Answer:
207 | '''
208 |
209 | SBP_extract_physics = '''You are an expert at Physics. Your task is to extract the Physics concepts and principles involved in solving the problem. Here is an example.
210 |
211 | Question:
212 | A spherical conductor carries a net charge. How is this charge distributed on the sphere?
213 |
214 | Choices:
215 | (A) The charge is evenly distributed on the surface.
216 | (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere.
217 | (C) The charge moves continually within the sphere.
218 | (D) The charge is distributed uniformly throughout the sphere.
219 |
220 | Principles Involved:
221 | Coulomb's Law: the force between two charged particles is proportional to the product of their charges and inversely proportional to the square of the distance between them, F = k * q1 * q2 / r, where F is the electric force, k is a constant, q1 and q2 are the charges the particles carry, and r is the distance between them.
222 |
223 | ''' + MMLU_prompt + '''Principles Involved:
224 | '''
225 |
226 | SBP_answer_physics = "You are an expert at Physics. You are given a Physics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles. " + MMLU_prompt + ''' Here is an example.
227 |
228 | Question:
229 | A spherical conductor carries a net charge. How is this charge distributed on the sphere?
230 |
231 | Choices:
232 | (A) The charge is evenly distributed on the surface.
233 | (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere.
234 | (C) The charge moves continually within the sphere.
235 | (D) The charge is distributed uniformly throughout the sphere.
236 |
237 | Principles:
238 | Coulomb's Law: the force between two charged particles is proportional to the product of their charges and inversely proportional to the square of the distance between them, F = k * q1 * q2 / r, where F is the electric force, k is a constant, q1 and q2 are the charges the particles carry, and r is the distance between them.
239 |
240 | Answer:
241 | Using the Principles of Coulomb's Law, we can solve the problem as following:
242 | Step 1: Apply Coulomb's Law to find out how charges are distributed on the surface.
243 | In the case of a spherical conductor, the charges on the surface will repel each other. The further apart the charges are, the less force they will exert on each other. Therefore, the charges will distribute themselves evenly on the surface of the sphere, as this is the configuration that minimizes the repulsive force between them.
244 |
245 | Step 2: Apply Coulomb's Law to find out what happens if there are other charges present.
246 | The distribution of charge on the surface may also be affected by the presence of other charged objects near the sphere. For example, if a negatively charged object is brought near a positively charged sphere, the negative charges on the sphere will be repelled and will move to the opposite side of the sphere. This will result in a non-uniform distribution of charge on the surface of the sphere.
247 |
248 | Therefore, the correct answer is (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere.
249 |
250 | ''' + MMLU_prompt + '''Principles:
251 | {principles}
252 |
253 | Answer:
254 | '''
255 |
256 | SBP_extract_chemistry = '''You are an expert at Chemistry. Your task is to extract the Chemistry concepts and principles involved in solving the problem. Here is an example.
257 |
258 | Question:
259 | A sample of an unknown chloride compound was dissolved in water, and then titrated with excess Pb(NO3)2 to create a precipitate. After drying, it is determined there are 0.0050 mol of precipitate present. What mass of chloride is present in the original sample?
260 |
261 | Choices:
262 | (A) 0.177 g
263 | (B) 0.355 g
264 | (C) 0.522 g
265 | (D) 0.710 g
266 |
267 | Principles Involved:
268 | Precipitation reactions: Precipitation reactions occur when two soluble salts are mixed and form an insoluble product, called a precipitate. The precipitate can be separated from the solution by filtration or centrifugation.
269 | Molar mass: The molar mass of a substance is the mass of one mole of that substance. The molar mass is expressed in grams per mole (g/mol).
270 | Limiting reactant: The limiting reactant is the reactant that is completely consumed in a chemical reaction. The amount of product formed is determined by the amount of limiting reactant.
271 |
272 | ''' + MMLU_prompt + '''Principles Involved:
273 | '''
274 |
275 | SBP_answer_chemistry = "You are an expert at Chemistry. You are given a Chemistry problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles. " + MMLU_prompt + ''' Here is an example.
276 |
277 | Question:
278 | A sample of an unknown chloride compound was dissolved in water, and then titrated with excess Pb(NO3)2 to create a precipitate. After drying, it is determined there are 0.0050 mol of precipitate present. What mass of chloride is present in the original sample?
279 |
280 | Choices:
281 | (A) 0.177 g
282 | (B) 0.355 g
283 | (C) 0.522 g
284 | (D) 0.710 g
285 |
286 | Principles:
287 | Precipitation reactions: Precipitation reactions occur when two soluble salts are mixed and form an insoluble product, called a precipitate. The precipitate can be separated from the solution by filtration or centrifugation.
288 | Molar mass: The molar mass of a substance is the mass of one mole of that substance. The molar mass is expressed in grams per mole (g/mol).
289 | Limiting reactant: The limiting reactant is the reactant that is completely consumed in a chemical reaction. The amount of product formed is determined by the amount of limiting reactant.
290 |
291 | Answer:
292 | Assuming the unknown chloride compound is MCl, where M represents the metal cation, the balanced chemical equation for the precipitation reaction is:
293 | Pb(NO3)2(aq) + 2MCl(aq) −→ PbCl2(s) + 2MNO3(aq)
294 |
295 | Since Pb(NO3)2 is in excess, MCl is the limiting reactant. The stoichiometry of the reaction indicates that 2 moles of MCl produce 1 mole of PbCl2 precipitate. Therefore, 0.0050 mol of PbCl2 corresponds to 0.010 mol of MCl.
296 |
297 | The mass of chloride in the original sample can be calculated using the molar mass of chloride (35.45 g/mol):
298 | 0.010 mol Cl x 35.45 g/mol = 0.355 g Cl
299 |
300 | The correct answer is (B) 0.355 g.
301 |
302 | ''' + MMLU_prompt + '''Principles:
303 | {principles}
304 |
305 | Answer:
306 | '''
307 |
308 |
309 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from dataset import *
2 | from model import load_model, LLM_generate
3 | from dataset import create_prompt
4 |
5 | import os
6 | import argparse
7 | import copy
8 | from concurrent.futures import ThreadPoolExecutor
9 | from tqdm import tqdm
10 |
11 | from huggingface_hub import login
12 |
13 | hf_token = "hf_YourTokenHere" # Replace with your token
14 | login(token=hf_token)
15 |
16 |
17 | refine1_feeback_prompt = "Review your previous answer and find problems with your answer."
18 | refine1_refine_prompt = "Based on the problems you found, improve your answer."
19 |
20 |
21 | def concat_refine_records(records, output, string):
22 | for i in range(0, len(output)):
23 | for j in range(0, len(output[i])):
24 | records[i][j][string] = output[i][j]
25 |
26 |
27 | def concat_refine_messages(messages, output):
28 | for i in range(0, len(output)):
29 | if isinstance(output, str):
30 | messages[i].append(output)
31 | else:
32 | messages[i].append(output[i][0]["output"])
33 |
34 |
35 | def get_model_outputs(args):
36 | if args.reasoning in ["DiP", "CoT", "AnP", "L2M"]:
37 | args.query = create_prompt(args)
38 | if args.verbal:
39 | print(args.query)
40 | records = LLM_generate(args)
41 |
42 | elif args.reasoning == "S-RF":
43 | args.reasoning = "DiP"
44 | queries = create_prompt(args)
45 | args.reasoning = "S-RF"
46 | args.messages = [[query] for query in queries]
47 | origin_output = LLM_generate(args)
48 | records = []
49 | for i in range(0, len(origin_output)):
50 | record_ = []
51 | for j in range(0, len(origin_output[i])):
52 | record = {}
53 | record["output0"] = origin_output[i][j]
54 | record_.append(record)
55 | records.append(record_)
56 | with ThreadPoolExecutor(max_workers=5) as executor:
57 | parameters = (args.messages, origin_output)
58 | executor.submit(concat_refine_messages, *parameters)
59 | for j in tqdm(range(0, args.rounds)):
60 | with ThreadPoolExecutor(max_workers=5) as executor:
61 | parameters = (args.messages, refine1_feeback_prompt)
62 | executor.submit(concat_refine_messages, *parameters)
63 | output = LLM_generate(args)
64 | if args.verbal:
65 | print(output)
66 | with ThreadPoolExecutor(max_workers=5) as executor:
67 | parameters = (records, output, f"problems{j+1}")
68 | executor.submit(concat_refine_records, *parameters)
69 | parameters = (args.messages, output)
70 | executor.submit(concat_refine_messages, *parameters)
71 | parameters = (args.messages, refine1_refine_prompt + " " + PROMPT_FORMAT)
72 | executor.submit(concat_refine_messages, *parameters)
73 | output = LLM_generate(args)
74 | if args.verbal:
75 | print(output)
76 | with ThreadPoolExecutor(max_workers=5) as executor:
77 | parameters = (records, output, f"output{j+1}")
78 | executor.submit(concat_refine_records, *parameters)
79 | parameters = (args.messages, output)
80 | executor.submit(concat_refine_messages, *parameters)
81 |
82 | elif args.reasoning == "ToT":
83 | records = []
84 | args.query = create_prompt(args)
85 | if args.verbal:
86 | print(args.query)
87 | l = len(args.questions)
88 | output_choices = LLM_generate(args)
89 | for i in range(0, l):
90 | record_ = []
91 | record = {}
92 | record["solutions"] = args.records_tot[i]
93 | for j in range(0, args.num):
94 | record["choose"] = output_choices[i][j]
95 | record_.append(record)
96 | records.append(record_)
97 |
98 | elif args.reasoning == "SBP":
99 | records = []
100 | args.query = create_prompt(args)
101 | if args.verbal:
102 | print(args.query)
103 | principles = LLM_generate(args)
104 | num = args.num
105 | l = len(args.questions)
106 | args.num = 1
107 | args.query = []
108 | for i in range(0, l):
109 | for j in range(0, num):
110 | record = {}
111 | record["principles"] = principles[i][j]
112 | args.principles = record["principles"]["output"]
113 | args.query.append(create_prompt(args, i)[0])
114 | del args.principles
115 | solutions = LLM_generate(args)
116 | for i in range(0, l):
117 | record_ = []
118 | for j in range(0, num):
119 | record = {}
120 | record["principles"] = principles[i][j]
121 | record["solution"] = solutions[num * i + j][0]
122 | record_.append(record)
123 | records.append(record_)
124 | args.num = num
125 |
126 | elif args.reasoning == "MAD":
127 | records = {}
128 | args.reasoning = "DiP"
129 | agent_contexts = [create_prompt(args) for agent in range(0, 3)]
130 | args.reasoning = "MAD"
131 | if args.verbal:
132 | print(args.query)
133 | if args.continue_:
134 | for round in range(args.rounds):
135 | if round < len(args.records):
136 | outputs = [output["output"] for output in args.records[f"round{round+1}"]]
137 | records[f"round{round+1}"] = args.records[f"round{round+1}"]
138 | else:
139 | records[f"round{round+1}"] = []
140 | for i, agent_context in enumerate(agent_contexts):
141 | if round != 0:
142 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
143 | message = construct_message(args, agent_contexts_other, args.question, 2*round - 1)
144 | agent_context.append(message)
145 | args.messages = [agent_context]
146 | if round < len(args.records):
147 | assistant_message = outputs[i]
148 | else:
149 | record = LLM_generate(args)[0][0]
150 | records[f"round{round+1}"].append(record)
151 | assistant_message = record["output"]
152 | else:
153 | assistant_message = outputs[i]
154 | agent_context.append(assistant_message)
155 | else:
156 | for round in range(args.rounds):
157 | records[f"round{round+1}"] = []
158 | for i, agent_context in enumerate(agent_contexts):
159 | if round != 0:
160 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
161 | message = construct_message(args, agent_contexts_other, args.question, 2*round - 1)
162 | agent_context.append(message)
163 | args.messages = [agent_context]
164 | record = LLM_generate(args)[0][0]
165 | records[f"round{round+1}"].append(record)
166 | assistant_message = record["output"]
167 | agent_context.append(assistant_message)
168 | records = [[records]]
169 |
170 | return records
171 |
172 |
173 | def handle_tot_reasoning(args):
174 | """Handle Tree of Thoughts (ToT) reasoning setup."""
175 | logs_tot = []
176 | shot = args.shot
177 | args.shot = 0
178 | args.reasoning = "CoT"
179 |
180 | for j in range(shot):
181 | args.n = j
182 | logs_tot.append(read_logs(args))
183 |
184 | args.reasoning = "ToT"
185 | args.shot = shot
186 | return logs_tot
187 |
188 |
189 | def setup_mad_reasoning(args):
190 | """Setup for MAD (Multi-Agent Debate) reasoning, reusing the results of DiP."""
191 | reasoning = args.reasoning
192 | args.reasoning = "DiP"
193 |
194 | # Collect initial logs
195 | args.n = 0
196 | logs_DiP_0 = read_logs(args)
197 | args.n = 1
198 | logs_DiP_1 = read_logs(args)
199 | args.n = 2
200 | logs_DiP_2 = read_logs(args)
201 |
202 | assert len(logs_DiP_0) == len(logs_DiP_1) == len(logs_DiP_2), "Logs of DiP for MAD reasoning must be of the same length."
203 | assert len(logs_DiP_0) > 0, "To use MAD reasoning, DiP logs must be available."
204 | args.reasoning = reasoning
205 | return logs_DiP_0, logs_DiP_1, logs_DiP_2
206 |
207 |
208 | if __name__ == "__main__":
209 | # Model and dataset configurations
210 | MODEL_CONFIGS = {
211 | "Qwen-2.5": "Qwen/Qwen2.5-7B-Instruct",
212 | "Llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
213 | "GLM-4": "THUDM/glm-4-9b-chat",
214 | "Gemini": "gemini-1.5-flash",
215 | "GPT4-mini": "gpt-4o-mini",
216 | "Phi-3.5": "microsoft/Phi-3.5-mini-instruct"
217 | }
218 |
219 | DATASET_CONFIGS = {
220 | "math": ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"],
221 | "science": [
222 | "MMLU-high_school_physics",
223 | "MMLU-high_school_chemistry",
224 | "MMLU-high_school_biology",
225 | "GPQA",
226 | ],
227 | }
228 |
229 | PROMPT_FORMATS = {
230 | "GSM8K": GSM8K.prompt_format,
231 | "GPQA": GPQA.prompt_format,
232 | "GSM-Hard": GSM_Hard.prompt_format,
233 | "MATH": MATH.prompt_format,
234 | "MMLU-high_school_physics": MMLU.prompt_format,
235 | "MMLU-high_school_chemistry": MMLU.prompt_format,
236 | "MMLU-high_school_biology": MMLU.prompt_format,
237 | "AIME_2024": AIME.prompt_format
238 | }
239 |
240 | # Set up argument parser with improved descriptions
241 | parser = argparse.ArgumentParser(description="Large Language Model Evaluation Framework")
242 |
243 | # Model configuration
244 | parser.add_argument("--model_name", type=str, default=MODEL_CONFIGS["Qwen-2.5"],
245 | choices=list(MODEL_CONFIGS.values()))
246 | parser.add_argument("--model_type", type=str, default="vllm",
247 | choices=["vllm", "gemini", "openai"])
248 |
249 | # Dataset configuration
250 | parser.add_argument("--dataset", type=str, default="GSM8K",
251 | choices=[ds for group in DATASET_CONFIGS.values() for ds in group])
252 | parser.add_argument("--split", type=str, default="test")
253 |
254 | # Reasoning strategy configuration
255 | parser.add_argument("--reasoning", type=str, default="DiP",
256 | choices=["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT", "MAD"],
257 | help="Reasoning prompting strategy")
258 | parser.add_argument("--shot", type=int, default=0,
259 | help='''DiP, SBP: Shot is fixed at 0, while SBP using 1-shot on MMLU.
260 | CoT, L2M: Number of examples in few-shot prompting.
261 | AnP: Number of analogous problems to generate.
262 | S-RF: Shot is fixed at 0.
263 | ToT: Number of reasoning paths to generate.
264 | MAD: Shot is fixed at 0, using 3 agents for debate.
265 | ''')
266 |
267 | # Processing configuration
268 | parser.add_argument("--max_num_workers", type=int, default=1,
269 | help="Maximum number of workers for parallel processing for API-based models")
270 | parser.add_argument("--batchsize", type=int, default=10)
271 | parser.add_argument("--rounds", type=int, default=5)
272 |
273 | # Generation parameters
274 | parser.add_argument("--temperature", type=float, default=0.7)
275 | parser.add_argument("--top_p", type=float, default=0.8)
276 | parser.add_argument("--max_new_tokens", type=int, default=4096)
277 |
278 | # Generation range configuration
279 | parser.add_argument("--range_begin", type=int, default=0)
280 | parser.add_argument("--range_end", type=int, default=16)
281 |
282 | # Hardware configuration
283 | parser.add_argument("--gpu", type=str, default="4,5", help="GPU IDs to use, e.g., '0,1'")
284 | parser.add_argument("--dtype", type=str, default="bfloat16",
285 | help="Data type for VLLM")
286 |
287 | # API configuration
288 | parser.add_argument("--google_api_key", type=str)
289 | parser.add_argument("--openai_api_key", type=str)
290 | parser.add_argument("--openai_base_url", type=str)
291 |
292 | # Other options
293 | parser.add_argument("--verbal", action="store_true",
294 | help="Enable verbose output")
295 | parser.add_argument("--seed", type=int, default=0)
296 |
297 | args = parser.parse_args()
298 |
299 | # Set up basic configurations
300 | args.range = range(args.range_begin, args.range_end)
301 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
302 |
303 | # Initialize prompt format based on dataset
304 | PROMPT_FORMAT = PROMPT_FORMATS.get(args.dataset)
305 |
306 | # Reset messaging configuration
307 | args.system = None
308 | args.messages = None
309 | args.query = None
310 |
311 | # Special handling for MAD reasoning
312 | if args.reasoning == "MAD":
313 | args.batchsize = 1
314 | args.max_num_workers = 1
315 |
316 | # Load dataset and prepare model
317 | dataset_list = read_dataset(args)
318 | model_name = args.model_name
319 | args.model_name = model_name.split("/")[-1]
320 |
321 | # Special handling for different reasoning strategies
322 | if args.reasoning in ["DiP", "SBP"]:
323 | args.shot = 0
324 | elif args.reasoning in ["CoT"]:
325 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH"]:
326 | assert args.shot in [0, 1, 5], "CoT reasoning requires 0, 1, or 5 shots for GSM8K, GSM-Hard, and MATH datasets"
327 | else:
328 | assert args.shot == 0, f"CoT reasoning requires 0-shot prompting for {args.dataset} dataset"
329 | elif args.reasoning == "AnP":
330 | assert args.shot in [1, 3, 5], "AnP reasoning requires 1, 3, or 5 shots, i.e., generating 1, 3, or 5 analogous problems"
331 | elif args.reasoning in ["S-RF", "MAD"]:
332 | args.shot = 0
333 | assert len(args.range) == 1, "Range from the beginning to the end must be 1 for S-RF and MAD reasoning"
334 | elif args.reasoning == "ToT":
335 | assert args.shot in [3, 5, 10], "ToT reasoning requires 3, 5, or 10 shots, i.e., generating 3, 5, or 10 reasoning paths"
336 | logs_tot = handle_tot_reasoning(args)
337 |
338 | args.num = len(args.range)
339 |
340 | # Handle MAD reasoning specific setup
341 | if args.reasoning == "MAD":
342 | logs_DiP_0, logs_DiP_1, logs_DiP_2 = setup_mad_reasoning(args)
343 |
344 | print(f"{'='*10}{args.dataset}===={args.reasoning}===={args.shot}{'='*32}")
345 |
346 | # Process logs and validate progress
347 | logs_all = []
348 | for j in args.range:
349 | args.n = j
350 | logs_all.append(read_logs(args))
351 | args.n = args.range[0]
352 | logs = logs_all[0]
353 | begin_num = len(logs)
354 |
355 | assert begin_num < len(dataset_list), "Processing already completed. Please check logs."
356 |
357 | # Initialize model
358 | args.model_name = model_name
359 | load_model(args)
360 | args.model_name = model_name.split("/")[-1]
361 |
362 |
363 | letters = ["A", "B", "C", "D", "E"]
364 | if args.reasoning == "MAD":
365 | if len(logs) == len(dataset_list):
366 | rounds = len(logs[0]["record"])
367 | if rounds < args.rounds:
368 | begin_num = 0
369 | args.continue_ = True
370 | else:
371 | args.continue_ = False
372 | else:
373 | args.continue_ = False
374 |
375 | if args.model_type == "vllm":
376 | for ii in tqdm(range(begin_num, len(dataset_list), args.batchsize)):
377 | examples = []
378 | args.questions = []
379 | args.choiceses = []
380 | args.subjects = []
381 | range_ = range(ii, min(len(dataset_list), ii + args.batchsize))
382 | args.records_tot = []
383 | for i in range_:
384 | if args.reasoning == "ToT":
385 | records = []
386 | for j in range(0, args.shot):
387 | records.append(logs_tot[j][i]["record"])
388 | args.records_tot.append(records)
389 | example = copy.deepcopy(dataset_list[i])
390 | if args.dataset == "GSM8K":
391 | args.questions.append(example["question"])
392 | answer = example["answer"]
393 | key = answer.split("#### ")[-1]
394 | example["key"] = key
395 | elif args.dataset == "GPQA":
396 | args.questions.append(example['problem'])
397 | args.choiceses.append(example['choices'])
398 | args.subjects.append(example["subject"].lower())
399 | example["question"] = example.pop("problem")
400 | example["key"] = example["answer"]
401 | elif args.dataset == "GSM-Hard":
402 | args.questions.append(example["input"])
403 | example["question"] = example.pop("input")
404 | example["key"] = example.pop("target")
405 | elif args.dataset == "MATH":
406 | args.questions.append(example["problem"])
407 | example["question"] = example.pop("problem")
408 | example["key"] = example.pop("answer")
409 | elif "MMLU" in args.dataset:
410 | args.questions.append(example["question"])
411 | args.subjects.append(example["subject"])
412 | args.choiceses.append(example["choices"])
413 | example["key"] = letters[example.pop("answer")]
414 | elif args.dataset == "AIME_2024":
415 | args.questions.append(example["Problem"])
416 | example["question"] = example.pop("Problem")
417 | example["key"] = example.pop("Answer")
418 | example["num"] = i
419 | examples.append(example)
420 |
421 | if args.reasoning == "MAD":
422 | args.previous_record = [logs_DiP_0[i]["record"], logs_DiP_1[i]["record"], logs_DiP_2[i]["record"]]
423 | if args.continue_:
424 | args.records = logs[i]["record"]
425 |
426 | records_all = get_model_outputs(args)
427 |
428 | for i in range(0, len(range_)):
429 | records = records_all[i]
430 | for j, k in enumerate(args.range):
431 | new_example = examples[i].copy()
432 | new_example["record"] = records[j]
433 | logs_all[j].append(new_example)
434 | del new_example
435 | del examples
436 |
437 | for j, k in enumerate(args.range):
438 | args.n = k
439 | record_logs(logs_all[j], args)
440 | else:
441 | nums = [log["num"] for log in logs]
442 | remain_data = [i for i in range(len(dataset_list)) if i not in nums]
443 | remain_ranges = [remain_data[i:i+args.batchsize] for i in range(0, len(remain_data), args.batchsize)]
444 | for range_ in tqdm(remain_ranges):
445 | examples = []
446 | args.questions = []
447 | args.choiceses = []
448 | args.subjects = []
449 | args.records_tot = []
450 | for i in range_:
451 | if "tot" in args.reasoning:
452 | records = []
453 | for j in range(0, args.shot):
454 | records.append(logs_tot[j][i]["record"])
455 | args.records_tot.append(records)
456 | example = copy.deepcopy(dataset_list[i])
457 | if args.dataset == "GSM8K":
458 | args.questions.append(example["question"])
459 | answer = example["answer"]
460 | key = answer.split("#### ")[-1]
461 | example["key"] = key
462 | elif args.dataset == "GPQA":
463 | args.questions.append(example['problem'])
464 | args.choiceses.append(example['choices'])
465 | args.subjects.append(example["subject"].lower())
466 | example["key"] = example["answer"]
467 | elif args.dataset == "GSM-Hard":
468 | args.questions.append(example["input"])
469 | example["question"] = example.pop("input")
470 | example["key"] = example.pop("target")
471 | elif args.dataset == "MATH":
472 | args.questions.append(example["problem"])
473 | example["question"] = example.pop("problem")
474 | example["key"] = example.pop("answer")
475 | elif "MMLU" in args.dataset:
476 | args.questions.append(example["question"])
477 | args.subjects.append(example["subject"])
478 | args.choiceses.append(example["choices"])
479 | example["key"] = letters[example.pop("answer")]
480 | elif args.dataset == "AIME_2024":
481 | args.questions.append(example["Problem"])
482 | example["question"] = example.pop("Problem")
483 | example["key"] = example.pop("Answer")
484 | example["num"] = i
485 | examples.append(example)
486 |
487 | if args.reasoning == "MAD":
488 | args.previous_record = [logs_DiP_0[i]["record"], logs_DiP_1[i]["record"], logs_DiP_2[i]["record"]]
489 | if args.continue_:
490 | args.records = logs[i]["record"]
491 |
492 | records_all = get_model_outputs(args)
493 |
494 | for i in range(0, len(range_)):
495 | records = records_all[i]
496 | for j, k in enumerate(args.range):
497 | new_example = examples[i].copy()
498 | new_example["record"] = records[j]
499 | logs_all[j].append(new_example)
500 | del new_example
501 | del examples
502 |
503 | for j, k in enumerate(args.range):
504 | args.n = k
505 | record_logs(logs_all[j], args)
506 |
--------------------------------------------------------------------------------
/eval_csv_cost.py:
--------------------------------------------------------------------------------
1 | from openpyxl import Workbook, load_workbook
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | from dataset import *
6 | import argparse
7 | from collections import Counter
8 | import random
9 | import os
10 | import re
11 |
12 | from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter
13 |
14 |
15 | def all_equal(lst):
16 | return all(x == lst[0] for x in lst)
17 |
18 |
19 | def find_most_common_elements(input_list):
20 | counter = Counter(input_list)
21 | max_count = max(counter.values())
22 | most_common_elements = [element for element, count in counter.items() if count == max_count]
23 | return most_common_elements, max_count
24 |
25 |
26 | def get_most_common_answer(outputs):
27 | outputs = [output for output in outputs if output != None]
28 | if outputs == []:
29 | return None
30 | most_common_elements, max_count = find_most_common_elements(outputs)
31 | most_common_answer = random.choice(most_common_elements)
32 | return most_common_answer
33 |
34 |
35 | def get_cost(model_name, prompt_tokens, completion_tokens):
36 | prompt_tokens = float(prompt_tokens)
37 | completion_tokens = float(completion_tokens)
38 | if "gemini" in model_name:
39 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3
40 | elif model_name == "gpt-3.5-turbo-0613":
41 | cost = prompt_tokens * 1.5 + completion_tokens * 2
42 | elif model_name == "gpt-4o-mini":
43 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
44 | else:
45 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
46 | cost = cost / 10 ** 6
47 | return cost
48 |
49 |
50 | if __name__ == "__main__":
51 | model_names = ["Qwen/Qwen2.5-7B-Instruct",
52 | "meta-llama/Meta-Llama-3-8B-Instruct",
53 | "THUDM/glm-4-9b-chat",
54 | "gemini-1.5-flash",
55 | "gpt-4o-mini",
56 | "microsoft/Phi-3.5-mini-instruct"]
57 |
58 | datasets = ["GSM8K",
59 | "GPQA",
60 | "GSM-Hard",
61 | "MATH",
62 | "MMLU-high_school_physics",
63 | "MMLU-high_school_chemistry",
64 | "MMLU-high_school_biology",
65 | "AIME_2024"]
66 |
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument("--model_name", type=str, default=model_names[0])
69 | parser.add_argument("--dataset", type=str, default=datasets[0])
70 | parser.add_argument("--split", type=str, default="test")
71 | parser.add_argument("--shuffle", type=bool, default=True)
72 | args = parser.parse_args()
73 |
74 | DiP_max = 100 ## maximum sampling time of DiP
75 | CoT_max = 90 ## maximum sampling time of CoT
76 | L2M_max = 75 ## maximum sampling time of L2M
77 | SBP_max = 25 ## maximum sampling time of SBP
78 | AnP_max = 40 ## maximum sampling time of AnP
79 | S_RF_max = 10 ## maximum round of S-RF
80 | MAD_max = 10 ## maximum round of MAD
81 |
82 | sampling_times = [1, 3, 5, 7, 10, 12, 15]
83 |
84 | model_names_formal = {
85 | "Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct",
86 | "Llama-3-8B-Instruct": "LLaMA-3-8B-Instruct",
87 | "glm-4-9b-chat": "GLM-4-9B-Chat",
88 | "gemini-1.5-flash": "Gemini-1.5-Flash",
89 | "gpt-4o-mini": "GPT-4o-mini",
90 | "Phi-3.5-mini-instruct": "Phi-3.5-mini-Instruct"
91 | }
92 |
93 | marker_dict = {
94 | 'DiP': '^',
95 | 'CoT': '^',
96 | 'L2M': '^',
97 | 'ToT': 'o',
98 | 'S-RF':'o',
99 | 'SBP': '^',
100 | 'AnP': '^',
101 | 'MAD': 'o'
102 | }
103 |
104 |
105 | if "/" in args.model_name:
106 | args.model_name = args.model_name.split('/')[-1]
107 |
108 | path = os.path.join(log_path, args.dataset, f"{args.dataset}_cost.xlsx")
109 | if os.path.exists(path):
110 | wb = load_workbook(path)
111 | else:
112 | wb = Workbook()
113 |
114 | unique_labels = []
115 |
116 | headers = ["Method", "Subject", "x-shot", "num", "prompt_tokens", "completion_tokens", "cost", "tokens", "accuracy"]
117 |
118 | if args.model_name in wb.sheetnames:
119 | ws_new = wb[args.model_name]
120 | ws_new.delete_rows(1, ws_new.max_row)
121 | ws_new.delete_cols(1, ws_new.max_column)
122 | else:
123 | ws_new = wb.create_sheet(args.model_name)
124 | ws_new.append(headers)
125 |
126 | try:
127 | model_name = args.model_name
128 | sheet = wb.get_sheet_by_name(model_name)
129 |
130 | tokens = []
131 | accuracy = []
132 | cost = []
133 | labels = []
134 | flag = 0
135 | for row in sheet.iter_rows(values_only=True):
136 | if flag == 0:
137 | flag = 1
138 | continue
139 | # if row[0] != "tot-io":
140 | labels.append(row[0])
141 | tokens.append(int(row[-2]))
142 | accuracy.append(float(row[-1]))
143 | cost.append(float(row[-3]))
144 | counter = Counter(labels)
145 | assert counter["DiP"] >= DiP_max
146 | assert counter["CoT"] >= CoT_max
147 | assert counter["L2M"] >= L2M_max
148 | assert counter["SBP"] >= SBP_max
149 | assert counter["AnP"] >= AnP_max
150 | # assert counter["ToT"] >= 3
151 | # assert counter["S-RF"] >= S_RF_max
152 | # assert counter["MAD"] >= MAD_max
153 | except:
154 | ws_new.delete_rows(1, ws_new.max_row)
155 | ws_new.delete_cols(1, ws_new.max_column)
156 | reasonings = ["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT_3", "ToT_5", "ToT_10", "MAD"]
157 |
158 | for reasoning in reasonings:
159 | try:
160 | if reasoning == "DiP":
161 | args.reasoning = "DiP"
162 | args.shot = 0
163 | for N in range(0, DiP_max+1):
164 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
165 | break
166 |
167 | elif reasoning == "CoT":
168 | args.reasoning = "CoT"
169 | args.shot = 0
170 | for N in range(0, CoT_max+1):
171 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
172 | break
173 |
174 | elif "L2M" in reasoning:
175 | args.reasoning = "L2M"
176 | if args.dataset in ["GSM8K", "GSM-Hard"]:
177 | args.shot = 1
178 | else:
179 | args.shot = 0
180 | for N in range(0, L2M_max+1):
181 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
182 | break
183 |
184 | elif reasoning == "ToT_3":
185 | args.reasoning = "ToT"
186 | args.shot = 3
187 | N = 1
188 | elif reasoning == "ToT_5":
189 | args.reasoning = "ToT"
190 | args.shot = 5
191 | N = 1
192 | elif reasoning == "ToT_10":
193 | args.reasoning = "ToT"
194 | args.shot = 10
195 | N = 1
196 |
197 | elif reasoning == "S-RF":
198 | args.reasoning = reasoning
199 | args.shot = 0
200 | N = min((len(logs[-1]["record"]) - 1) // 2, S_RF_max)
201 |
202 | elif reasoning == "SBP":
203 | args.reasoning = "SBP"
204 | args.shot = 0
205 | for N in range(0, SBP_max+1):
206 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
207 | break
208 |
209 | elif reasoning == "AnP":
210 | args.reasoning = "AnP"
211 | args.shot = 1
212 | for N in range(0, AnP_max+1):
213 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
214 | break
215 |
216 | elif reasoning == "MAD":
217 | args.reasoning = "MAD"
218 | args.shot = 0
219 | args.n = 0
220 | logs = read_logs(args)
221 | N = min(len(logs[0]["record"]), MAD_max)
222 |
223 | logs_list = []
224 | assert N != 0
225 | for m in range(0, N):
226 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP", "SBP"]:
227 | args.nums = range(0,m+1)
228 | args.n = m
229 | logs = read_logs(args)
230 | logs_list.append(logs)
231 | if (m+1) not in sampling_times:
232 | continue
233 |
234 | elif args.reasoning in ["S-RF", "MAD"]:
235 | args.n = 0
236 | args.nums = range(0, m+1)
237 | logs = read_logs(args)
238 | logs_list.append(logs)
239 |
240 | elif "ToT" in args.reasoning:
241 | args.nums = range(0, m+1)
242 | for n in range(0, 5):
243 | try:
244 | args.n = n
245 | logs = read_logs(args)
246 | logs_list.append(logs)
247 | except:
248 | break
249 |
250 | accs = []
251 | prompt_tokens_ = 0
252 | completion_tokens_ = 0
253 | l = len(logs_list[0])
254 |
255 | for count in range(0, 5):
256 | random.shuffle(logs_list)
257 | acc_num = 0
258 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"]:
259 | subject = "mathematic"
260 | for j in range(0, l):
261 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset:
262 | subject = logs_list[0][j]["subject"]
263 | key = logs_list[0][j]['key']
264 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP"]:
265 | output_keys = [parse_answer(args, log[j]['record']['output']) for log in logs_list]
266 | output_keys = [output for output in output_keys if output != None]
267 | if len(output_keys) == 0:
268 | output_keys = [None]
269 |
270 | if count == 0:
271 | prompt_tokens = [log[j]['record']['usage']['prompt_tokens'] for log in logs_list[:1]]
272 | completion_tokens = [log[j]['record']['usage']['completion_tokens'] for log in logs_list]
273 |
274 | elif args.reasoning == "S-RF":
275 | output_keys = [parse_answer(args, logs[j]["record"][f"output{m+1}"]["output"])]
276 | if count == 0:
277 | prompt_tokens = [logs[j]["record"]["output0"]['usage']["prompt_tokens"]]
278 | completion_tokens = [logs[j]["record"]["output0"]['usage']["completion_tokens"]]
279 | for k in range(0, m+1):
280 | prompt_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["prompt_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["prompt_tokens"]]
281 | completion_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["completion_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["completion_tokens"]]
282 |
283 | elif "tot" in args.reasoning:
284 | indexes = []
285 | for logs in logs_list:
286 | index = parse_best_solution(logs[j]["record"]["choose"]["output"])
287 | if index != None and "0" < index and index <= str(args.shot):
288 | index = int(index) - 1
289 | else:
290 | index = random.choice(range(0, args.shot))
291 | indexes.append(index)
292 | solutions = logs_list[0][j]["record"]["solutions"]
293 | index = get_most_common_answer(indexes)
294 | best_solution = solutions[index]
295 | output_keys = [parse_answer(args, best_solution["output"])]
296 | if count == 0:
297 | prompt_tokens = []
298 | completion_tokens = []
299 | for k in range(0, len(solutions)):
300 | solution = solutions[k]
301 | if k == 0:
302 | prompt_tokens.append(solution["usage"]["prompt_tokens"])
303 | completion_tokens.append(solution["usage"]["completion_tokens"])
304 | for ii, logs in enumerate(logs_list):
305 | if ii == 0:
306 | prompt_tokens.append(logs[j]["record"]["choose"]["usage"]["prompt_tokens"])
307 | completion_tokens.append(logs[j]["record"]["choose"]["usage"]["completion_tokens"])
308 |
309 | elif args.reasoning == "SBP":
310 | output_keys = [parse_answer(args, log[j]['record']['solution']["output"]) for log in logs_list]
311 | output_keys = [output for output in output_keys if output != None]
312 | if len(output_keys) == 0:
313 | output_keys = [None]
314 | if count == 0:
315 | prompt_tokens = []
316 | completion_tokens = []
317 | for k in range(0, len(logs_list)):
318 | log = logs_list[k][j]
319 | if k == 0:
320 | prompt_tokens.append(log["record"]["principles"]["usage"]["prompt_tokens"])
321 | completion_tokens.append(log["record"]["principles"]["usage"]["completion_tokens"])
322 | prompt_tokens.append(log["record"]["solution"]["usage"]["prompt_tokens"])
323 | completion_tokens.append(log["record"]["solution"]["usage"]["completion_tokens"])
324 |
325 | elif args.reasoning == "MAD":
326 | output_keys = [parse_answer(args, log["output"]) for log in logs[j]["record"][f"round{m+1}"]]
327 | output_keys = [output for output in output_keys if output != None]
328 | if len(output_keys) == 0:
329 | output_keys = [None]
330 | if count == 0:
331 | prompt_tokens = []
332 | completion_tokens = []
333 | for k in range(0, m+1):
334 | prompt_tokens += [log["usage"]["prompt_tokens"] for log in logs[j]["record"][f"round{k+1}"]]
335 | completion_tokens += [log["usage"]["completion_tokens"] for log in logs[j]["record"][f"round{k+1}"]]
336 |
337 | if count == 0:
338 | prompt_tokens_ += sum(prompt_tokens)
339 | completion_tokens_ += sum(completion_tokens)
340 |
341 | most_common_elements, max_count = find_most_common_elements(output_keys)
342 | output_key = random.choice(most_common_elements)
343 |
344 | if args.dataset in ["GSM8K"]:
345 | if output_key != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output_key) < 10**(-4):
346 | acc_num += 1
347 | elif args.dataset in ["GSM-Hard"]:
348 | if output_key != None and abs(float(key) - output_key) < 10**(-4):
349 | acc_num += 1
350 | elif args.dataset in ["MATH", "AIME_2024"]:
351 | if is_equiv(str(key), output_key):
352 | acc_num += 1
353 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset:
354 | if key == output_key:
355 | acc_num += 1
356 | acc = acc_num / l * 100
357 | accs.append(acc)
358 |
359 | acc = sum(accs)/len(accs)
360 | total_tokens = prompt_tokens_ + completion_tokens_
361 | cost = get_cost(args.model_name, prompt_tokens_, completion_tokens_)
362 |
363 | s = f"{args.reasoning.ljust(15)} " + "{}-shot".format(args.shot).ljust(7) + f" {str(args.nums).ljust(12)} prompt_tokens: {str(prompt_tokens_).ljust(10)} completion_tokens: {str(completion_tokens_).ljust(10)} cost:" + str("%.12f"%(cost)).ljust(20) + f"tokens: {str(total_tokens).ljust(11)}" " Acc: " + "%.4f"%acc
364 | print(s)
365 |
366 | ws_new.append([args.reasoning, subject, str(args.shot), str(args.nums[-1]+1), str(prompt_tokens_), str(completion_tokens_), str("%.15f"%(cost)), str(total_tokens), str("%.4f"%acc)])
367 |
368 | except Exception as e:
369 | print(args.reasoning, "error")
370 |
371 | wb.save(path)
372 |
373 | model_name = args.model_name
374 | sheet = wb.get_sheet_by_name(model_name)
375 |
376 | tokens = []
377 | accuracy = []
378 | cost = []
379 | labels = []
380 | flag = 0
381 | for row in sheet.iter_rows(values_only=True):
382 | if flag == 0:
383 | flag = 1
384 | continue
385 | labels.append(row[0])
386 | tokens.append(int(row[-2]))
387 | accuracy.append(float(row[-1]))
388 | cost.append(float(row[-3]))
389 |
390 | wb.close()
391 |
392 | unique_labels = ["DiP", "CoT", "L2M", "ToT", "S-RF", "SBP", "AnP", "MAD"]
393 | colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
394 |
395 | plt.figure(figsize=(6, 3))
396 | ax = plt.gca()
397 |
398 | colors = np.array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
399 | [2.17647059e-01, 4.29120609e-01, 9.75511968e-01, 1.00000000e+00],
400 | [17.25490196e-02, 8.82927610e-01, 1, 1.00000000e+00],
401 | [3.54901961e-01, 9.74138602e-01, 7.82927610e-01, 1.00000000e+00],
402 | [6.45098039e-01, 9.74138602e-01, 6.22112817e-01, 1.00000000e+00],
403 | [1, 7.82927610e-01, 5.34676422e-01, 1.00000000e+00],
404 | [1.00000000e+00, 4.29120609e-01, 2.19946358e-01, 1.00000000e+00],
405 | [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00],])
406 |
407 |
408 | for i, label in enumerate(unique_labels):
409 | mask = [label == lb for lb in labels]
410 | plt.scatter([cost[idx] for idx, val in enumerate(mask) if val],
411 | [accuracy[idx] for idx, val in enumerate(mask) if val],
412 | color=colors[i], label=label, s=36, marker=marker_dict[label], edgecolor='black', linewidths=0.75, zorder=5)
413 | plt.plot([cost[idx] for idx, val in enumerate(mask) if val],
414 | [accuracy[idx] for idx, val in enumerate(mask) if val],
415 | color=colors[i], linewidth=1.5, marker=marker_dict[label])
416 |
417 | max_cost = max(cost) // 0.1 * 0.1
418 | x_ticks = [max_cost / 10 * i for i in range(1, 11)]
419 | pos = - x_ticks[0] / 1
420 |
421 | y_max = max(accuracy)
422 | y_min = min(accuracy)
423 | ind_line = (y_max - y_min) / 7.5
424 | plt.ylim(bottom=None, top=y_max + ind_line * 1.5)
425 |
426 | def custom_formatter(y, pos):
427 | if y > max(accuracy) + 0.3:
428 | return ""
429 | else:
430 | return f"{int(y)}"
431 |
432 | for x in x_ticks:
433 | plt.axvline(x=x, color='gray', linestyle='--', alpha=0.5, zorder=0)
434 |
435 | mask = [cost[idx] <= x for idx in range(len(cost))]
436 | if any(mask):
437 | best_idx = np.argmax([accuracy[idx] for idx, val in enumerate(mask) if val])
438 | best_label = [labels[idx] for idx, val in enumerate(mask) if val][best_idx]
439 |
440 | plt.scatter(x, y_max + ind_line, color=colors[unique_labels.index(best_label)],
441 | marker=marker_dict[best_label], s=64, zorder=8, edgecolor='black', linewidths=1)
442 |
443 | plt.axhline(y=y_max + ind_line * 0.5, color='k', linestyle='-', alpha=0.5, zorder=10, linewidth=1.5)
444 |
445 | plt.rcParams['xtick.labelsize'] = 14
446 | plt.rcParams['ytick.labelsize'] = 14
447 | plt.rcParams['font.weight'] = 'bold'
448 | plt.rcParams['axes.labelweight'] = 'bold'
449 | plt.rcParams['axes.titleweight'] = 'bold'
450 | plt.title(f"{model_names_formal[args.model_name]}", fontsize=18)
451 | plt.xlabel('Cost ($)', fontsize=14, fontweight='bold')
452 | plt.ylabel('Accuracy', fontsize=14, fontweight='bold')
453 |
454 | plt.xticks(x_ticks)
455 |
456 | yticks = ax.get_yticks()
457 | yticks = yticks[yticks <= max(accuracy)]
458 |
459 | ax.set_yticks(yticks)
460 | ax.set_yticklabels([str(int(tick)) for tick in yticks])
461 | ax.yaxis.set_major_locator(MaxNLocator(integer=True))
462 | ax.yaxis.set_major_formatter(FuncFormatter(custom_formatter))
463 |
464 | plt.text(pos, y_max + ind_line, r'$\mathbf{P}_{O}^*$', fontsize=12, ha='center', va='center')
465 | plt.rcParams['font.weight'] = 'bold'
466 | plt.rcParams['axes.labelweight'] = 'bold'
467 | plt.rcParams['axes.titleweight'] = 'bold'
468 | plt.legend(loc='best', fontsize=12, framealpha=0.9, ncol=2, markerscale=1.5, handletextpad=0.5, columnspacing=1.0)
469 |
470 | plt.gca().yaxis.set_major_locator(MultipleLocator(5))
471 | plt.savefig(os.path.join(log_path, args.dataset, args.model_name, "pics", f"Performance_cost.png"), bbox_inches='tight', dpi = 600)
472 | plt.close()
473 | print(f"Performance_cost.png saved to {os.path.join(log_path, args.dataset, args.model_name, 'pics', 'Performance_cost.png')}")
474 |
475 |
--------------------------------------------------------------------------------
/eval_csv_N.py:
--------------------------------------------------------------------------------
1 | from openpyxl import Workbook, load_workbook
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | from dataset import *
6 | import argparse
7 | from collections import Counter
8 | import random
9 | import os
10 | import re
11 |
12 | from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter
13 |
14 |
15 | def all_equal(lst):
16 | return all(x == lst[0] for x in lst)
17 |
18 |
19 | def find_most_common_elements(input_list):
20 | counter = Counter(input_list)
21 | max_count = max(counter.values())
22 | most_common_elements = [element for element, count in counter.items() if count == max_count]
23 | return most_common_elements, max_count
24 |
25 |
26 | def get_most_common_answer(outputs):
27 | outputs = [output for output in outputs if output != None]
28 | if outputs == []:
29 | return None
30 | most_common_elements, max_count = find_most_common_elements(outputs)
31 | most_common_answer = random.choice(most_common_elements)
32 | return most_common_answer
33 |
34 |
35 | def get_cost(model_name, prompt_tokens, completion_tokens):
36 | prompt_tokens = float(prompt_tokens)
37 | completion_tokens = float(completion_tokens)
38 | if "gemini" in model_name:
39 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3
40 | elif model_name == "gpt-3.5-turbo-0613":
41 | cost = prompt_tokens * 1.5 + completion_tokens * 2
42 | elif model_name == "gpt-4o-mini":
43 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
44 | else:
45 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
46 | cost = cost / 10 ** 6
47 | return cost
48 |
49 |
50 | if __name__ == "__main__":
51 | model_names = ["Qwen/Qwen2.5-7B-Instruct",
52 | "meta-llama/Meta-Llama-3-8B-Instruct",
53 | "THUDM/glm-4-9b-chat",
54 | "gemini-1.5-flash",
55 | "gpt-4o-mini",
56 | "microsoft/Phi-3.5-mini-instruct"]
57 |
58 | datasets = ["GSM8K",
59 | "GPQA",
60 | "GSM-Hard",
61 | "MATH",
62 | "MMLU-high_school_physics",
63 | "MMLU-high_school_chemistry",
64 | "MMLU-high_school_biology",
65 | "AIME_2024"]
66 |
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument("--model_name", type=str, default=model_names[0])
69 | parser.add_argument("--dataset", type=str, default=datasets[0])
70 | parser.add_argument("--split", type=str, default="test")
71 | parser.add_argument("--shuffle", type=bool, default=True)
72 | args = parser.parse_args()
73 |
74 | sampling_times = [1, 3, 5, 7, 10, 15]
75 | pos = - sampling_times[-1] / 30 * sampling_times[-1]/15
76 |
77 | model_names_formal = {
78 | "Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct",
79 | "Llama-3-8B-Instruct": "LLaMA-3-8B-Instruct",
80 | "glm-4-9b-chat": "GLM-4-9B-Chat",
81 | "gemini-1.5-flash": "Gemini-1.5-Flash",
82 | "gpt-4o-mini": "GPT-4o-mini",
83 | "Phi-3.5-mini-instruct": "Phi-3.5-mini-Instruct"
84 | }
85 |
86 | # Maximum sampling time for each LLM, in order to avoid loading files that are recording the current runnning results.
87 | N_dict = {
88 | "Qwen2.5-7B-Instruct": 16,
89 | "Llama-3-8B-Instruct": 16,
90 | "glm-4-9b-chat": 16,
91 | "gemini-1.5-flash": 16,
92 | "gpt-4o-mini": 16,
93 | "Phi-3.5-mini-instruct": 16
94 | }
95 |
96 | marker_dict = {
97 | 'DiP': '^',
98 | 'CoT': '^',
99 | 'L2M': '^',
100 | 'ToT': 'o',
101 | 'S-RF':'o',
102 | 'SBP': '^',
103 | 'AnP': '^',
104 | 'MAD': 'o'
105 | }
106 |
107 |
108 | if "/" in args.model_name:
109 | args.model_name = args.model_name.split('/')[-1]
110 | N_max = N_dict[args.model_name]
111 |
112 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, "pics")):
113 | os.mkdir(os.path.join(log_path, args.dataset, args.model_name, "pics"))
114 |
115 | if "/" in args.model_name:
116 | args.model_name = args.model_name.split('/')[-1]
117 |
118 | path = os.path.join(log_path, args.dataset, f"{args.dataset}_N.xlsx")
119 | if os.path.exists(path):
120 | wb = load_workbook(path)
121 | else:
122 | wb = Workbook()
123 |
124 | unique_labels = []
125 |
126 | headers = ["Method", "Subject", "x-shot", "num", "prompt_tokens", "completion_tokens", "cost", "tokens", "accuracy"]
127 |
128 | if args.model_name in wb.sheetnames:
129 | ws_new = wb[args.model_name]
130 | # ws_new.delete_rows(1, ws_new.max_row)
131 | # ws_new.delete_cols(1, ws_new.max_column)
132 | else:
133 | ws_new = wb.create_sheet(args.model_name)
134 | ws_new.append(headers)
135 |
136 | reasonings = ["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT_3", "ToT_5", "ToT_10", "MAD"]
137 |
138 | try:
139 | model_name = args.model_name
140 | sheet = wb.get_sheet_by_name(model_name)
141 |
142 | tokens = []
143 | accuracy = []
144 | cost = []
145 | labels = []
146 | flag = 0
147 | for row in sheet.iter_rows(values_only=True):
148 | if flag == 0:
149 | flag = 1
150 | continue
151 | # if row[0] != "tot-io":
152 | labels.append(row[0])
153 | tokens.append(int(row[-2]))
154 | accuracy.append(float(row[-1]))
155 | cost.append(float(row[-3]))
156 | counter = Counter(labels)
157 | assert counter["DiP"] >= N_max
158 | assert counter["CoT"] >= N_max
159 | assert counter["L2M"] >= N_max
160 | assert counter["SBP"] >= N_max
161 | assert counter["AnP"] >= N_max
162 | # assert counter["ToT"] >= 3
163 | # assert counter["S-RF"] > 3
164 | # assert counter["MAD"] > 3
165 | except:
166 | for reasoning in reasonings:
167 | try:
168 | if reasoning == "DiP":
169 | args.reasoning = "DiP"
170 | args.shot = 0
171 | for N in range(0, N_max+1):
172 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
173 | break
174 | # N = N - 1
175 | elif reasoning == "CoT":
176 | args.reasoning = "CoT"
177 | args.shot = 0
178 | for N in range(0, N_max+1):
179 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
180 | break
181 | # N = N - 1
182 | elif reasoning == "L2M":
183 | args.reasoning = "L2M"
184 | if args.dataset in ["GSM8K", "GSM-Hard"]:
185 | args.shot = 1
186 | else:
187 | args.shot = 0
188 | for N in range(0, N_max+1):
189 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
190 | break
191 | # N = N - 1
192 | elif reasoning == "ToT_3":
193 | args.reasoning = "ToT"
194 | args.shot = 3
195 | N = 1
196 | elif reasoning == "ToT_5":
197 | args.reasoning = "ToT"
198 | args.shot = 5
199 | N = 1
200 | elif reasoning == "ToT_10":
201 | args.reasoning = "ToT"
202 | args.shot = 10
203 | N = 1
204 | elif reasoning == "S-RF":
205 | args.reasoning = reasoning
206 | args.shot = 0
207 | N = (len(logs[-1]["record"]) - 1) // 2
208 |
209 | elif reasoning == "SBP":
210 | args.reasoning = "SBP"
211 | args.shot = 0
212 | for N in range(0, N_max+1):
213 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
214 | break
215 | # N = N - 1
216 | elif reasoning == "AnP":
217 | args.reasoning = "AnP"
218 | args.shot = 1
219 | for N in range(0, N_max+1):
220 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")):
221 | break
222 | elif reasoning == "MAD":
223 | args.reasoning = "MAD"
224 | args.shot = 0
225 | args.n = 0
226 | logs = read_logs(args)
227 | N = len(logs[0]["record"])
228 |
229 | logs_list = []
230 | assert N != 0
231 | for m in range(0, N):
232 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP", "SBP"]:
233 | args.nums = range(0,m+1)
234 | args.n = m
235 | logs = read_logs(args)
236 | logs_list.append(logs)
237 | if (m+1) not in sampling_times:
238 | continue
239 | elif args.reasoning in ["S-RF", "MAD"]:
240 | args.n = 0
241 | args.nums = range(0, m+1)
242 | logs = read_logs(args)
243 | logs_list.append(logs)
244 | elif args.reasoning == "ToT":
245 | args.nums = range(0, m+1)
246 | for n in range(0, 5):
247 | try:
248 | args.n = n
249 | logs = read_logs(args)
250 | logs_list.append(logs)
251 | except:
252 | break
253 |
254 | accs = []
255 | prompt_tokens_ = 0
256 | completion_tokens_ = 0
257 | l = len(logs_list[0])
258 | for count in range(0, 5):
259 | if args.shuffle:
260 | random.shuffle(logs_list)
261 | acc_num = 0
262 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"]:
263 | subject = "mathematic"
264 | for j in range(0, l):
265 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset:
266 | subject = logs_list[0][j]["subject"]
267 | key = logs_list[0][j]['key']
268 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP"]:
269 | output_keys = [parse_answer(args, log[j]['record']['output']) for log in logs_list]
270 | output_keys = [output for output in output_keys if output != None]
271 | if len(output_keys) == 0:
272 | output_keys = [None]
273 |
274 | if count == 0:
275 | prompt_tokens = [log[j]['record']['usage']['prompt_tokens'] for log in logs_list[:1]]
276 | completion_tokens = [log[j]['record']['usage']['completion_tokens'] for log in logs_list]
277 |
278 | elif args.reasoning == "S-RF":
279 | output_keys = [parse_answer(args, logs[j]["record"][f"output{m+1}"]["output"])]
280 | if count == 0:
281 | prompt_tokens = [logs[j]["record"]["output0"]['usage']["prompt_tokens"]]
282 | completion_tokens = [logs[j]["record"]["output0"]['usage']["completion_tokens"]]
283 | for k in range(0, m+1):
284 | prompt_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["prompt_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["prompt_tokens"]]
285 | completion_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["completion_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["completion_tokens"]]
286 |
287 | elif args.reasoning == "ToT":
288 | indexes = []
289 | for logs in logs_list:
290 | index = parse_best_solution(logs[j]["record"]["choose"]["output"])
291 | if index != None and "0" < index and index <= str(args.shot):
292 | index = int(index) - 1
293 | else:
294 | index = random.choice(range(0, args.shot))
295 | indexes.append(index)
296 | solutions = logs_list[0][j]["record"]["solutions"]
297 | index = get_most_common_answer(indexes)
298 | best_solution = solutions[index]
299 | output_keys = [parse_answer(args, best_solution["output"])]
300 | if count == 0:
301 | prompt_tokens = []
302 | completion_tokens = []
303 | for k in range(0, len(solutions)):
304 | solution = solutions[k]
305 | if k == 0:
306 | prompt_tokens.append(solution["usage"]["prompt_tokens"])
307 | completion_tokens.append(solution["usage"]["completion_tokens"])
308 | for ii, logs in enumerate(logs_list):
309 | if ii == 0:
310 | prompt_tokens.append(logs[j]["record"]["choose"]["usage"]["prompt_tokens"])
311 | completion_tokens.append(logs[j]["record"]["choose"]["usage"]["completion_tokens"])
312 |
313 | elif args.reasoning == "SBP":
314 | output_keys = [parse_answer(args, log[j]['record']['solution']["output"]) for log in logs_list]
315 | output_keys = [output for output in output_keys if output != None]
316 | if len(output_keys) == 0:
317 | output_keys = [None]
318 | if count == 0:
319 | prompt_tokens = []
320 | completion_tokens = []
321 | for k in range(0, len(logs_list)):
322 | log = logs_list[k][j]
323 | if k == 0:
324 | prompt_tokens.append(log["record"]["principles"]["usage"]["prompt_tokens"])
325 | completion_tokens.append(log["record"]["principles"]["usage"]["completion_tokens"])
326 | prompt_tokens.append(log["record"]["solution"]["usage"]["prompt_tokens"])
327 | completion_tokens.append(log["record"]["solution"]["usage"]["completion_tokens"])
328 |
329 | elif args.reasoning == "MAD":
330 | output_keys = [parse_answer(args, log["output"]) for log in logs[j]["record"][f"round{m+1}"]]
331 | output_keys = [output for output in output_keys if output != None]
332 | if len(output_keys) == 0:
333 | output_keys = [None]
334 | if count == 0:
335 | prompt_tokens = []
336 | completion_tokens = []
337 | for k in range(0, m+1):
338 | prompt_tokens += [log["usage"]["prompt_tokens"] for log in logs[j]["record"][f"round{k+1}"]]
339 | completion_tokens += [log["usage"]["completion_tokens"] for log in logs[j]["record"][f"round{k+1}"]]
340 | if count == 0:
341 | prompt_tokens_ += sum(prompt_tokens)
342 | completion_tokens_ += sum(completion_tokens)
343 |
344 | most_common_elements, max_count = find_most_common_elements(output_keys)
345 | output_key = random.choice(most_common_elements)
346 |
347 | if args.dataset in ["GSM8K"]:
348 | if output_key != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output_key) < 10**(-4):
349 | acc_num += 1
350 | elif args.dataset in ["GSM-Hard"]:
351 | if output_key != None and abs(float(key) - output_key) < 10**(-4):
352 | acc_num += 1
353 | elif args.dataset in ["MATH", "AIME_2024"]:
354 | if is_equiv(str(key), output_key):
355 | acc_num += 1
356 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset:
357 | if key == output_key:
358 | acc_num += 1
359 | acc = acc_num / l * 100
360 | accs.append(acc)
361 |
362 | acc = sum(accs)/len(accs)
363 | total_tokens = prompt_tokens_ + completion_tokens_
364 | if args.reasoning in ["ToT"]:
365 | cost = args.shot
366 | else:
367 | cost = m + 1
368 |
369 |
370 | s = f"{args.reasoning.ljust(15)} " + "{}-shot".format(args.shot).ljust(7) + f" {str(args.nums).ljust(12)} prompt_tokens: {str(prompt_tokens_).ljust(10)} completion_tokens: {str(completion_tokens_).ljust(10)} cost:" + str("%.12f"%(cost)).ljust(20) + f"tokens: {str(total_tokens).ljust(11)}" " Acc: " + "%.4f"%acc
371 | print(s)
372 |
373 | ws_new.append([args.reasoning, subject, str(args.shot), str(args.nums[-1]+1), str(prompt_tokens_), str(completion_tokens_), str("%.15f"%(cost)), str(total_tokens), str("%.4f"%acc)])
374 |
375 | except Exception as e:
376 | print(args.reasoning, "error")
377 |
378 | wb.save(path)
379 |
380 | labels_num = {
381 | "DiP": 0,
382 | "CoT": 0,
383 | "L2M": 0,
384 | "ToT": 0,
385 | "S-RF": 0,
386 | "AnP": 0,
387 | "SBP": 0,
388 | "MAD": 0,
389 | }
390 |
391 | model_name = args.model_name
392 | sheet = wb.get_sheet_by_name(model_name)
393 |
394 | tokens = []
395 | accuracy = []
396 | cost = []
397 | labels = []
398 | flag = 0
399 | for row in sheet.iter_rows(values_only=True):
400 | if flag == 0:
401 | flag = 1
402 | continue
403 | name = row[0]
404 | if labels_num[row[0]] >= N_max:
405 | continue
406 | labels.append(name)
407 | labels_num[row[0]] += 1
408 | tokens.append(int(row[-2]))
409 | accuracy.append(float(row[-1]))
410 | cost.append(float(row[-3]))
411 |
412 | unique_labels = ["DiP", "CoT", "L2M", "ToT", "S-RF", "SBP", "AnP", "MAD"]
413 | wb.close()
414 |
415 | plt.figure(figsize=(6, 3))
416 | ax = plt.gca()
417 |
418 | colors = np.array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
419 | [2.17647059e-01, 4.29120609e-01, 9.75511968e-01, 1.00000000e+00],
420 | [17.25490196e-02, 8.82927610e-01, 1, 1.00000000e+00],
421 | [3.54901961e-01, 9.74138602e-01, 7.82927610e-01, 1.00000000e+00],
422 | [6.45098039e-01, 9.74138602e-01, 6.22112817e-01, 1.00000000e+00],
423 | [1, 7.82927610e-01, 5.34676422e-01, 1.00000000e+00],
424 | [1.00000000e+00, 4.29120609e-01, 2.19946358e-01, 1.00000000e+00],
425 | [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00],])
426 |
427 | for i, label in enumerate(unique_labels):
428 | mask = [label == lb for lb in labels]
429 | plt.scatter([cost[idx] for idx, val in enumerate(mask) if val],
430 | [accuracy[idx] for idx, val in enumerate(mask) if val],
431 | color=colors[i], label=label, s=36, marker=marker_dict[label], edgecolor='black', linewidths=0.75, zorder=5)
432 | plt.plot([cost[idx] for idx, val in enumerate(mask) if val],
433 | [accuracy[idx] for idx, val in enumerate(mask) if val],
434 | color=colors[i], linewidth=1.5, marker=marker_dict[label])
435 |
436 | x_ticks = sampling_times
437 | y_max = max(accuracy)
438 | y_min = min(accuracy)
439 | ind_line = (y_max - y_min) / 7.5
440 | plt.ylim(bottom=None, top=y_max + ind_line * 1.5)
441 |
442 | def custom_formatter(y, pos):
443 | if y > max(accuracy) + 0.9:
444 | return ""
445 | else:
446 | return f"{int(y)}"
447 |
448 | for x in x_ticks:
449 | plt.axvline(x=x, color='gray', linestyle='--', alpha=0.5, zorder=0)
450 | mask = [cost[idx] == x for idx in range(len(cost))]
451 | if any(mask):
452 | best_idx = np.argmax([accuracy[idx] for idx, val in enumerate(mask) if val])
453 | best_label = [labels[idx] for idx, val in enumerate(mask) if val][best_idx]
454 | plt.scatter(x, y_max + ind_line, color=colors[unique_labels.index(best_label)],
455 | marker=marker_dict[best_label], s=64, zorder=8, edgecolor='black', linewidths=1)
456 |
457 | plt.axhline(y=y_max + ind_line / 2, color='k', linestyle='-', alpha=0.5, zorder=10, linewidth=1.5)
458 | plt.rcParams['xtick.labelsize'] = 14
459 | plt.rcParams['ytick.labelsize'] = 14
460 | plt.rcParams['font.weight'] = 'bold'
461 | plt.rcParams['axes.labelweight'] = 'bold'
462 | plt.rcParams['axes.titleweight'] = 'bold'
463 | plt.title(f"{model_names_formal[args.model_name]}", fontsize=18)
464 | plt.xlabel('Sampling Time', fontsize=14, fontweight='bold')
465 | plt.ylabel('Accuracy', fontsize=14, fontweight='bold')
466 |
467 | plt.xticks(x_ticks)
468 | yticks = ax.get_yticks()
469 | yticks = yticks[yticks <= max(accuracy)]
470 |
471 | ax.set_yticks(yticks)
472 | ax.set_yticklabels([str(int(tick)) for tick in yticks])
473 | ax.yaxis.set_major_locator(MaxNLocator(integer=True))
474 | ax.yaxis.set_major_formatter(FuncFormatter(custom_formatter))
475 |
476 | plt.text(pos, y_max + ind_line, r'$\mathbf{P}_{N}^*$', fontsize=12, ha='center', va='center')
477 | plt.rcParams['font.weight'] = 'bold'
478 | plt.rcParams['axes.labelweight'] = 'bold'
479 | plt.rcParams['axes.titleweight'] = 'bold'
480 |
481 | plt.legend(loc='best', fontsize=12, framealpha=0.9, ncol=2, markerscale=1.5, handletextpad=0.5, columnspacing=1.0)
482 |
483 | plt.savefig(os.path.join(log_path, args.dataset, args.model_name, "pics", f"Performance_N.png"), bbox_inches='tight', dpi = 600)
484 | plt.close()
485 | print(f"Performance_N.png saved to {os.path.join(log_path, args.dataset, args.model_name, 'pics', 'Performance_N.png')}")
486 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from prompts import GSM8K, GPQA, GSM_Hard, MATH, MMLU, AIME
3 |
4 | import os
5 | import json
6 | import re
7 | import random
8 | from collections import Counter
9 |
10 |
11 | base_path = f"xxx/xxx/.../rethinking_prompting"
12 | log_path = os.path.join(base_path, "logs")
13 |
14 |
15 | def get_cost(model_name, prompt_tokens, completion_tokens):
16 | prompt_tokens = float(prompt_tokens)
17 | completion_tokens = float(completion_tokens)
18 | if "gemini" in model_name:
19 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3
20 | elif model_name == "gpt-3.5-turbo-0613":
21 | cost = prompt_tokens * 1.5 + completion_tokens * 2
22 | elif model_name == "gpt-4o-mini":
23 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
24 | else:
25 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6
26 | cost = cost / 10 ** 6
27 | return cost
28 |
29 |
30 | def find_most_common_elements(outputs):
31 | outputs = [output for output in outputs if output != None]
32 | if outputs == []:
33 | return None
34 | counter = Counter(outputs)
35 | max_count = max(counter.values())
36 | most_common_elements = [element for element, count in counter.items() if count == max_count]
37 | return most_common_elements, max_count
38 |
39 |
40 | def get_unique_most_common_answer(outputs):
41 | outputs = [output for output in outputs if output != None]
42 | if outputs == []:
43 | return None
44 | most_common_elements, max_count = find_most_common_elements(outputs)
45 | most_common_answer = random.choice(most_common_elements)
46 | return most_common_answer
47 |
48 |
49 | def load_GPQA_examples(dataset_list, seed: int):
50 | random.seed(seed)
51 |
52 | def shuffle_choices_and_create_example(row):
53 | list_choices = [row["Incorrect Answer 1"], row["Incorrect Answer 2"], row["Incorrect Answer 3"], row["Correct Answer"]]
54 | random.shuffle(list_choices)
55 |
56 | example = {}
57 | example["problem"] = row["Question"]
58 | example["subject"] = row["High-level domain"]
59 | example["choices"] = list_choices
60 | example["answer"] = chr(list_choices.index(row["Correct Answer"]) + 65)
61 | return example
62 |
63 | return [shuffle_choices_and_create_example(row) for row in dataset_list]
64 |
65 |
66 | def read_dataset(args):
67 | dataset = args.dataset
68 | if dataset == "GSM8K":
69 | ds = load_dataset("openai/gsm8k", "main")[args.split]
70 | dataset_list = [d for d in ds]
71 | elif dataset == "GSM-Hard":
72 | ds = load_dataset("reasoning-machines/gsm-hard")
73 | dataset_list = [d for d in ds["train"]]
74 | elif dataset == "MATH":
75 | ds = load_dataset("HuggingFaceH4/MATH-500")[args.split]
76 | dataset_list = [d for d in ds]
77 | elif dataset == "GPQA":
78 | ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
79 | dataset_list = load_GPQA_examples([d for d in ds], args.seed)
80 | elif "MMLU" in dataset:
81 | subject = dataset.split("-")[-1]
82 | dataset_list = load_dataset("cais/mmlu", subject)[args.split]
83 | elif dataset == "AIME_2024":
84 | os.environ["HF_DATASETS_OFFLINE"] = "1"
85 | from modelscope.msdatasets import MsDataset
86 | ds = MsDataset.load("AI-ModelScope/AIME_2024", subset_name="default", split="train")
87 | dataset_list = [d for d in ds]
88 | return dataset_list
89 |
90 |
91 | def examine_output(dataset, output, key):
92 | if dataset in ["GSM8K"]:
93 | if output != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output) < 10**(-4):
94 | return True
95 | elif dataset in ["GSM-Hard"]:
96 | if output != None and abs(float(key) - output) < 10**(-4):
97 | return True
98 | elif dataset in ["MATH", "AIME_2024"]:
99 | if is_equiv(str(key), output):
100 | return True
101 | elif dataset in ["GPQA", "CommonSenseQA"] or "MMLU" in dataset:
102 | if key == output:
103 | return True
104 | return False
105 |
106 |
107 | def record_logs(logs, args):
108 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name)):
109 | os.makedirs(os.path.join(log_path, args.dataset, args.model_name))
110 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json")
111 | with open(path, "w") as f:
112 | json.dump(logs, f, indent = 4)
113 |
114 |
115 | def record_a_logs(logs, args):
116 | if not os.path.exists(os.path.join(log_path, args.dataset)):
117 | os.mkdir(os.path.join(log_path, args.dataset))
118 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name)):
119 | os.mkdir(os.path.join(log_path, args.dataset, args.model_name))
120 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json")
121 | with open(path, "a") as f:
122 | json.dump(logs, f, indent = 4)
123 |
124 |
125 | def read_logs(args):
126 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json")
127 | if os.path.exists(path):
128 | with open (path, "r") as f:
129 | logs = json.loads(f.read())
130 | else:
131 | logs = []
132 | return logs
133 |
134 |
135 | def construct_message(args, agents, question, idx):
136 |
137 | prefix_string = "These are the answers to the question from other agents: "
138 |
139 | for agent in agents:
140 | agent_response = agent[idx]
141 | response = "\n\n One agent answer: ```{}```".format(agent_response)
142 |
143 | prefix_string = prefix_string + response
144 |
145 | if args.dataset in ["GSM8K", "GSM-Hard"]:
146 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}.".format(question) + " Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response."
147 | elif args.dataset in ["GPQA"]:
148 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the problem? \n The original problem is {}. ".format(question) + GPQA.prompt_format
149 | elif "MMLU" in args.dataset:
150 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the problem? \n The original problem is {}. ".format(question) + MMLU.prompt_format
151 | elif args.dataset in ["MATH", "AIME_2024"]:
152 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}. ".format(question) + MATH.prompt_format
153 | else:
154 | raise ValueError(f"{args.dataset} is not included in MAD!")
155 |
156 | return prefix_string
157 |
158 |
159 | def create_prompt(args, index = None):
160 | prompts = []
161 | if index != None:
162 | range_ = [index]
163 | else:
164 | range_ = range(0, len(args.questions))
165 |
166 | for i in range_:
167 | args.question = args.questions[i]
168 | if args.dataset == "GSM8K":
169 | if args.reasoning == "DiP":
170 | prompt = GSM8K.io.replace("{question}", args.question)
171 | elif args.reasoning == "CoT":
172 | assert args.shot in [0, 1, 5]
173 | if args.shot == 0:
174 | prompt = GSM8K.cot_0_shot.replace("{question}", args.question)
175 | elif args.shot == 1:
176 | prompt = GSM8K.cot_1_shot.replace("{question}", args.question)
177 | elif args.shot == 5:
178 | prompt = GSM8K.cot_5_shot.replace("{question}", args.question)
179 | elif args.reasoning == "L2M":
180 | if args.shot == 1:
181 | prompt = GSM8K.Least_to_Most_1_shot.replace("{question}", args.question)
182 | elif args.reasoning == "ToT":
183 | if args.shot == 3:
184 | prompt = GSM8K.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GSM8K.tot_post
185 | elif args.shot == 5:
186 | prompt = GSM8K.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GSM8K.tot_post
187 | elif args.shot == 10:
188 | prompt = GSM8K.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GSM8K.tot_post
189 | elif args.reasoning == "AnP":
190 | if args.shot == 1:
191 | prompt = GSM8K.anologous_1_prompt.replace("{question}", args.question)
192 | elif args.shot == 3:
193 | prompt = GSM8K.anologous_3_prompt.replace("{question}", args.question)
194 | elif args.shot == 5:
195 | prompt = GSM8K.anologous_5_prompt.replace("{question}", args.question)
196 | elif args.reasoning == "SBP":
197 | if "principles" not in args:
198 | prompt = GSM8K.SBP_extract.replace("{question}", args.question)
199 | else:
200 | prompt = GSM8K.SBP_answer.replace("{question}", args.question)
201 | prompt = prompt.replace("{principles}", args.principles)
202 |
203 | elif args.dataset == "GSM-Hard":
204 | if args.reasoning == "DiP":
205 | prompt = GSM_Hard.io.replace("{question}", args.question)
206 | elif args.reasoning == "CoT":
207 | assert args.shot in [0, 1, 5]
208 | if args.shot == 0:
209 | prompt = GSM_Hard.cot_0_shot.replace("{question}", args.question)
210 | elif args.shot == 1:
211 | prompt = GSM_Hard.cot_1_shot.replace("{question}", args.question)
212 | elif args.shot == 5:
213 | prompt = GSM_Hard.cot_5_shot.replace("{question}", args.question)
214 | elif args.reasoning == "L2M":
215 | if args.shot == 1:
216 | prompt = GSM_Hard.Least_to_Most_1_shot.replace("{question}", args.question)
217 | elif args.reasoning == "ToT":
218 | if args.shot == 3:
219 | prompt = GSM_Hard.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GSM_Hard.tot_post
220 | elif args.shot == 5:
221 | prompt = GSM_Hard.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GSM_Hard.tot_post
222 | elif args.shot == 10:
223 | prompt = GSM_Hard.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GSM_Hard.tot_post
224 | elif args.reasoning == "AnP":
225 | if args.shot == 1:
226 | prompt = GSM_Hard.anologous_1_prompt.replace("{question}", args.question)
227 | elif args.shot == 3:
228 | prompt = GSM_Hard.anologous_3_prompt.replace("{question}", args.question)
229 | elif args.shot == 5:
230 | prompt = GSM_Hard.anologous_5_prompt.replace("{question}", args.question)
231 | elif args.reasoning == "SBP":
232 | if "principles" not in args:
233 | prompt = GSM_Hard.SBP_extract.replace("{question}", args.question)
234 | else:
235 | prompt = GSM_Hard.SBP_answer.replace("{question}", args.question)
236 | prompt = prompt.replace("{principles}", args.principles)
237 |
238 | elif args.dataset == "GPQA":
239 | args.choices = args.choiceses[i]
240 | if args.reasoning == "DiP":
241 | prompt = GPQA.io.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
242 | elif args.reasoning == "CoT":
243 | assert args.shot in [0, 1, 5]
244 | if args.shot == 0:
245 | prompt = GPQA.cot_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
246 | elif args.shot == 1:
247 | prompt = GPQA.cot_1_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
248 | elif args.shot == 5:
249 | prompt = GPQA.cot_5_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
250 | elif args.reasoning == "L2M":
251 | if args.shot == 0:
252 | prompt = GPQA.Least_to_Most_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
253 | elif args.reasoning == "ToT":
254 | if args.shot == 3:
255 | prompt = GPQA.tot_3_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GPQA.tot_post
256 | elif args.shot == 5:
257 | prompt = GPQA.tot_5_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GPQA.tot_post
258 | elif args.shot == 10:
259 | prompt = GPQA.tot_10_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GPQA.tot_post
260 | elif args.reasoning == "AnP":
261 | if args.shot == 1:
262 | prompt = GPQA.anologous_1_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
263 | elif args.shot == 3:
264 | prompt = GPQA.anologous_3_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
265 | elif args.shot == 5:
266 | prompt = GPQA.anologous_5_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
267 | elif args.reasoning == "SBP":
268 | if "principles" not in args:
269 | prompt = GPQA.SBP_extract.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
270 | else:
271 | prompt = GPQA.SBP_answer.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles)
272 |
273 | elif "MMLU" in args.dataset:
274 | args.choices = args.choiceses[i]
275 | if args.reasoning == "DiP":
276 | prompt = MMLU.io.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
277 | elif args.reasoning == "CoT":
278 | assert args.shot in [0, 1, 5]
279 | if args.shot == 0:
280 | prompt = MMLU.cot_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
281 | elif args.reasoning == "L2M":
282 | if args.shot == 0:
283 | prompt = MMLU.Least_to_Most_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
284 | elif args.reasoning == "ToT":
285 | if args.shot == 3:
286 | prompt = MMLU.tot_3_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + MMLU.tot_post
287 | elif args.shot == 5:
288 | prompt = MMLU.tot_5_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + MMLU.tot_post
289 | elif args.shot == 10:
290 | prompt = MMLU.tot_10_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + MMLU.tot_post
291 | elif args.reasoning == "AnP":
292 | if args.shot == 1:
293 | prompt = MMLU.anologous_1_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
294 | elif args.shot == 3:
295 | prompt = MMLU.anologous_3_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
296 | elif args.shot == 5:
297 | prompt = MMLU.anologous_5_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
298 | elif args.reasoning == "SBP":
299 | if "principles" not in args:
300 | if "physics" in args.subjects[i]:
301 | prompt = MMLU.SBP_extract_physics.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
302 | elif "chemistry" in args.subjects[i]:
303 | prompt = MMLU.SBP_extract_chemistry.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
304 | else:
305 | prompt = MMLU.SBP_extract.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3])
306 | else:
307 | if "physics" in args.subjects[i]:
308 | prompt = MMLU.SBP_answer_physics.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles)
309 | elif "chemistry" in args.subjects[i]:
310 | prompt = MMLU.SBP_answer_chemistry.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles)
311 | else:
312 | prompt = MMLU.SBP_answer.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles)
313 |
314 | elif args.dataset == "MATH":
315 | if args.reasoning == "DiP":
316 | prompt = MATH.io.replace("{question}", args.question)
317 | elif args.reasoning == "CoT":
318 | assert args.shot in [0, 1, 5]
319 | if args.shot == 0:
320 | prompt = MATH.cot_0_shot.replace("{question}", args.question)
321 | elif args.shot == 1:
322 | prompt = MATH.cot_1_shot.replace("{question}", args.question)
323 | elif args.shot == 5:
324 | prompt = MATH.cot_5_shot.replace("{question}", args.question)
325 | elif args.reasoning == "L2M":
326 | if args.shot == 0:
327 | prompt = MATH.Least_to_Most_0_shot.replace("{question}", args.question)
328 | elif args.reasoning == "ToT":
329 | if args.shot == 3:
330 | prompt = MATH.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + MATH.tot_post
331 | elif args.shot == 5:
332 | prompt = MATH.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + MATH.tot_post
333 | elif args.shot == 10:
334 | prompt = MATH.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + MATH.tot_post
335 | elif args.reasoning == "AnP":
336 | if args.shot == 1:
337 | prompt = MATH.anologous_1_prompt.replace("{question}", args.question)
338 | elif args.shot == 3:
339 | prompt = MATH.anologous_3_prompt.replace("{question}", args.question)
340 | elif args.shot == 5:
341 | prompt = MATH.anologous_5_prompt.replace("{question}", args.question)
342 | elif args.reasoning == "SBP":
343 | if "principles" not in args:
344 | prompt = MATH.SBP_extract.replace("{question}", args.question)
345 | else:
346 | prompt = MATH.SBP_answer.replace("{question}", args.question)
347 | prompt = prompt.replace("{principles}", args.principles)
348 |
349 | elif args.dataset == "AIME_2024":
350 | if args.reasoning == "DiP":
351 | prompt = AIME.io.replace("{question}", args.question)
352 | elif args.reasoning == "CoT":
353 | assert args.shot in [0, 1, 5]
354 | if args.shot == 0:
355 | prompt = AIME.cot_0_shot.replace("{question}", args.question)
356 | elif args.reasoning == "L2M":
357 | if args.shot == 0:
358 | prompt = AIME.Least_to_Most_0_shot.replace("{question}", args.question)
359 | elif args.reasoning == "ToT":
360 | if args.shot == 3:
361 | prompt = AIME.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + AIME.tot_post
362 | elif args.shot == 5:
363 | prompt = AIME.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + AIME.tot_post
364 | elif args.shot == 10:
365 | prompt = AIME.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + AIME.tot_post
366 | elif args.reasoning == "AnP":
367 | if args.shot == 1:
368 | prompt = AIME.anologous_1_prompt.replace("{question}", args.question)
369 | elif args.shot == 3:
370 | prompt = AIME.anologous_3_prompt.replace("{question}", args.question)
371 | elif args.shot == 5:
372 | prompt = AIME.anologous_5_prompt.replace("{question}", args.question)
373 | elif args.reasoning == "SBP":
374 | if "principles" not in args:
375 | prompt = AIME.SBP_extract.replace("{question}", args.question)
376 | else:
377 | prompt = AIME.SBP_answer.replace("{question}", args.question)
378 | prompt = prompt.replace("{principles}", args.principles)
379 | else:
380 | raise ValueError(f"{args.dataset} is not included in def create_prompt!")
381 | prompts.append(prompt)
382 |
383 | return prompts
384 |
385 |
386 | def _strip_string(string):
387 | # linebreaks
388 | string = string.replace("\n", "")
389 | #print(string)
390 |
391 | # remove inverse spaces
392 | string = string.replace("\\!", "")
393 | #print(string)
394 |
395 | # replace \\ with \
396 | string = string.replace("\\\\", "\\")
397 | #print(string)
398 |
399 | # replace tfrac and dfrac with frac
400 | string = string.replace("tfrac", "frac")
401 | string = string.replace("dfrac", "frac")
402 | #print(string)
403 |
404 | # remove \left and \right
405 | string = string.replace("\\left", "")
406 | string = string.replace("\\right", "")
407 | #print(string)
408 |
409 | # Remove circ (degrees)
410 | string = string.replace("^{\\circ}", "")
411 | string = string.replace("^\\circ", "")
412 |
413 | # remove dollar signs
414 | string = string.replace("\\$", "")
415 |
416 | # remove units (on the right)
417 | string = _remove_right_units(string)
418 |
419 | # remove percentage
420 | string = string.replace("\\%", "")
421 | string = string.replace("\%", "")
422 |
423 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
424 | string = string.replace(" .", " 0.")
425 | string = string.replace("{.", "{0.")
426 |
427 | # 6+9j -> 6+9i
428 | string = string.replace("j", "i")
429 | # if empty, return empty string
430 |
431 | if len(string) == 0:
432 | return string
433 | if string[0] == ".":
434 | string = "0" + string
435 |
436 | # to consider: get rid of e.g. "k = " or "q = " at beginning
437 | if len(string.split("=")) == 2:
438 | if len(string.split("=")[0]) <= 2:
439 | string = string.split("=")[1]
440 |
441 | # fix sqrt3 --> sqrt{3}
442 | string = _fix_sqrt(string)
443 |
444 | # remove spaces
445 | string = string.replace(" ", "")
446 |
447 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
448 | string = _fix_fracs(string)
449 |
450 | # manually change 0.5 --> \frac{1}{2}
451 | if string == "0.5":
452 | string = "\\frac{1}{2}"
453 |
454 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
455 | string = _fix_a_slash_b(string)
456 |
457 | return string
458 |
459 |
460 | def replace_sqrt_with_power(s):
461 | pattern = r"\\sqrt\{([^}]+)\}"
462 | return re.sub(pattern, r"(\1)**0.5", s)
463 |
464 |
465 | def replace_pi(s):
466 | # result = re.sub(r"(? 1:
487 | substrs = substrs[1:]
488 | for substr in substrs:
489 | new_str += "\\frac"
490 | if substr[0] == "{":
491 | new_str += substr
492 | else:
493 | try:
494 | assert len(substr) >= 2
495 | except:
496 | return string
497 | a = substr[0]
498 | b = substr[1]
499 | if b != "{":
500 | if len(substr) > 2:
501 | post_substr = substr[2:]
502 | new_str += "{" + a + "}{" + b + "}" + post_substr
503 | else:
504 | new_str += "{" + a + "}{" + b + "}"
505 | else:
506 | if len(substr) > 2:
507 | post_substr = substr[2:]
508 | new_str += "{" + a + "}" + b + post_substr
509 | else:
510 | new_str += "{" + a + "}" + b
511 | string = new_str
512 | return string
513 |
514 |
515 | def _fix_a_slash_b(string):
516 | if len(string.split("/")) != 2:
517 | return string
518 | a = string.split("/")[0]
519 | b = string.split("/")[1]
520 | try:
521 | a = int(a)
522 | b = int(b)
523 | assert string == "{}/{}".format(a, b)
524 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
525 | return new_string
526 | except:
527 | return string
528 |
529 |
530 | def _remove_right_units(string):
531 | # "\\text{ " only ever occurs (at least in the val set) when describing units
532 | if "\\text{ " in string:
533 | splits = string.split("\\text{ ")
534 | assert len(splits) == 2
535 | return splits[0]
536 | else:
537 | return string
538 |
539 |
540 | def _fix_sqrt(string):
541 | if "\\sqrt" not in string:
542 | return string
543 | splits = string.split("\\sqrt")
544 | new_string = splits[0]
545 | for split in splits[1:]:
546 | if split[0] != "{":
547 | a = split[0]
548 | new_substr = "\\sqrt{" + a + "}" + split[1:]
549 | else:
550 | new_substr = "\\sqrt" + split
551 | new_string += new_substr
552 | return new_string
553 |
554 |
555 | def is_equiv(str1, str2, verbose=False):
556 | if str1 is None and str2 is None:
557 | print("WARNING: Both None")
558 | return True
559 | if str1 is None or str2 is None:
560 | return False
561 |
562 | try:
563 | ss1 = _strip_string(str1)
564 | ss2 = _strip_string(str2)
565 | if verbose:
566 | print(ss1, ss2)
567 | return ss1 == ss2
568 | except:
569 | return str1 == str2
570 |
571 |
572 | def last_boxed_only(sample):
573 | """
574 | Given a (q,a) sample, filter the answers so that they only contain
575 | the last \boxed{...} or \fbox{...} element
576 | """
577 | q, a = sample
578 | a = last_boxed_only_string(a)
579 | if a == None:
580 | return None
581 | return (q, a)
582 |
583 |
584 | def last_boxed_only_string(string):
585 | idx = string.rfind("\\boxed")
586 | if idx < 0:
587 | idx = string.rfind("\\fbox")
588 | if idx < 0:
589 | return None
590 |
591 | i = idx
592 | right_brace_idx = None
593 | num_left_braces_open = 0
594 | while i < len(string):
595 | if string[i] == "{":
596 | num_left_braces_open += 1
597 | if string[i] == "}":
598 | num_left_braces_open -= 1
599 | if num_left_braces_open == 0:
600 | right_brace_idx = i
601 | break
602 | i += 1
603 |
604 | if right_brace_idx == None:
605 | retval = None
606 | else:
607 | retval = string[idx:right_brace_idx + 1]
608 |
609 | return retval
610 |
611 |
612 | def remove_boxed(s):
613 | left = "\\boxed{"
614 | try:
615 | assert s[:len(left)] == left
616 | assert s[-1] == "}"
617 | return s[len(left):-1]
618 | except:
619 | return None
620 |
621 |
622 | def parse_answer(args, input_str):
623 | solution = None
624 | if args.dataset in ["GSM8K", "GSM-Hard"]:
625 | pattern = r"boxed\{(.*?)\}"
626 | matches = re.findall(pattern, input_str)
627 |
628 | for match_str in matches[::-1]:
629 | match_str = match_str.split("=")[-1]
630 | if "boxed" not in match_str:
631 | solution = re.sub(r"[^0-9.-]", "", match_str)
632 | else:
633 | solution = parse_answer(args, match_str)
634 | if solution:
635 | break
636 |
637 | if solution == None or solution == "":
638 | pattern = r"boxed\{(.*)\}"
639 | matches = re.findall(pattern, input_str)
640 |
641 | for match_str in matches[::-1]:
642 | if "boxed" not in match_str:
643 | solution = re.sub(r"[^0-9.-]", "", match_str)
644 | else:
645 | solution = parse_answer(args, match_str)
646 | if solution:
647 | break
648 |
649 | if solution == None or solution == "":
650 | pattern = r"\{([0-9 \-.,$]*)\}"
651 | matches = re.findall(pattern, input_str)
652 |
653 | for match_str in matches[::-1]:
654 | solution = re.sub(r"[^0-9.-]", "", match_str)
655 | if solution:
656 | break
657 |
658 | if solution == None or solution == "":
659 | pattern = r"\*\*(.*)\*\*"
660 | matches = re.findall(pattern, input_str)
661 |
662 | for match_str in matches[::-1]:
663 | solution = re.sub(r"[^0-9.-]", "", match_str)
664 | if solution:
665 | break
666 |
667 | if solution == None or solution == "":
668 | matches = re.findall(r"[0-9\-.,$]+", input_str)
669 | for match_str in matches[::-1]:
670 | if re.findall(r"\d+", match_str) != []:
671 | solution = re.sub(r"[^0-9.-]", "", match_str)
672 | if solution[-1] == ".":
673 | solution = solution[:-1]
674 | break
675 | try:
676 | solution = float(solution)
677 | except:
678 | solution = None
679 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset:
680 | answers = re.findall(r"correct answer is \*\*(.*)\*\*", input_str)
681 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset:
682 | letters = ["A", "B", "C", "D"]
683 | for answer in answers[::-1]:
684 | if answer[0] not in letters:
685 | try:
686 | solution = re.search(r"\((.)\)", answer).group(1)[0]
687 | if solution in letters:
688 | return solution
689 | else:
690 | return None
691 | except:
692 | solution = "M"
693 | else:
694 | solution = answer[0]
695 | return solution
696 |
697 | answers = re.findall(r"correct answer is (.?)", input_str)
698 | for answer in answers[::-1]:
699 | if answer[0] not in letters:
700 | try:
701 | solution = re.search(r"correct answer is \((.?)\)", input_str).group(1)[0]
702 | return solution
703 | except:
704 | answer = "M"
705 | else:
706 | solution = answer[0]
707 | return solution
708 | answers = re.findall(r"\((.)\)", input_str)
709 | for answer in answers[::-1]:
710 | if answer[0] in letters:
711 | solution = answer[0]
712 | return solution
713 | answers = re.findall(r"\{(.)\}", input_str)
714 | for answer in answers[::-1]:
715 | if answer[0] in letters:
716 | solution = answer[0]
717 | elif args.dataset in ["MATH", "AIME_2024"]:
718 | return remove_boxed(last_boxed_only_string(input_str))
719 | else:
720 | raise ValueError(f"{args.dataset} is not in def parse_answer!")
721 | return solution
722 |
723 |
724 | def parse_best_solution(input_str):
725 | pattern = r"index of the best solution is (\d+)"
726 | matches = re.findall(pattern, input_str)
727 |
728 | if matches:
729 | return matches[-1]
730 | else:
731 | pattern = r"\*\*(\d+)\*\*"
732 | matches = re.findall(pattern, input_str)
733 |
734 | for match_str in matches[::-1]:
735 | if match_str:
736 | return match_str
737 | return None
738 |
739 |
740 | def parse_best_method(s):
741 | start_str = "most suitable method is "
742 | start_index = s.find(start_str)
743 | if start_index == -1:
744 | return ""
745 | start_index += len(start_str)
746 | end_index = start_index
747 | while end_index < len(s) and s[end_index] not in ".,!?;:\n":
748 | end_index += 1
749 | return s[start_index:end_index].strip()
750 |
751 |
752 | def check_solution_verdict(output_str):
753 | pattern = r"solution is (right|wrong)"
754 | match = re.search(pattern, output_str)
755 |
756 | if match:
757 | return match.group(1)
758 | else:
759 | s = "random_" + random.choice(["right", "wrong"])
760 | return s
761 |
--------------------------------------------------------------------------------