├── image.png ├── requirements.txt ├── LICENSE ├── scripts ├── Qwen_MATH.sh ├── Qwen_GSM8K.sh ├── GPT_GSM8K.sh └── Gemini_GSM8K.sh ├── prompts ├── AIME.py ├── GSM8K.py ├── GSM_Hard.py ├── MATH.py ├── GPQA.py └── MMLU.py ├── readme.md ├── model.py ├── main.py ├── eval_csv_cost.py ├── eval_csv_N.py └── dataset.py /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MraDonkey/rethinking_prompting/HEAD/image.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opentelemetry-proto==1.26.0 2 | pyarrow==18.0.0 3 | multiprocess==0.70.16 4 | numpy==1.26.4 5 | datasets==3.2.0 6 | openai==1.53.0 7 | vllm==0.8.4 8 | protobuf==4.25.8 9 | google-generativeai==0.7.2 10 | openpyxl 11 | matplotlib 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 donkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/Qwen_MATH.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --model_name Qwen/Qwen2.5-7B-Instruct \ 3 | --model_type vllm \ 4 | --split test \ 5 | --dataset MATH \ 6 | --reasoning DiP \ 7 | --shot 0 \ 8 | --batchsize 10 \ 9 | --range_begin 0 \ 10 | --range_end 32 \ 11 | --gpu 0,1 12 | 13 | python main.py \ 14 | --model_name Qwen/Qwen2.5-7B-Instruct \ 15 | --model_type vllm \ 16 | --split test \ 17 | --dataset MATH \ 18 | --reasoning CoT \ 19 | --shot 0 \ 20 | --batchsize 10 \ 21 | --range_begin 0 \ 22 | --range_end 32 \ 23 | --gpu 0,1 24 | 25 | python main.py \ 26 | --model_name Qwen/Qwen2.5-7B-Instruct \ 27 | --model_type vllm \ 28 | --split test \ 29 | --dataset MATH \ 30 | --reasoning L2M \ 31 | --shot 1 \ 32 | --batchsize 10 \ 33 | --range_begin 0 \ 34 | --range_end 32 \ 35 | --gpu 0,1 36 | 37 | python main.py \ 38 | --model_name Qwen/Qwen2.5-7B-Instruct \ 39 | --model_type vllm \ 40 | --split test \ 41 | --dataset MATH \ 42 | --reasoning SBP \ 43 | --shot 0 \ 44 | --batchsize 10 \ 45 | --range_begin 0 \ 46 | --range_end 32 \ 47 | --gpu 0,1 48 | 49 | python main.py \ 50 | --model_name Qwen/Qwen2.5-7B-Instruct \ 51 | --model_type vllm \ 52 | --split test \ 53 | --dataset MATH \ 54 | --reasoning AnP \ 55 | --shot 1 \ 56 | --batchsize 10 \ 57 | --range_begin 0 \ 58 | --range_end 32 \ 59 | --gpu 0,1 60 | 61 | python main.py \ 62 | --model_name Qwen/Qwen2.5-7B-Instruct \ 63 | --model_type vllm \ 64 | --split test \ 65 | --dataset MATH \ 66 | --reasoning ToT \ 67 | --shot 3 \ 68 | --batchsize 10 \ 69 | --range_begin 0 \ 70 | --range_end 5 \ 71 | --gpu 0,1 72 | 73 | python main.py \ 74 | --model_name Qwen/Qwen2.5-7B-Instruct \ 75 | --model_type vllm \ 76 | --split test \ 77 | --dataset MATH \ 78 | --reasoning ToT \ 79 | --shot 5 \ 80 | --batchsize 10 \ 81 | --range_begin 0 \ 82 | --range_end 5 \ 83 | --gpu 0,1 84 | 85 | python main.py \ 86 | --model_name Qwen/Qwen2.5-7B-Instruct \ 87 | --model_type vllm \ 88 | --split test \ 89 | --dataset MATH \ 90 | --reasoning ToT \ 91 | --shot 10 \ 92 | --batchsize 10 \ 93 | --range_begin 0 \ 94 | --range_end 5 \ 95 | --gpu 0,1 96 | 97 | python main.py \ 98 | --model_name Qwen/Qwen2.5-7B-Instruct \ 99 | --model_type vllm \ 100 | --split test \ 101 | --dataset MATH \ 102 | --reasoning S-RF \ 103 | --shot 0 \ 104 | --batchsize 5 \ 105 | --range_begin 0 \ 106 | --range_end 1 \ 107 | --gpu 0,1 108 | 109 | python main.py \ 110 | --model_name Qwen/Qwen2.5-7B-Instruct \ 111 | --model_type vllm \ 112 | --split test \ 113 | --dataset MATH \ 114 | --reasoning MAD \ 115 | --shot 0 \ 116 | --batchsize 1 \ 117 | --range_begin 0 \ 118 | --range_end 1 \ 119 | --gpu 0,1 120 | -------------------------------------------------------------------------------- /scripts/Qwen_GSM8K.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --model_name Qwen/Qwen2.5-7B-Instruct \ 3 | --model_type vllm \ 4 | --split test \ 5 | --dataset GSM8K \ 6 | --reasoning DiP \ 7 | --shot 0 \ 8 | --batchsize 10 \ 9 | --range_begin 0 \ 10 | --range_end 32 \ 11 | --gpu 0,1 12 | 13 | python main.py \ 14 | --model_name Qwen/Qwen2.5-7B-Instruct \ 15 | --model_type vllm \ 16 | --split test \ 17 | --dataset GSM8K \ 18 | --reasoning CoT \ 19 | --shot 0 \ 20 | --batchsize 10 \ 21 | --range_begin 0 \ 22 | --range_end 32 \ 23 | --gpu 0,1 24 | 25 | python main.py \ 26 | --model_name Qwen/Qwen2.5-7B-Instruct \ 27 | --model_type vllm \ 28 | --split test \ 29 | --dataset GSM8K \ 30 | --reasoning L2M \ 31 | --shot 1 \ 32 | --batchsize 10 \ 33 | --range_begin 0 \ 34 | --range_end 32 \ 35 | --gpu 0,1 36 | 37 | python main.py \ 38 | --model_name Qwen/Qwen2.5-7B-Instruct \ 39 | --model_type vllm \ 40 | --split test \ 41 | --dataset GSM8K \ 42 | --reasoning SBP \ 43 | --shot 0 \ 44 | --batchsize 10 \ 45 | --range_begin 0 \ 46 | --range_end 32 \ 47 | --gpu 0,1 48 | 49 | python main.py \ 50 | --model_name Qwen/Qwen2.5-7B-Instruct \ 51 | --model_type vllm \ 52 | --split test \ 53 | --dataset GSM8K \ 54 | --reasoning AnP \ 55 | --shot 1 \ 56 | --batchsize 10 \ 57 | --range_begin 0 \ 58 | --range_end 32 \ 59 | --gpu 0,1 60 | 61 | python main.py \ 62 | --model_name Qwen/Qwen2.5-7B-Instruct \ 63 | --model_type vllm \ 64 | --split test \ 65 | --dataset GSM8K \ 66 | --reasoning ToT \ 67 | --shot 3 \ 68 | --batchsize 10 \ 69 | --range_begin 0 \ 70 | --range_end 5 \ 71 | --gpu 0,1 72 | 73 | python main.py \ 74 | --model_name Qwen/Qwen2.5-7B-Instruct \ 75 | --model_type vllm \ 76 | --split test \ 77 | --dataset GSM8K \ 78 | --reasoning ToT \ 79 | --shot 5 \ 80 | --batchsize 10 \ 81 | --range_begin 0 \ 82 | --range_end 5 \ 83 | --gpu 0,1 84 | 85 | python main.py \ 86 | --model_name Qwen/Qwen2.5-7B-Instruct \ 87 | --model_type vllm \ 88 | --split test \ 89 | --dataset GSM8K \ 90 | --reasoning ToT \ 91 | --shot 10 \ 92 | --batchsize 10 \ 93 | --range_begin 0 \ 94 | --range_end 5 \ 95 | --gpu 0,1 96 | 97 | python main.py \ 98 | --model_name Qwen/Qwen2.5-7B-Instruct \ 99 | --model_type vllm \ 100 | --split test \ 101 | --dataset GSM8K \ 102 | --reasoning S-RF \ 103 | --shot 0 \ 104 | --batchsize 5 \ 105 | --range_begin 0 \ 106 | --range_end 1 \ 107 | --gpu 0,1 108 | 109 | python main.py \ 110 | --model_name Qwen/Qwen2.5-7B-Instruct \ 111 | --model_type vllm \ 112 | --split test \ 113 | --dataset GSM8K \ 114 | --reasoning MAD \ 115 | --shot 0 \ 116 | --batchsize 1 \ 117 | --range_begin 0 \ 118 | --range_end 1 \ 119 | --gpu 0,1 120 | -------------------------------------------------------------------------------- /scripts/GPT_GSM8K.sh: -------------------------------------------------------------------------------- 1 | ## You can control the maximum concurrency level by adjusting the "max_num_workers" parameter. 2 | 3 | python main.py \ 4 | --model_name gpt-4o-mini \ 5 | --model_type openai \ 6 | --split test \ 7 | --dataset GSM8K \ 8 | --reasoning DiP \ 9 | --shot 0 \ 10 | --batchsize 10 \ 11 | --range_begin 0 \ 12 | --range_end 16 \ 13 | --max_num_workers 10 \ 14 | --openai_api_key your_api_key \ 15 | --openai_base_url your_base_url 16 | 17 | python main.py \ 18 | --model_name gpt-4o-mini \ 19 | --model_type openai \ 20 | --split test \ 21 | --dataset GSM8K \ 22 | --reasoning CoT \ 23 | --shot 0 \ 24 | --batchsize 10 \ 25 | --range_begin 0 \ 26 | --range_end 16 \ 27 | --max_num_workers 10 \ 28 | --openai_api_key your_api_key \ 29 | --openai_base_url your_base_url 30 | 31 | python main.py \ 32 | --model_name gpt-4o-mini \ 33 | --model_type openai \ 34 | --split test \ 35 | --dataset GSM8K \ 36 | --reasoning L2M \ 37 | --shot 1 \ 38 | --batchsize 10 \ 39 | --range_begin 0 \ 40 | --range_end 16 \ 41 | --max_num_workers 10 \ 42 | --openai_api_key your_api_key \ 43 | --openai_base_url your_base_url 44 | 45 | python main.py \ 46 | --model_name gpt-4o-mini \ 47 | --model_type openai \ 48 | --split test \ 49 | --dataset GSM8K \ 50 | --reasoning SBP \ 51 | --shot 0 \ 52 | --batchsize 10 \ 53 | --range_begin 0 \ 54 | --range_end 16 \ 55 | --max_num_workers 10 \ 56 | --openai_api_key your_api_key \ 57 | --openai_base_url your_base_url 58 | 59 | python main.py \ 60 | --model_name gpt-4o-mini \ 61 | --model_type openai \ 62 | --split test \ 63 | --dataset GSM8K \ 64 | --reasoning AnP \ 65 | --shot 1 \ 66 | --batchsize 10 \ 67 | --range_begin 0 \ 68 | --range_end 16 \ 69 | --max_num_workers 10 \ 70 | --openai_api_key your_api_key \ 71 | --openai_base_url your_base_url 72 | 73 | python main.py \ 74 | --model_name gpt-4o-mini \ 75 | --model_type openai \ 76 | --split test \ 77 | --dataset GSM8K \ 78 | --reasoning ToT \ 79 | --shot 3 \ 80 | --batchsize 10 \ 81 | --range_begin 0 \ 82 | --range_end 5 \ 83 | --max_num_workers 10 \ 84 | --openai_api_key your_api_key \ 85 | --openai_base_url your_base_url 86 | 87 | python main.py \ 88 | --model_name gpt-4o-mini \ 89 | --model_type openai \ 90 | --split test \ 91 | --dataset GSM8K \ 92 | --reasoning ToT \ 93 | --shot 5 \ 94 | --batchsize 10 \ 95 | --range_begin 0 \ 96 | --range_end 5 \ 97 | --max_num_workers 10 \ 98 | --openai_api_key your_api_key \ 99 | --openai_base_url your_base_url 100 | 101 | python main.py \ 102 | --model_name gpt-4o-mini \ 103 | --model_type openai \ 104 | --split test \ 105 | --dataset GSM8K \ 106 | --reasoning ToT \ 107 | --shot 10 \ 108 | --batchsize 10 \ 109 | --range_begin 0 \ 110 | --range_end 5 \ 111 | --max_num_workers 10 \ 112 | --openai_api_key your_api_key \ 113 | --openai_base_url your_base_url 114 | 115 | python main.py \ 116 | --model_name gpt-4o-mini \ 117 | --model_type openai \ 118 | --split test \ 119 | --dataset GSM8K \ 120 | --reasoning S-RF \ 121 | --shot 0 \ 122 | --batchsize 5 \ 123 | --range_begin 0 \ 124 | --range_end 1 \ 125 | --max_num_workers 10 \ 126 | --openai_api_key your_api_key \ 127 | --openai_base_url your_base_url 128 | 129 | python main.py \ 130 | --model_name gpt-4o-mini \ 131 | --model_type openai \ 132 | --split test \ 133 | --dataset GSM8K \ 134 | --reasoning MAD \ 135 | --shot 0 \ 136 | --batchsize 1 \ 137 | --range_begin 0 \ 138 | --range_end 1 \ 139 | --max_num_workers 1 \ 140 | --openai_api_key your_api_key \ 141 | --openai_base_url your_base_url 142 | -------------------------------------------------------------------------------- /prompts/AIME.py: -------------------------------------------------------------------------------- 1 | prompt_format = " Your final result should be in the form \\boxed{answer}, at the end of your response." 2 | 3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words." 4 | 5 | io = "{question}" + prompt_format 6 | 7 | io_briefly = io + " You should answer with no more than 200 words." 8 | 9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n' 10 | 11 | cot_0_shot = cot_pre + '''Question: {question} 12 | Answer: Let's think step by step.''' 13 | 14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 16 | 17 | # Question: {question} 18 | # Answer: 19 | # ''' 20 | 21 | Least_to_Most_0_shot = cot_pre + ''' In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers. 22 | 23 | Question: {question} 24 | Answer: 25 | ''' 26 | 27 | tot_post = ''' 28 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 29 | 30 | tot_3_solutions = ''' 31 | Question: {question} 32 | 33 | Solution 1: {solution1} 34 | Solution 2: {solution2} 35 | Solution 3: {solution3}''' 36 | 37 | tot_5_solutions = tot_3_solutions + ''' 38 | Solution 4: {solution4} 39 | Solution 5: {solution5}''' 40 | 41 | tot_10_solutions = tot_5_solutions + ''' 42 | Solution 6: {solution6} 43 | Solution 7: {solution7} 44 | Solution 8: {solution8} 45 | Solution 9: {solution9} 46 | Solution 10: {solution10}''' 47 | 48 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 49 | 50 | # Initial Problem: 51 | {question} 52 | 53 | # Instructions: 54 | ## Relevant Problems: 55 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 56 | - After "Q: ", describe the problem. 57 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 58 | 59 | ## Solve the Initial Problem: 60 | Q: Copy and paste the initial problem here. 61 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 62 | ''' 63 | 64 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 65 | 66 | # Initial Problem: 67 | {question} 68 | 69 | # Instructions: 70 | ## Relevant Problems: 71 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 72 | - After "Q: ", describe the problem. 73 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 74 | 75 | ## Solve the Initial Problem: 76 | Q: Copy and paste the initial problem here. 77 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 78 | ''' 79 | 80 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 81 | 82 | # Initial Problem: 83 | {question} 84 | 85 | # Instructions: 86 | ## Relevant Problems: 87 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 88 | - After "Q: ", describe the problem. 89 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 90 | 91 | ## Solve the Initial Problem: 92 | Q: Copy and paste the initial problem here. 93 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 94 | ''' 95 | 96 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem. 97 | Question: 98 | {question} 99 | 100 | Principles involved: 101 | ''' 102 | 103 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + ''' 104 | Question: 105 | {question} 106 | 107 | Principles: 108 | {principles} 109 | 110 | Answer: 111 | ''' 112 | 113 | -------------------------------------------------------------------------------- /scripts/Gemini_GSM8K.sh: -------------------------------------------------------------------------------- 1 | ## You can control the maximum concurrency level by adjusting the "max_num_workers" parameter. 2 | 3 | python main.py \ 4 | --model_name gemini-1.5-flash \ 5 | --model_type gemini \ 6 | --split test \ 7 | --dataset GSM8K \ 8 | --reasoning DiP \ 9 | --shot 0 \ 10 | --batchsize 10 \ 11 | --range_begin 0 \ 12 | --range_end 8 \ 13 | --max_num_workers 10 \ 14 | --gemini_api_key your_api_key 15 | 16 | python main.py \ 17 | --model_name gemini-1.5-flash \ 18 | --model_type gemini \ 19 | --split test \ 20 | --dataset GSM8K \ 21 | --reasoning CoT \ 22 | --shot 0 \ 23 | --batchsize 10 \ 24 | --range_begin 0 \ 25 | --range_end 8 \ 26 | --max_num_workers 10 \ 27 | --gemini_api_key your_api_key 28 | 29 | python main.py \ 30 | --model_name gemini-1.5-flash \ 31 | --model_type gemini \ 32 | --split test \ 33 | --dataset GSM8K \ 34 | --reasoning L2M \ 35 | --shot 1 \ 36 | --batchsize 10 \ 37 | --range_begin 0 \ 38 | --range_end 8 \ 39 | --max_num_workers 10 \ 40 | --gemini_api_key your_api_key 41 | 42 | python main.py \ 43 | --model_name gemini-1.5-flash \ 44 | --model_type gemini \ 45 | --split test \ 46 | --dataset GSM8K \ 47 | --reasoning SBP \ 48 | --shot 0 \ 49 | --batchsize 10 \ 50 | --range_begin 0 \ 51 | --range_end 8 \ 52 | --max_num_workers 10 \ 53 | --gemini_api_key your_api_key 54 | 55 | python main.py \ 56 | --model_name gemini-1.5-flash \ 57 | --model_type gemini \ 58 | --split test \ 59 | --dataset GSM8K \ 60 | --reasoning AnP \ 61 | --shot 1 \ 62 | --batchsize 10 \ 63 | --range_begin 0 \ 64 | --range_end 8 \ 65 | --max_num_workers 10 \ 66 | --gemini_api_key your_api_key 67 | 68 | #### 69 | 70 | python main.py \ 71 | --model_name gemini-1.5-flash \ 72 | --model_type gemini \ 73 | --split test \ 74 | --dataset GSM8K \ 75 | --reasoning DiP \ 76 | --shot 0 \ 77 | --batchsize 10 \ 78 | --range_begin 8 \ 79 | --range_end 16 \ 80 | --max_num_workers 10 \ 81 | --gemini_api_key your_api_key 82 | 83 | python main.py \ 84 | --model_name gemini-1.5-flash \ 85 | --model_type gemini \ 86 | --split test \ 87 | --dataset GSM8K \ 88 | --reasoning CoT \ 89 | --shot 0 \ 90 | --batchsize 10 \ 91 | --range_begin 8 \ 92 | --range_end 16 \ 93 | --max_num_workers 10 \ 94 | --gemini_api_key your_api_key 95 | 96 | python main.py \ 97 | --model_name gemini-1.5-flash \ 98 | --model_type gemini \ 99 | --split test \ 100 | --dataset GSM8K \ 101 | --reasoning L2M \ 102 | --shot 1 \ 103 | --batchsize 10 \ 104 | --range_begin 8 \ 105 | --range_end 16 \ 106 | --max_num_workers 10 \ 107 | --gemini_api_key your_api_key 108 | 109 | python main.py \ 110 | --model_name gemini-1.5-flash \ 111 | --model_type gemini \ 112 | --split test \ 113 | --dataset GSM8K \ 114 | --reasoning SBP \ 115 | --shot 0 \ 116 | --batchsize 10 \ 117 | --range_begin 8 \ 118 | --range_end 16 \ 119 | --max_num_workers 10 \ 120 | --gemini_api_key your_api_key 121 | 122 | python main.py \ 123 | --model_name gemini-1.5-flash \ 124 | --model_type gemini \ 125 | --split test \ 126 | --dataset GSM8K \ 127 | --reasoning AnP \ 128 | --shot 1 \ 129 | --batchsize 10 \ 130 | --range_begin 8 \ 131 | --range_end 16 \ 132 | --max_num_workers 10 \ 133 | --gemini_api_key your_api_key 134 | 135 | 136 | 137 | python main.py \ 138 | --model_name gemini-1.5-flash \ 139 | --model_type gemini \ 140 | --split test \ 141 | --dataset GSM8K \ 142 | --reasoning ToT \ 143 | --shot 3 \ 144 | --batchsize 10 \ 145 | --range_begin 0 \ 146 | --range_end 5 \ 147 | --max_num_workers 10 \ 148 | --gemini_api_key your_api_key 149 | 150 | python main.py \ 151 | --model_name gemini-1.5-flash \ 152 | --model_type gemini \ 153 | --split test \ 154 | --dataset GSM8K \ 155 | --reasoning ToT \ 156 | --shot 5 \ 157 | --batchsize 10 \ 158 | --range_begin 0 \ 159 | --range_end 5 \ 160 | --max_num_workers 10 \ 161 | --gemini_api_key your_api_key 162 | 163 | python main.py \ 164 | --model_name gemini-1.5-flash \ 165 | --model_type gemini \ 166 | --split test \ 167 | --dataset GSM8K \ 168 | --reasoning ToT \ 169 | --shot 10 \ 170 | --batchsize 10 \ 171 | --range_begin 0 \ 172 | --range_end 5 \ 173 | --max_num_workers 10 \ 174 | --gemini_api_key your_api_key 175 | 176 | python main.py \ 177 | --model_name gemini-1.5-flash \ 178 | --model_type gemini \ 179 | --split test \ 180 | --dataset GSM8K \ 181 | --reasoning S-RF \ 182 | --shot 0 \ 183 | --batchsize 5 \ 184 | --range_begin 0 \ 185 | --range_end 1 \ 186 | --max_num_workers 10 \ 187 | --gemini_api_key your_api_key 188 | 189 | python main.py \ 190 | --model_name gemini-1.5-flash \ 191 | --model_type gemini \ 192 | --split test \ 193 | --dataset GSM8K \ 194 | --reasoning MAD \ 195 | --shot 0 \ 196 | --batchsize 1 \ 197 | --range_begin 0 \ 198 | --range_end 1 \ 199 | --max_num_workers 1 \ 200 | --gemini_api_key your_api_key 201 | -------------------------------------------------------------------------------- /prompts/GSM8K.py: -------------------------------------------------------------------------------- 1 | prompt_format = " Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response." 2 | 3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words." 4 | 5 | io = "{question}" + prompt_format 6 | 7 | io_briefly = io + " You should answer with no more than 200 words." 8 | 9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n' 10 | 11 | cot_0_shot = cot_pre + '''Question: {question} Let's think step by step. 12 | Answer:''' 13 | 14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 16 | 17 | # Question: {question} 18 | # Answer: 19 | # ''' 20 | 21 | cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 22 | Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 23 | 24 | Question: {question} 25 | Answer: 26 | ''' 27 | 28 | cot_5_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 29 | Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 30 | 31 | Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? 32 | Answer: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. The answer is \\boxed{10}. 33 | 34 | Question: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?" 35 | Answer: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30. This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more. The answer is \\boxed{5}. 36 | 37 | Question: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read? 38 | Answer: Maila read 12 x 2 = <<12*2=24>>24 pages today. So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday. There are 120 - 36 = <<120-36=84>>84 pages left to be read. Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages. The answer is \\boxed{42}. 39 | 40 | Question: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year? 41 | Answer: He writes each friend 3*2=<<3*2=6>>6 pages a week. So he writes 6*2=<<6*2=12>>12 pages every week. That means he writes 12*52=<<12*52=624>>624 pages a year. The answer is \\boxed{624}. 42 | 43 | Question: {question} 44 | Answer: 45 | ''' 46 | 47 | Least_to_Most_1_shot = cot_pre + '''Question: Elsa has 5 apples. Anna has 2 more apples than Elsa. How many apples do they have together? 48 | Answer: Let's break down this problem: 1. How many apples does Anna have? 2. How many apples do they have together? 49 | 1. Anna has 2 more apples than Elsa. So Anna has 2 + 5 = 7 apples. 50 | 2. Elsa and Anna have 5 + 7 = 12 apples together. 51 | The answer is: \\boxed{12}. 52 | 53 | Question: {question} 54 | Answer: 55 | ''' 56 | 57 | tot_post = ''' 58 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 59 | 60 | tot_3_solutions = ''' 61 | Question: {question} 62 | 63 | Solution 1: {solution1} 64 | Solution 2: {solution2} 65 | Solution 3: {solution3}''' 66 | 67 | tot_5_solutions = tot_3_solutions + ''' 68 | Solution 4: {solution4} 69 | Solution 5: {solution5}''' 70 | 71 | tot_10_solutions = tot_5_solutions + ''' 72 | Solution 6: {solution6} 73 | Solution 7: {solution7} 74 | Solution 8: {solution8} 75 | Solution 9: {solution9} 76 | Solution 10: {solution10}''' 77 | 78 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 79 | 80 | # Initial Problem: 81 | {question} 82 | 83 | # Instructions: 84 | ## Relevant Problems: 85 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 86 | - After "Q: ", describe the problem. 87 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 88 | 89 | ## Solve the Initial Problem: 90 | Q: Copy and paste the initial problem here. 91 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 92 | ''' 93 | 94 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 95 | 96 | # Initial Problem: 97 | {question} 98 | 99 | # Instructions: 100 | ## Relevant Problems: 101 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 102 | - After "Q: ", describe the problem. 103 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 104 | 105 | ## Solve the Initial Problem: 106 | Q: Copy and paste the initial problem here. 107 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 108 | ''' 109 | 110 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 111 | 112 | # Initial Problem: 113 | {question} 114 | 115 | # Instructions: 116 | ## Relevant Problems: 117 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 118 | - After "Q: ", describe the problem. 119 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 120 | 121 | ## Solve the Initial Problem: 122 | Q: Copy and paste the initial problem here. 123 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 124 | ''' 125 | 126 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem. 127 | Question: 128 | {question} 129 | 130 | Principles involved: 131 | ''' 132 | 133 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + ''' 134 | Question: 135 | {question} 136 | 137 | Principles: 138 | {principles} 139 | 140 | Answer: 141 | ''' 142 | 143 | -------------------------------------------------------------------------------- /prompts/GSM_Hard.py: -------------------------------------------------------------------------------- 1 | prompt_format = " The given information may not conform to common sense and the result may be a nonsense decimal or negative number, it's okay, output it instead of considering it is unreasonable. Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response." 2 | 3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words." 4 | 5 | io = "{question}" + prompt_format 6 | 7 | io_briefly = io + " You should answer with no more than 200 words." 8 | 9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n' 10 | 11 | cot_0_shot = cot_pre + '''Question: {question} Let's think step by step. 12 | Answer:''' 13 | 14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 16 | 17 | # Question: {question} 18 | # Answer: 19 | # ''' 20 | 21 | cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48564 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 22 | Answer: Natalia sold 48564/2 = <<48564/2=24282>>24282 clips in May. Natalia sold 48564+24282 = <<48564+24282=72846>>72846 clips altogether in April and May. The answer is \\boxed{72846}. 23 | 24 | Question: {question} 25 | Answer: 26 | ''' 27 | 28 | cot_5_shot = cot_pre + '''Question: Natalia sold clips to 48564 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 29 | Answer: Natalia sold 48564/2 = <<48564/2=24282>>24282 clips in May. Natalia sold 48564+24282 = <<48564+24282=72846>>72846 clips altogether in April and May. The answer is \\boxed{72846}. 30 | 31 | Question: Weng earns $1293 an hour for babysitting. Yesterday, she just did 612 minutes of babysitting. How much did she earn? 32 | Answer: Weng earns 1293/60 = $<<1293/60=21.55>>21.55 per minute. Working 612 minutes, she earned 21.55 x 612 = $<<21.55*612=13188.6>>13188.6. The answer is \\boxed{13188.6}. 33 | 34 | Question: Betty is saving money for a new wallet which costs $8200. Betty has only half of the money she needs. Her parents decided to give her $1525 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?" 35 | Answer: In the beginning, Betty has only 8200 / 2 = $<<8200/2=4100>>4100. Betty's grandparents gave her 1525 * 2 = $<<1525*2=3050>>3050. This means, Betty needs 8200 - 4100 - 3050 - 1525 = $<<8200-4100-3050-1525=-475>>-475 more. The answer is \\boxed{-475}. 36 | 37 | Question: Julie is reading a 12602-page book. Yesterday, she was able to read 3127 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read? 38 | Answer: Maila read 3127 x 2 = <<3127*2=6254>>6254 pages today. So she was able to read a total of 3127 + 6254 = <<3127+6254=9381>>9381 pages since yesterday. There are 12602 - 9381 = <<12602-9381=3221>>3221 pages left to be read. Since she wants to read half of the remaining pages tomorrow, then she should read 3221/2 = <<3221/2=1610.5>>1610.5 pages. The answer is \\boxed{1610.5}. 39 | 40 | Question: James writes a 312996-page letter to 2143 different friends twice a week. How many pages does he write a year? 41 | Answer: He writes each friend 312996*2143=<<312996*2143=670750428>>670750428 pages a week. So he writes 670750428*2=<<670750428*2=1341500856>>1341500856 pages every week. That means he writes 1341500856*52=<<1341500856*52=69758044512>>69758044512 pages a year. The answer is \\boxed{69758044512}. 42 | 43 | Question: {question} 44 | Answer: 45 | ''' 46 | 47 | Least_to_Most_1_shot = cot_pre + '''Question: Elsa has 524866 apples. Anna has 432343 more apples than Elsa. How many apples do they have together? 48 | Answer: Let's break down this problem: 1. How many apples does Anna have? 2. How many apples do they have together? 49 | 1. Anna has 432343 more apples than Elsa. So Anna has 524866 + 432343 = 957209 apples. 50 | 2. Elsa and Anna have 524866 + 957209 = 1482075 apples together. 51 | The answer is: \\boxed{1482075}. 52 | 53 | Question: {question} 54 | Answer: 55 | ''' 56 | 57 | tot_post = ''' 58 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 59 | 60 | tot_3_solutions = ''' 61 | Question: {question} 62 | 63 | Solution 1: {solution1} 64 | Solution 2: {solution2} 65 | Solution 3: {solution3}''' 66 | 67 | tot_5_solutions = tot_3_solutions + ''' 68 | Solution 4: {solution4} 69 | Solution 5: {solution5}''' 70 | 71 | tot_10_solutions = tot_5_solutions + ''' 72 | Solution 6: {solution6} 73 | Solution 7: {solution7} 74 | Solution 8: {solution8} 75 | Solution 9: {solution9} 76 | Solution 10: {solution10}''' 77 | 78 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 79 | 80 | # Initial Problem: 81 | {question} 82 | 83 | # Instructions: 84 | ## Relevant Problems: 85 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 86 | - After "Q: ", describe the problem. 87 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 88 | 89 | ## Solve the Initial Problem: 90 | Q: Copy and paste the initial problem here. 91 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 92 | ''' 93 | 94 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 95 | 96 | # Initial Problem: 97 | {question} 98 | 99 | # Instructions: 100 | ## Relevant Problems: 101 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 102 | - After "Q: ", describe the problem. 103 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 104 | 105 | ## Solve the Initial Problem: 106 | Q: Copy and paste the initial problem here. 107 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 108 | ''' 109 | 110 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 111 | 112 | # Initial Problem: 113 | {question} 114 | 115 | # Instructions: 116 | ## Relevant Problems: 117 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 118 | - After "Q: ", describe the problem. 119 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 120 | 121 | ## Solve the Initial Problem: 122 | Q: Copy and paste the initial problem here. 123 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 124 | ''' 125 | 126 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem. 127 | Question: 128 | {question} 129 | 130 | Principles involved: 131 | ''' 132 | 133 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + ''' 134 | Question: 135 | {question} 136 | 137 | Principles: 138 | {principles} 139 | 140 | Answer: 141 | ''' 142 | 143 | -------------------------------------------------------------------------------- /prompts/MATH.py: -------------------------------------------------------------------------------- 1 | prompt_format = " Your final result should be in the form \\boxed{answer}, at the end of your response." 2 | 3 | directly_answer = "{question} You should only answer the final result with a single numerical number. Do not say other words." 4 | 5 | io = "{question}" + prompt_format 6 | 7 | io_briefly = io + " You should answer with no more than 200 words." 8 | 9 | cot_pre = "Please answer the given question." + prompt_format + '\n\n' 10 | 11 | cot_0_shot = cot_pre + '''Question: {question} 12 | Answer: Let's think step by step.''' 13 | 14 | # cot_1_shot = cot_pre + '''Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 15 | # Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. The answer is \\boxed{72}. 16 | 17 | # Question: {question} 18 | # Answer: 19 | # ''' 20 | 21 | cot_1_shot = cot_pre + '''Question: Ryan has 3 red lava lamps and 3 blue lava lamps. He arranges them in a row on a shelf randomly, then turns 3 random lamps on. What is the probability that the leftmost lamp on the shelf is red, and the leftmost lamp which is turned on is also red? 22 | Answer: There are $\\binom{6}{3}=20$ ways for Ryan to arrange the lamps, and $\\binom{6}{3}=20$ ways for him to choose which lamps are on, giving $20\\cdot20=400$ total possible outcomes. There are two cases for the desired outcomes: either the left lamp is on, or it isn't. If the left lamp is on, there are $\\binom{5}{2}=10$ ways to choose which other lamps are on, and $\\binom{5}{2}=10$ ways to choose which other lamps are red. This gives $10\\cdot10=100$ possibilities. If the first lamp isn't on, there are $\\binom{5}{3}=10$ ways to choose which lamps are on, and since both the leftmost lamp and the leftmost lit lamp must be red, there are $\\binom{4}{1}=4$ ways to choose which other lamp is red. This case gives 40 valid possibilities, for a total of 140 valid arrangements out of 400. Therefore, the probability is $\\dfrac{140}{400}=\\boxed{\\dfrac{7}{20}}$. 23 | 24 | Question: {question} 25 | Answer: 26 | ''' 27 | 28 | cot_5_shot = cot_pre + '''Question: Ryan has 3 red lava lamps and 3 blue lava lamps. He arranges them in a row on a shelf randomly, then turns 3 random lamps on. What is the probability that the leftmost lamp on the shelf is red, and the leftmost lamp which is turned on is also red? 29 | Answer: There are $\\binom{6}{3}=20$ ways for Ryan to arrange the lamps, and $\\binom{6}{3}=20$ ways for him to choose which lamps are on, giving $20\\cdot20=400$ total possible outcomes. There are two cases for the desired outcomes: either the left lamp is on, or it isn't. If the left lamp is on, there are $\\binom{5}{2}=10$ ways to choose which other lamps are on, and $\\binom{5}{2}=10$ ways to choose which other lamps are red. This gives $10\\cdot10=100$ possibilities. If the first lamp isn't on, there are $\\binom{5}{3}=10$ ways to choose which lamps are on, and since both the leftmost lamp and the leftmost lit lamp must be red, there are $\\binom{4}{1}=4$ ways to choose which other lamp is red. This case gives 40 valid possibilities, for a total of 140 valid arrangements out of 400. Therefore, the probability is $\\dfrac{140}{400}=\\boxed{\\dfrac{7}{20}}$. 30 | 31 | Question: On the $xy$-plane, the origin is labeled with an $M$. The points $(1,0)$, $(-1,0)$, $(0,1)$, and $(0,-1)$ are labeled with $A$'s. The points $(2,0)$, $(1,1)$, $(0,2)$, $(-1, 1)$, $(-2, 0)$, $(-1, -1)$, $(0, -2)$, and $(1, -1)$ are labeled with $T$'s. The points $(3,0)$, $(2,1)$, $(1,2)$, $(0, 3)$, $(-1, 2)$, $(-2, 1)$, $(-3, 0)$, $(-2,-1)$, $(-1,-2)$, $(0, -3)$, $(1, -2)$, and $(2, -1)$ are labeled with $H$'s. If you are only allowed to move up, down, left, and right, starting from the origin, how many distinct paths can be followed to spell the word MATH? 32 | Answer: From the M, we can proceed to four different As. Note that the letters are all symmetric, so we can simply count one case (say, that of moving from M to the bottom A) and then multiply by four.\n\nFrom the bottom A, we can proceed to any one of three Ts. From the two Ts to the sides of the A, we can proceed to one of two Hs. From the T that is below the A, we can proceed to one of three Hs. Thus, this case yields $2 \\cdot 2 + 3 = 7$ paths.\n\nThus, there are $4 \\cdot 7 = \\boxed{28}$ distinct paths. 33 | 34 | Question: Factor the following expression: $55z^{17}+121z^{34}$. 35 | Answer: The greatest common factor of the two coefficients is $11$, and the greatest power of $z$ that divides both terms is $z^{17}$. So, we factor $11z^{17}$ out of both terms:\n\n\\begin{align*}\n55z^{17}+121z^{34} &= 11z^{17}\\cdot 5 +11z^{17}\\cdot 11z^{17}\\\\\n&= \\boxed{11z^{17}(5+11z^{17})}\n\\end{align*}. 36 | 37 | Question: Allen and Ben are painting a fence. The ratio of the amount of work Allen does to the amount of work Ben does is $3:5$. If the fence requires a total of $240$ square feet to be painted, how many square feet does Ben paint? 38 | Answer: Between them, Allen and Ben are dividing the work into $8$ equal parts, $3$ of which Allen does and $5$ of which Ben does. Each part of the work requires $\\frac{240}{8} = 30$ square feet to be painted. Since Ben does $5$ parts of the work, he will paint $30 \\cdot 5 = \\boxed{150}$ square feet of the fence. 39 | 40 | Question: Suppose $z$ and $w$ are complex numbers such that\n\\[|z| = |w| = z \\overline{w} + \\overline{z} w= 1.\\]Find the largest possible value of the real part of $z + w.$ 41 | Answer: Let $z = a + bi$ and $w = c + di,$ where $a,$ $b,$ $c,$ and $d$ are complex numbers. Then from $|z| = 1,$ $a^2 + b^2 = 1,$ and from $|w| = 1,$ $c^2 + d^2 = 1.$ Also, from $z \\overline{w} + \\overline{z} w = 1,$\n\\[(a + bi)(c - di) + (a - bi)(c + di) = 1,\\]so $2ac + 2bd = 1.$\n\nThen\n\\begin{align*}\n(a + c)^2 + (b + d)^2 &= a^2 + 2ac + c^2 + b^2 + 2bd + d^2 \\\\\n&= (a^2 + b^2) + (c^2 + d^2) + (2ac + 2bd) \\\\\n&= 3.\n\\end{align*}The real part of $z + w$ is $a + c,$ which can be at most $\\sqrt{3}.$ Equality occurs when $z = \\frac{\\sqrt{3}}{2} + \\frac{1}{2} i$ and $w = \\frac{\\sqrt{3}}{2} - \\frac{1}{2} i,$ so the largest possible value of $a + c$ is $\\boxed{\\sqrt{3}}.$ 42 | 43 | Question: {question} 44 | Answer: 45 | ''' 46 | 47 | Least_to_Most_0_shot = cot_pre + ''' In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers. 48 | 49 | Question: {question} 50 | Answer: 51 | ''' 52 | 53 | tot_post = ''' 54 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 55 | 56 | tot_3_solutions = ''' 57 | Question: {question} 58 | 59 | Solution 1: {solution1} 60 | Solution 2: {solution2} 61 | Solution 3: {solution3}''' 62 | 63 | tot_5_solutions = tot_3_solutions + ''' 64 | Solution 4: {solution4} 65 | Solution 5: {solution5}''' 66 | 67 | tot_10_solutions = tot_5_solutions + ''' 68 | Solution 6: {solution6} 69 | Solution 7: {solution7} 70 | Solution 8: {solution8} 71 | Solution 9: {solution9} 72 | Solution 10: {solution10}''' 73 | 74 | anologous_1_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 75 | 76 | # Initial Problem: 77 | {question} 78 | 79 | # Instructions: 80 | ## Relevant Problems: 81 | Recall an example of the math problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 82 | - After "Q: ", describe the problem. 83 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 84 | 85 | ## Solve the Initial Problem: 86 | Q: Copy and paste the initial problem here. 87 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 88 | ''' 89 | 90 | anologous_3_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 91 | 92 | # Initial Problem: 93 | {question} 94 | 95 | # Instructions: 96 | ## Relevant Problems: 97 | Recall three examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 98 | - After "Q: ", describe the problem. 99 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 100 | 101 | ## Solve the Initial Problem: 102 | Q: Copy and paste the initial problem here. 103 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 104 | ''' 105 | 106 | anologous_5_prompt = '''Your task is to tackle mathematical problems. When presented with a math problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 107 | 108 | # Initial Problem: 109 | {question} 110 | 111 | # Instructions: 112 | ## Relevant Problems: 113 | Recall five examples of math problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 114 | - After "Q: ", describe the problem. 115 | - After "A: ", explain the solution and enclose the ultimate answer in \\boxed{}. 116 | 117 | ## Solve the Initial Problem: 118 | Q: Copy and paste the initial problem here. 119 | A: Explain the solution and enclose the ultimate answer in \\boxed{} here. 120 | ''' 121 | 122 | SBP_extract = '''You are an expert at mathematics. Your task is to extract the mathematics concepts and principles involved in solving the problem. 123 | Question: 124 | {question} 125 | 126 | Principles involved: 127 | ''' 128 | 129 | SBP_answer = "You are an expert at mathematics. You are given a mathematics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + prompt_format + ''' 130 | Question: 131 | {question} 132 | 133 | Principles: 134 | {principles} 135 | 136 | Answer: 137 | ''' 138 | 139 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 |

Rethinking the Role of Prompting Strategies in LLM Test-Time Scaling: A Perspective of Probability Theory

3 | 4 |
5 | Yexiang Liu1,2  6 | Zekun Li3  7 | Zhi Fang1,2  8 | Nan Xu1,4  9 | Ran He1,2*  10 | Tieniu Tan1,2,5  11 |
12 |
13 | 1MAIS, Institute of Automation, Chinese Academy of Sciences 
14 | 2School of Artificial Intelligence, University of Chinese Academy of Sciences 
15 | 3University of California, Santa Barbara 
16 | 4Beijing Wenge Technology Co., Ltd  5Nanjing University
17 | *Corresponding Author 18 |
19 |
20 |
21 |
22 | ACL 2025 Main 🏆 Outstanding Paper Award 23 |
24 |
25 | 26 | [![Conference](https://img.shields.io/badge/ACL_2025-Outstanding%20Paper%20Award-ed1c24?style=flat&logo=data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiIHN0YW5kYWxvbmU9Im5vIj8+CjwhLS0gQ3JlYXRlZCB3aXRoIElua3NjYXBlIChodHRwOi8vd3d3Lmlua3NjYXBlLm9yZy8pIC0tPgo8c3ZnCiAgIHhtbG5zOnN2Zz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiAgIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIKICAgdmVyc2lvbj0iMS4wIgogICB3aWR0aD0iNjgiCiAgIGhlaWdodD0iNjgiCiAgIGlkPSJzdmcyIj4KICA8ZGVmcwogICAgIGlkPSJkZWZzNCIgLz4KICA8cGF0aAogICAgIGQ9Ik0gNDEuOTc3NTUzLC0yLjg0MjE3MDllLTAxNCBDIDQxLjk3NzU1MywxLjc2MTc4IDQxLjk3NzU1MywxLjQ0MjExIDQxLjk3NzU1MywzLjAxNTggTCA3LjQ4NjkwNTQsMy4wMTU4IEwgMCwzLjAxNTggTCAwLDEwLjUwMDc5IEwgMCwzOC40Nzg2NyBMIDAsNDYgTCA3LjQ4NjkwNTQsNDYgTCA0OS41MDA4MDIsNDYgTCA1Ni45ODc3MDgsNDYgTCA2OCw0NiBMIDY4LDMwLjk5MzY4IEwgNTYuOTg3NzA4LDMwLjk5MzY4IEwgNTYuOTg3NzA4LDEwLjUwMDc5IEwgNTYuOTg3NzA4LDMuMDE1OCBDIDU2Ljk4NzcwOCwxLjQ0MjExIDU2Ljk4NzcwOCwxLjc2MTc4IDU2Ljk4NzcwOCwtMi44NDIxNzA5ZS0wMTQgTCA0MS45Nzc1NTMsLTIuODQyMTcwOWUtMDE0IHogTSAxNS4wMTAxNTUsMTcuOTg1NzggTCA0MS45Nzc1NTMsMTcuOTg1NzggTCA0MS45Nzc1NTMsMzAuOTkzNjggTCAxNS4wMTAxNTUsMzAuOTkzNjggTCAxNS4wMTAxNTUsMTcuOTg1NzggeiAiCiAgICAgc3R5bGU9ImZpbGw6I2VkMWMyNDtmaWxsLW9wYWNpdHk6MTtmaWxsLXJ1bGU6ZXZlbm9kZDtzdHJva2U6bm9uZTtzdHJva2Utd2lkdGg6MTIuODk1NDExNDk7c3Ryb2tlLWxpbmVjYXA6YnV0dDtzdHJva2UtbGluZWpvaW46bWl0ZXI7c3Ryb2tlLW1pdGVybGltaXQ6NDtzdHJva2UtZGFzaGFycmF5Om5vbmU7c3Ryb2tlLWRhc2hvZmZzZXQ6MDtzdHJva2Utb3BhY2l0eToxIgogICAgIHRyYW5zZm9ybT0idHJhbnNsYXRlKDAsIDExKSIKICAgICBpZD0icmVjdDIxNzgiIC8+Cjwvc3ZnPgo=)](https://aclanthology.org/2025.acl-long.1356/) [![arXiv](https://img.shields.io/badge/arXiv%20paper-2505.10981-b31b1b.svg)](https://arxiv.org/abs/2505.10981) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) 27 |
28 | 29 | ## 📑 Brief Introduction 30 | 31 | ### Abstract 32 | Recently, scaling test-time compute on Large Language Models (LLM) has garnered wide attention. However, there has been limited investigation of how various reasoning prompting strategies perform as scaling. In this paper, we focus on a standard and realistic scaling setting: majority voting. We systematically conduct experiments on 6 LLMs $\times$ 8 prompting strategies $\times$ 6 benchmarks. Experiment results consistently show that as the sampling time and computational overhead increase, complicated prompting strategies with superior initial performance gradually fall behind simple Chain-of-Thought. We analyze this phenomenon and provide theoretical proofs. Additionally, we propose a probabilistic method to efficiently predict scaling performance and identify the best prompting strategy under large sampling times, eliminating the need for resource-intensive inference processes in practical applications. Furthermore, we introduce two ways derived from our theoretical analysis to significantly improve the scaling performance. We hope that our research can promote to re-examine the role of complicated prompting, unleash the potential of simple prompting strategies, and provide new insights for enhancing test-time scaling performance. 33 | 34 | ### Contributions 35 | 36 | 1. **Comprehensive experiments.** Our study covers a wide range - 6 LLMs $\times$ 8 prompting strategies $\times$ 6 benchmarks, providing sufficient evidence and context to fully support the claim. 37 | 2. **Valuable findings breaking the conventional wisdom.** Our extensive experiments consistently demonstrate that a complex prompting strategy with higher pass@1 accuracy may not always be better as test-time scaling, while simple CoT/DiP gradually dominates even if with an initial inferior performance. 38 | 3. **Rigorous theoretical analysis.** We provide an in-depth probability theoretic backed explanation of what leads to more rapid improvements with scale. 39 | - 3.1 **Definition of easy and hard questions by answer distribution.** The difficulty of the question is not only related to pass@1 accuracy, but determined by the probability distribution of all possible answer outputs. The accuracy of easy questions increases as scaling while hard questions reduces. 40 | - 3.2 **Disturbed peaks of wrong answer distribution.** Scaling performance is affected by enormous answer distribution, and we quantify this with our theory. 41 | 4. **Practical $O(1)$ approach to predict scaling performance without resource-intensive inference**. 42 | 5. **Two effective and general methods to significantly improve scaling performance verified on multiple models and datasets.** Combining the two methods will lead to much more improvements, e.g., improving Majority@10 accuracy from 15.2% to 61.0% with LLaMA-3-8B-Instruct on MATH-500. 43 | - 5.1 **Adaptively scaling based on the question difficulty.** 44 | - 5.2 **Dynamically selecting the optimal prompting strategy based on our theory.** 45 | 46 | 47 | ## 🔍 Features 48 | 49 | - Support for multiple LLM backends (VLLM, Gemini, OpenAI and other API-based models) 50 | - You can specify **any model** according to your needs. 51 | - Various reasoning prompting strategies: 52 | - Non-Iterative: 53 | - **DiP**: Direct Prompting 54 | - **CoT**: [Chain of Thought Prompting](https://proceedings.neurips.cc/paper_files/paper/2022/hash/9d5609613524ecf4f15af0f7b31abca4-Abstract-Conference.html?ref=https://githubhelp.com) 55 | - **L2M**: [Least-to-Most Prompting](https://arxiv.org/abs/2205.10625) 56 | - **SBP**: [Step-Back Prompting](https://arxiv.org/abs/2310.06117) 57 | - **AnP**: [Analogous Prompting](https://arxiv.org/abs/2310.01714) 58 | - Iterative: 59 | - **ToT**: [Tree of Thoughts](https://proceedings.neurips.cc/paper_files/paper/2023/hash/271db9922b8d1f4dd7aaef84ed5ac703-Abstract-Conference.html) 60 | - **S-RF**: [Self-Refine](https://proceedings.neurips.cc/paper_files/paper/2023/hash/91edff07232fb1b55a505a9e9f6c0ff3-Abstract-Conference.html) 61 | - **MAD**: [Multi-Agent Debate](https://dl.acm.org/doi/abs/10.5555/3692070.3692537) 62 | - Extensive dataset support: 63 | - Mathematical reasoning: 64 | - [GSM8K](https://arxiv.org/abs/2110.14168) 65 | - [GSM-Hard](https://proceedings.mlr.press/v202/gao23f) 66 | - [MATH](https://arxiv.org/abs/2103.03874) 67 | - [AIME_2024](https://modelscope.cn/datasets/AI-ModelScope/AIME_2024) 68 | - Scientific reasoning: 69 | - [GPQA](https://arxiv.org/abs/2311.12022) 70 | - [MMLU](https://arxiv.org/abs/2009.03300) 71 | - MMLU-high_school_physics 72 | - MMLU-high_school_chemistry 73 | - MMLU-high_school_biology 74 | - Two different budgets for evaluation: 75 | - Sampling time 76 | - Computation overhead (Cost) 77 | 78 | ## 🛠️ Installation 79 | 80 | 1. Clone the repository: 81 | ```bash 82 | git clone https://github.com/MraDonkey/rethinking_prompting.git 83 | cd rethinking_prompting 84 | ``` 85 | 86 | 2. Create conda environment and install dependencies: 87 | ```bash 88 | conda create -n rethinking_prompting python=3.11 89 | conda activate rethinking_prompting 90 | pip install -r requirements.txt 91 | ``` 92 | 93 | ## ⚙️ Configuration 94 | 95 | Before running the framework, you need to set up your API keys for different LLM providers: 96 | 97 | - For vllm models: 98 | - You may need to login huggingface to get access to some LLMs. 99 | - For OpenAI models or other OpenAI-like API-based models: 100 | - Set api_key `openai_api_key` 101 | - Set base_url `openai_base_url` 102 | - For Google Gemini: 103 | - Set `google_api_key` 104 | 105 | Complete the variables `hf_token` in `main.py` and `base_path` in `dataset.py`. 106 | 107 | ## 🪛 Usage 108 | 109 | For example, to get the inference results of all prompting strategies with Qwen2.5-7B-Instruct on GSM8K, you can run this script. 110 | 111 | ```bash 112 | bash scripts/Qwen_GSM8K.sh 113 | ``` 114 | 115 | You can further customize hyperparameters to suit your specific requirements. 116 | 117 | ## 🔬 Evaluation 118 | 119 | To evaluate the performance of all tested prompting strategies: 120 | 121 | ```bash 122 | python eval_csv_N.py --model_name "your_model" --dataset "your_dataset" 123 | python eval_csv_cost.py --model_name "your_model" --dataset "your_dataset" 124 | ``` 125 | 126 | You can customize the variable `sampling_times` to adjust the points in the figure, in the style of Figure 1 and 2 in our [paper](https://arxiv.org/abs/2505.10981). 127 | 128 | ![alt text](image.png) 129 | 130 | ## ✒️ Citation 131 | 132 | Should you find our work beneficial to your research, we would appreciate citations to our paper and GitHub stars to support ongoing development. ⭐ 133 | 134 | ```bibtex 135 | @inproceedings{liu-etal-2025-rethinking, 136 | title = "Rethinking the Role of Prompting Strategies in {LLM} Test-Time Scaling: A Perspective of Probability Theory", 137 | author = "Liu, Yexiang and 138 | Li, Zekun and 139 | Fang, Zhi and 140 | Xu, Nan and 141 | He, Ran and 142 | Tan, Tieniu", 143 | editor = "Che, Wanxiang and 144 | Nabende, Joyce and 145 | Shutova, Ekaterina and 146 | Pilehvar, Mohammad Taher", 147 | booktitle = "Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 148 | month = jul, 149 | year = "2025", 150 | address = "Vienna, Austria", 151 | publisher = "Association for Computational Linguistics", 152 | url = "https://aclanthology.org/2025.acl-long.1356/", 153 | doi = "10.18653/v1/2025.acl-long.1356", 154 | pages = "27962--27994", 155 | ISBN = "979-8-89176-251-0" 156 | } 157 | ``` 158 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataset import parse_best_solution, parse_answer 3 | 4 | import google.generativeai as genai # pip install -q -U google-generativeai 5 | from google.generativeai.types import HarmCategory, HarmBlockThreshold 6 | from openai import OpenAI 7 | 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | import pdb 10 | from tqdm import tqdm 11 | 12 | 13 | def get_messages(args): 14 | messages = [] 15 | assert args.messages != None or args.query != None 16 | if "gemini" not in args.model_name: 17 | if args.messages != None: 18 | roles = ['user', 'assistant'] 19 | for i in range(0, len(args.messages)): 20 | messages_i = [] 21 | assert len(args.messages[i]) % 2 == 1 22 | if args.system != None: 23 | messages_i.append({"role": "system", "content": args.system}) 24 | for j, message in enumerate(args.messages[i]): 25 | messages_i.append({'role': roles[j%2], 'content': message}) 26 | messages.append(messages_i) 27 | else: 28 | for query in args.query: 29 | if args.system == None: 30 | messages.append([{"role": "user", "content": query}]) 31 | else: 32 | messages.append([{"role": "system", "content": args.system}, 33 | {"role": "user", "content": query}]) 34 | else: 35 | if args.messages != None: 36 | roles = ['user', 'model'] 37 | for i in range(len(args.messages)): 38 | messages_i = [] 39 | assert len(args.messages[i]) % 2 == 1 40 | for j, message in enumerate(args.messages[i]): 41 | messages_i.append({'role': roles[j%2], 'parts': [message]}) 42 | messages.append(messages_i) 43 | else: 44 | for query in args.query: 45 | messages.append([{"role": "user", "parts": query}]) 46 | return messages 47 | 48 | 49 | class Gemini: 50 | def __init__(self, model="gemini-1.5-flash", N=3): 51 | self.model = genai.GenerativeModel(model) 52 | self.count_limit = N 53 | self.safety_settings = { 54 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 55 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 56 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 57 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 58 | } 59 | 60 | def get_response(self, messages, system_instruction = None, n = 1) -> str: 61 | # Query the model 62 | records = [] 63 | counts = 0 64 | while len(records) < n and counts < self.count_limit: 65 | if system_instruction != None: 66 | self.system_instruction = system_instruction 67 | try: 68 | num = n 69 | records = [] 70 | while (len(records) < n): 71 | time_begin = time.time() 72 | response = self.model.generate_content(messages, 73 | safety_settings=self.safety_settings, 74 | generation_config=genai.types.GenerationConfig(candidate_count = num)) 75 | time_completion = time.time() 76 | for i in range(0, num): 77 | if response.candidates[i].finish_reason == 1: 78 | output = response.candidates[i].content.parts[0].text.strip() 79 | completion_tokens = self.model.count_tokens(output).total_tokens 80 | usage = {"prompt_tokens": response.usage_metadata.prompt_token_count, 81 | "completion_tokens": completion_tokens, 82 | "time_prompt": 0, 83 | "time_completion": time_completion - time_begin} 84 | record = {"output": output, "usage": usage} 85 | records.append(record) 86 | num = n - len(record) 87 | 88 | except Exception as error: 89 | response = "" 90 | print(error) 91 | print('Sleeping for 10 seconds') 92 | time.sleep(10) 93 | counts += 1 94 | if counts == self.count_limit: 95 | raise EOFError("") 96 | return records 97 | 98 | 99 | def load_model(args): 100 | if args.model_type == "vllm": 101 | from vllm import LLM 102 | args.model = LLM(model=args.model_name, trust_remote_code=True, tensor_parallel_size=len(args.gpu.split(",")), dtype = args.dtype) 103 | elif args.model_type == "gemini": 104 | genai.configure(api_key=args.google_api_key) 105 | args.model = Gemini(model=args.model_name) 106 | elif args.model_type == "openai": 107 | args.client = OpenAI(api_key=args.openai_api_key, base_url=args.openai_base_url) 108 | 109 | 110 | def gpt_parallel_generate(args, message): 111 | records = [] 112 | res = args.client.chat.completions.create( 113 | model=args.model_name, 114 | messages = message, 115 | n = args.num, 116 | logprobs = True, 117 | max_tokens = args.max_new_tokens 118 | ) 119 | for i in range(0, args.num): 120 | output = res.choices[i].message.content 121 | completion_tokens = len(res.choices[i].logprobs.content) 122 | usage = { 123 | "prompt_tokens": res.usage.prompt_tokens, 124 | "completion_tokens": completion_tokens 125 | } 126 | record = { 127 | "output": output, 128 | "usage": usage 129 | } 130 | records.append(record) 131 | return records 132 | 133 | 134 | def LLM_generate(args): 135 | """ 136 | Generate responses using different LLM backends with parallel processing support. 137 | 138 | Args: 139 | args: Arguments containing model configuration and generation parameters 140 | 141 | Returns: 142 | List of generated records with outputs and usage statistics 143 | """ 144 | 145 | assert args.messages != None or args.query != None 146 | messages = [] 147 | messages = get_messages(args) 148 | outputs = [] 149 | records = [] 150 | prompt_tokens = [] 151 | completion_tokens = [] 152 | if args.model_type == "gemini": 153 | if args.max_num_workers == 1: 154 | for i in tqdm(range(0, len(messages))): 155 | records_ = args.model.get_response(messages[i], args.system, args.num) 156 | records.append(records_) 157 | else: 158 | records = [None] * len(messages) 159 | with ThreadPoolExecutor(max_workers=args.max_num_workers) as executor: 160 | future_to_index = {} 161 | for index, message in enumerate(messages): 162 | future = executor.submit(args.model.get_response, message, args.system, args.num) 163 | future_to_index[future] = index 164 | with tqdm(total=len(messages)) as pbar: 165 | for future in as_completed(future_to_index): 166 | idx = future_to_index[future] 167 | try: 168 | data = future.result() 169 | records[idx] = data 170 | pbar.update(1) 171 | except Exception as exc: 172 | print(f'Index {idx} generated exception: {exc}') 173 | pdb.set_trace() 174 | 175 | for i in range(0, len(messages)): 176 | for j in range(0, args.num): 177 | if "tot" in args.reasoning: 178 | output_key = parse_best_solution(records[i][j]["output"]) 179 | else: 180 | output_key = parse_answer(args, records[i][j]["output"]) 181 | records[i][j]["output_key"] = output_key 182 | 183 | elif args.model_type == "openai": 184 | if args.max_num_workers == 1: 185 | for i in tqdm(range(0, len(messages))): 186 | records_i = gpt_parallel_generate(args, messages[i]) 187 | records.append(records_i) 188 | else: 189 | records = [None] * len(messages) 190 | with ThreadPoolExecutor(max_workers=args.max_num_workers) as executor: 191 | future_to_index = {} 192 | for index, message in enumerate(messages): 193 | future = executor.submit(gpt_parallel_generate, args, message) 194 | future_to_index[future] = index 195 | with tqdm(total=len(messages)) as pbar: 196 | for future in as_completed(future_to_index): 197 | idx = future_to_index[future] 198 | try: 199 | data = future.result() 200 | records[idx] = data 201 | pbar.update(1) 202 | except Exception as exc: 203 | print(f'Index {idx} generated exception: {exc}') 204 | pdb.set_trace() 205 | 206 | for i in range(0, len(messages)): 207 | for j in range(0, args.num): 208 | if "tot" in args.reasoning: 209 | output_key = parse_best_solution(records[i][j]["output"]) 210 | else: 211 | output_key = parse_answer(args, records[i][j]["output"]) 212 | records[i][j]["output_key"] = output_key 213 | 214 | elif args.model_type == "vllm": 215 | from vllm import SamplingParams 216 | sampling_params = SamplingParams(temperature=args.temperature, n = args.num, max_tokens = args.max_new_tokens, top_p = 0.9) 217 | res = args.model.chat(messages=messages, sampling_params=sampling_params) 218 | for i in range(0, len(res)): 219 | outputs.append([output.text for output in res[i].outputs]) 220 | prompt_tokens.append(len(res[i].prompt_token_ids)) 221 | completion_tokens.append([len(output.token_ids) for output in res[i].outputs]) 222 | 223 | for i in range(0, len(outputs)): 224 | outputs_i = outputs[i] 225 | completion_tokens_i = completion_tokens[i] 226 | prompt_tokens_i = prompt_tokens[i] 227 | records_i = [] 228 | for j in range(0, args.num): 229 | usage = {"prompt_tokens": prompt_tokens_i, 230 | "completion_tokens": completion_tokens_i[j] 231 | } 232 | output = outputs_i[j] 233 | if "tot" in args.reasoning: 234 | output_key = parse_best_solution(output) 235 | else: 236 | output_key = parse_answer(args, output) 237 | record = {"output": output, "output_key": output_key, "usage": usage} 238 | records_i.append(record) 239 | records.append(records_i) 240 | else: 241 | raise NotImplementedError(f"Model type \"{args.model_type}\" not supported (should be \"vllm\", \"gemini\" or \"openai\").") 242 | return records 243 | -------------------------------------------------------------------------------- /prompts/GPQA.py: -------------------------------------------------------------------------------- 1 | prompt_format = '''Please choose the correct choice. Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\".''' 2 | 3 | GPQA_prompt = '''Question: {question} 4 | 5 | Choices: 6 | (A) {choice1} 7 | (B) {choice2} 8 | (C) {choice3} 9 | (D) {choice4} 10 | 11 | ''' 12 | 13 | io = GPQA_prompt + prompt_format 14 | 15 | cot_pre = prompt_format + '\n\n' 16 | 17 | cot_0_shot = GPQA_prompt + prompt_format + " Let's think step by step:" 18 | 19 | cot_1_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer? 20 | 21 | Choices: 22 | (A) 1/400 23 | (B) 19/400 24 | (C) 20/400 25 | (D) 38/400 26 | 27 | Let's think step by step: 28 | The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population. 29 | According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele. 30 | Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95. 31 | The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400. 32 | The correct answer is (D). 33 | 34 | ''' + GPQA_prompt + "Let's think step by step:" 35 | 36 | cot_5_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer? 37 | 38 | Choices: 39 | (A) 1/400 40 | (B) 19/400 41 | (C) 20/400 42 | (D) 38/400 43 | 44 | Let's think step by step: 45 | The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population. 46 | According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele. 47 | Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95. 48 | The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400. 49 | The correct answer is (D). 50 | 51 | Question: A Fe pellet of 0.056 g is first dissolved in 10 mL of hydrobromic acid HBr (0.1 M). The resulting solution is then titrated by KMnO4 (0.02 M). How many equivalence points are there? 52 | 53 | Choices: 54 | (A) Two points, 25 ml and 35 ml 55 | (B) One point, 25 mL 56 | (C) One point, 10 ml 57 | (D) Two points, 25 ml and 30 ml 58 | 59 | Let's think step by step: 60 | HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-. 61 | Two equivalence points will exist 25 ml and 35 ml. 62 | HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-. 63 | Two equivalence points will exist 25 ml and 35 ml. 64 | In the beaker there is Fe2+ and Br-. 65 | When considering titration with two analytes one will have to consider which reaction will occur first. 66 | Since it is a redox titration consider the reduction potential of: 67 | E0 (Br2 /Br- ) = 1.09 V E0 (MnO4-/ Mn2+) = 1.49 V E0 (Fe3+/Fe2+) =0.77 V 68 | [Fe2+]=m/MV=0.1M. 69 | Reaction 1: MnO4- + 5Fe2+ + 8H+ → Mn2+ + 5Fe3+ + 4H2O 70 | Reaction 2: 2MnO4- + 10Br- + 16H+ → 2Mn2+ + 5Br2 + 8H2O 71 | So MnO4- will first react with Fe2+ with a stoichiometry of 1:5 so Veq1 will be 10 ml. 72 | Then when Fe2+ is used up, MnO4- will react with Br- with a stoichiometry of 2:10 then V added will be 25 ml so Veq2=25+10=35 ml. 73 | The correct answer is (A). 74 | 75 | Question: Consider a quantum mechanical system containing a particle of mass $m$ moving in an istropic three dimensional potential of the form $V(r) = 1/2 m \omega^2 r^2$ corresponding to the acted force obeying Hooke’s law. Here, $\omega$ is the angular frequency of oscillation and $r$ is the radial distance of the particle from the origin in spherical polar coordinate. What is the value of energy of the third excited state, and how many linearly independent eigenfunctions are possible for the same energy eigenvalue? 76 | 77 | Choices: 78 | (A) 11 \pi^2 \hbar^2 / (2m r^2), 3 79 | (B) (9/2) \hbar \omega , 10 80 | (C) 11 \pi^2 \hbar^2 / (2m r^2), 10 81 | (D) (9/2) \hbar \omega, 3 82 | 83 | Let's think step by step: 84 | This problem is nothing but the three dimensional simple harmonic oscillator (SHO) problem. 85 | The energy spectrum of three dimensional SHO is $E_n= (n+3/2)\hbar \omega$ where $n=0,1,2,3….$. 86 | For third excited state n=3. 87 | 3+3/2=6/2+3/2=9/2. 88 | Thus the corresponding energy is $(9/2)\hbar \omega$. 89 | The degeneracy of the state is $g_n= (n+1)(n+2)/2$. 90 | For n=3, degeneracy is (3+1)*(3+2)/2=4*5/2=10. 91 | The correct answer is (B). 92 | 93 | Question: "Your overhear two chemists talking to each other as they leave a synthetic organic chemistry lab. One asks the other "So, how did it go?" The second chemist replies, "Not well - my compounds are on top of each other." What is the second chemist most likely referring to?" 94 | 95 | Choices: 96 | (A) The compounds they are working with have similar polarities. 97 | (B) The compounds they are working with have similar boiling points. 98 | (C) The compounds they are working with are bonding to each other through non-covalent/van der Waals interactions. 99 | (D) The compounds they are working with have similar optical rotations. 100 | 101 | Let's think step by step: 102 | "On top of each other" commonly refers to two compounds that have similar Rf values on chromatography (a common operation in synthetic chemistry). 103 | Similar Rf values arise for compounds with similar polarities. 104 | The correct answer is (A). 105 | 106 | Question: Two people are playing the following game. A fair coin is tossed into the air. Person A says that in a single toss of the coin, the tail will come. So it's like the first shot or the third shot or the fifth shot. Person B says that the coin will come with a double toss. So like the second, fourth, sixth or eighth shot. Imagine this game played forever. What is the probability that person A wins this game? 107 | 108 | Choices: 109 | (A) 1/2 110 | (B) 1/4 111 | (C) 2/3 112 | (D) 1/8 113 | 114 | Let's think step by step: 115 | When finding the correct answer, the probability of playing forever and the coin's single-point toss will be calculated. 116 | For example, a tail may appear on the first shot. 117 | This probability is 1/2. if the first toss doesn't come up, it shouldn't come to the second roll either, because the second throw is an even number. 118 | So it can come in the third shot. 119 | This is (1/2)(1/2)(1/2). 120 | So (1/2)^3=1/8. 121 | Or it could come on the fifth shot. 122 | This is (1/2)^5=1/32. 123 | This is actually a geometric series that goes on forever. 124 | We can write this series as follows. 125 | (1/2) + (1/2)^3 + (1/2)^5 + (1/2)^7 + ………. 126 | The solution for this series is as follows : a1/(1-r) where a1 is the first number and r is the sequence or r= a2/a1 or a3/a2 etc. 127 | a1=1/2 128 | r=(1/2)^2=1/4 129 | So a1/(1-r)=(1/2)/(1-1/4)=(1/2)/(3/4)=2/3. 130 | The correct answer is (C). 131 | 132 | ''' + GPQA_prompt + "Let's think step by step:" 133 | 134 | Least_to_Most_0_shot = GPQA_prompt + '''Please choose the correct choice. In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers. 135 | ''' + "Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\"." 136 | 137 | tot_post = ''' 138 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 139 | 140 | tot_3_solutions = GPQA_prompt + ''' 141 | Solution 1: {solution1} 142 | Solution 2: {solution2} 143 | Solution 3: {solution3}''' 144 | 145 | tot_5_solutions = tot_3_solutions + ''' 146 | Solution 4: {solution4} 147 | Solution 5: {solution5}''' 148 | 149 | tot_10_solutions = tot_5_solutions + ''' 150 | Solution 6: {solution6} 151 | Solution 7: {solution7} 152 | Solution 8: {solution8} 153 | Solution 9: {solution9} 154 | Solution 10: {solution10}''' 155 | 156 | anologous_1_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 157 | ''' + GPQA_prompt + ''' 158 | # Instructions: 159 | ## Relevant Problems: 160 | Recall an example of the {subject} problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 161 | - After "Q: ", describe the problem. 162 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 163 | 164 | ## Solve the Initial Problem: 165 | Q: Copy and paste the initial problem here. 166 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 167 | ''' 168 | 169 | anologous_3_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 170 | ''' + GPQA_prompt + ''' 171 | # Instructions: 172 | ## Relevant Problems: 173 | Recall three examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 174 | - After "Q: ", describe the problem. 175 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 176 | 177 | ## Solve the Initial Problem: 178 | Q: Copy and paste the initial problem here. 179 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 180 | ''' 181 | 182 | anologous_5_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 183 | ''' + GPQA_prompt + ''' 184 | # Instructions: 185 | ## Relevant Problems: 186 | Recall five examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 187 | - After "Q: ", describe the problem. 188 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 189 | 190 | ## Solve the Initial Problem: 191 | Q: Copy and paste the initial problem here. 192 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 193 | ''' 194 | 195 | SBP_extract = '''You are an expert at {subject}. Your task is to extract the {subject} concepts and principles involved in solving the problem. 196 | ''' + GPQA_prompt + ''' 197 | Principles involved: 198 | ''' 199 | 200 | SBP_answer = "You are an expert at {subject}. You are given a {subject} problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + ''' 201 | ''' + GPQA_prompt + "\nInstruction:\n"+ prompt_format + ''' 202 | Principles: 203 | {principles} 204 | ''' + ''' 205 | Answer: 206 | ''' 207 | 208 | -------------------------------------------------------------------------------- /prompts/MMLU.py: -------------------------------------------------------------------------------- 1 | prompt_format = '''Please choose the correct choice. Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\".''' 2 | 3 | MMLU_prompt = '''Question: 4 | {question} 5 | 6 | Choices: 7 | (A) {choice1} 8 | (B) {choice2} 9 | (C) {choice3} 10 | (D) {choice4} 11 | 12 | ''' 13 | 14 | io = MMLU_prompt + prompt_format 15 | 16 | cot_pre = prompt_format + '\n\n' 17 | 18 | cot_0_shot = MMLU_prompt + prompt_format + " Let's think step by step:" 19 | 20 | # cot_1_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer? 21 | 22 | # Choices: 23 | # (A) 1/400 24 | # (B) 19/400 25 | # (C) 20/400 26 | # (D) 38/400 27 | 28 | # Let's think step by step: 29 | # The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population. 30 | # According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele. 31 | # Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95. 32 | # The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400. 33 | # The correct answer is (D). 34 | 35 | # ''' + MMLU_prompt + "Let's think step by step:" 36 | 37 | # cot_5_shot = cot_pre + '''Question: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer? 38 | 39 | # Choices: 40 | # (A) 1/400 41 | # (B) 19/400 42 | # (C) 20/400 43 | # (D) 38/400 44 | 45 | # Let's think step by step: 46 | # The expected proportion of individuals who carry the b allele but are not expected to develop the cancer equals to the frequency of heterozygous allele in the given population. 47 | # According to the Hardy-Weinberg equation p∧2 + 2pq + q∧2 = 1, where p is the frequency of dominant allele frequency, q is the frequency of recessive allele frequency, p∧2 is the frequency of the homozygous dominant allele, q∧2 is the frequency of the recessive allele, and 2pq is the frequency of the heterozygous allele. 48 | # Given that q∧2=1/400, hence, q=0.05 and p=1-q=0.95. 49 | # The frequency of the heterozygous allele is 2pq=2*0.05*0.95=38/400. 50 | # The correct answer is (D). 51 | 52 | # Question: A Fe pellet of 0.056 g is first dissolved in 10 mL of hydrobromic acid HBr (0.1 M). The resulting solution is then titrated by KMnO4 (0.02 M). How many equivalence points are there? 53 | 54 | # Choices: 55 | # (A) Two points, 25 ml and 35 ml 56 | # (B) One point, 25 mL 57 | # (C) One point, 10 ml 58 | # (D) Two points, 25 ml and 30 ml 59 | 60 | # Let's think step by step: 61 | # HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-. 62 | # Two equivalence points will exist 25 ml and 35 ml. 63 | # HBr will react with Fe to produce Fe2+. MnO4- will first react with Fe2+ then Br-. 64 | # Two equivalence points will exist 25 ml and 35 ml. 65 | # In the beaker there is Fe2+ and Br-. 66 | # When considering titration with two analytes one will have to consider which reaction will occur first. 67 | # Since it is a redox titration consider the reduction potential of: 68 | # E0 (Br2 /Br- ) = 1.09 V E0 (MnO4-/ Mn2+) = 1.49 V E0 (Fe3+/Fe2+) =0.77 V 69 | # [Fe2+]=m/MV=0.1M. 70 | # Reaction 1: MnO4- + 5Fe2+ + 8H+ → Mn2+ + 5Fe3+ + 4H2O 71 | # Reaction 2: 2MnO4- + 10Br- + 16H+ → 2Mn2+ + 5Br2 + 8H2O 72 | # So MnO4- will first react with Fe2+ with a stoichiometry of 1:5 so Veq1 will be 10 ml. 73 | # Then when Fe2+ is used up, MnO4- will react with Br- with a stoichiometry of 2:10 then V added will be 25 ml so Veq2=25+10=35 ml. 74 | # The correct answer is (A). 75 | 76 | # Question: Consider a quantum mechanical system containing a particle of mass $m$ moving in an istropic three dimensional potential of the form $V(r) = 1/2 m \omega^2 r^2$ corresponding to the acted force obeying Hooke’s law. Here, $\omega$ is the angular frequency of oscillation and $r$ is the radial distance of the particle from the origin in spherical polar coordinate. What is the value of energy of the third excited state, and how many linearly independent eigenfunctions are possible for the same energy eigenvalue? 77 | 78 | # Choices: 79 | # (A) 11 \pi^2 \hbar^2 / (2m r^2), 3 80 | # (B) (9/2) \hbar \omega , 10 81 | # (C) 11 \pi^2 \hbar^2 / (2m r^2), 10 82 | # (D) (9/2) \hbar \omega, 3 83 | 84 | # Let's think step by step: 85 | # This problem is nothing but the three dimensional simple harmonic oscillator (SHO) problem. 86 | # The energy spectrum of three dimensional SHO is $E_n= (n+3/2)\hbar \omega$ where $n=0,1,2,3….$. 87 | # For third excited state n=3. 88 | # 3+3/2=6/2+3/2=9/2. 89 | # Thus the corresponding energy is $(9/2)\hbar \omega$. 90 | # The degeneracy of the state is $g_n= (n+1)(n+2)/2$. 91 | # For n=3, degeneracy is (3+1)*(3+2)/2=4*5/2=10. 92 | # The correct answer is (B). 93 | 94 | # Question: "Your overhear two chemists talking to each other as they leave a synthetic organic chemistry lab. One asks the other "So, how did it go?" The second chemist replies, "Not well - my compounds are on top of each other." What is the second chemist most likely referring to?" 95 | 96 | # Choices: 97 | # (A) The compounds they are working with have similar polarities. 98 | # (B) The compounds they are working with have similar boiling points. 99 | # (C) The compounds they are working with are bonding to each other through non-covalent/van der Waals interactions. 100 | # (D) The compounds they are working with have similar optical rotations. 101 | 102 | # Let's think step by step: 103 | # "On top of each other" commonly refers to two compounds that have similar Rf values on chromatography (a common operation in synthetic chemistry). 104 | # Similar Rf values arise for compounds with similar polarities. 105 | # The correct answer is (A). 106 | 107 | # Question: Two people are playing the following game. A fair coin is tossed into the air. Person A says that in a single toss of the coin, the tail will come. So it's like the first shot or the third shot or the fifth shot. Person B says that the coin will come with a double toss. So like the second, fourth, sixth or eighth shot. Imagine this game played forever. What is the probability that person A wins this game? 108 | 109 | # Choices: 110 | # (A) 1/2 111 | # (B) 1/4 112 | # (C) 2/3 113 | # (D) 1/8 114 | 115 | # Let's think step by step: 116 | # When finding the correct answer, the probability of playing forever and the coin's single-point toss will be calculated. 117 | # For example, a tail may appear on the first shot. 118 | # This probability is 1/2. if the first toss doesn't come up, it shouldn't come to the second roll either, because the second throw is an even number. 119 | # So it can come in the third shot. 120 | # This is (1/2)(1/2)(1/2). 121 | # So (1/2)^3=1/8. 122 | # Or it could come on the fifth shot. 123 | # This is (1/2)^5=1/32. 124 | # This is actually a geometric series that goes on forever. 125 | # We can write this series as follows. 126 | # (1/2) + (1/2)^3 + (1/2)^5 + (1/2)^7 + ………. 127 | # The solution for this series is as follows : a1/(1-r) where a1 is the first number and r is the sequence or r= a2/a1 or a3/a2 etc. 128 | # a1=1/2 129 | # r=(1/2)^2=1/4 130 | # So a1/(1-r)=(1/2)/(1-1/4)=(1/2)/(3/4)=2/3. 131 | # The correct answer is (C). 132 | 133 | # ''' + MMLU_prompt + "Let's think step by step:" 134 | 135 | Least_to_Most_0_shot = MMLU_prompt + '''Please choose the correct choice. In order to solve the question more conveniently and efficiently, break down the question into progressive sub-questions. Answer the sub-questions and get the final result according to sub-questions and their answers. 136 | ''' + "Your last sentence should be \"The correct answer is (insert answer here, which is only the letter of the choice)\"." 137 | 138 | tot_post = ''' 139 | Given the question and several solutions, decide which solution is the most promising. Analyze each solution in detail, then conclude in the last line "The index of the best solution is x", where x is the index number of the solution.''' 140 | 141 | tot_3_solutions = MMLU_prompt + ''' 142 | Solution 1: {solution1} 143 | Solution 2: {solution2} 144 | Solution 3: {solution3}''' 145 | 146 | tot_5_solutions = tot_3_solutions + ''' 147 | Solution 4: {solution4} 148 | Solution 5: {solution5}''' 149 | 150 | tot_10_solutions = tot_5_solutions + ''' 151 | Solution 6: {solution6} 152 | Solution 7: {solution7} 153 | Solution 8: {solution8} 154 | Solution 9: {solution9} 155 | Solution 10: {solution10}''' 156 | 157 | anologous_1_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 158 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + ''' 159 | # Instructions: 160 | ## Relevant Problems: 161 | Recall an example of the {subject} problem that is relevant to the initial problem. Your problem should be distinct from the initial problem (e.g., involving different numbers and names). For the example problem: 162 | - After "Q: ", describe the problem. 163 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 164 | 165 | ## Solve the Initial Problem: 166 | Q: Copy and paste the initial problem here. 167 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 168 | ''' 169 | 170 | anologous_3_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 171 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + ''' 172 | # Instructions: 173 | ## Relevant Problems: 174 | Recall three examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 175 | - After "Q: ", describe the problem. 176 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 177 | 178 | ## Solve the Initial Problem: 179 | Q: Copy and paste the initial problem here. 180 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 181 | ''' 182 | 183 | anologous_5_prompt = '''Your task is to tackle {subject} problems. When presented with a {subject} problem, recall relevant problems as examples. Afterward, proceed to solve the initial problem. 184 | ''' + MMLU_prompt.replace("Question: {question}", "Initial problem: {question}") + ''' 185 | # Instructions: 186 | ## Relevant Problems: 187 | Recall five examples of {subject} problems that are relevant to the initial problem. Your problems should be distinct from each other and from the initial problem (e.g., involving different numbers and names). For each problem: 188 | - After "Q: ", describe the problem. 189 | - After "A: ", explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 190 | 191 | ## Solve the Initial Problem: 192 | Q: Copy and paste the initial problem here. 193 | A: Explain the solution and conclude with the sentence \"The correct answer is (insert answer here, which is only the letter of the choice)\". 194 | ''' 195 | 196 | SBP_extract = '''You are an expert at {subject}. Your task is to extract the {subject} concepts and principles involved in solving the problem. 197 | ''' + MMLU_prompt + ''' 198 | Principles involved: 199 | ''' 200 | 201 | SBP_answer = "You are an expert at {subject}. You are given a {subject} problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles." + ''' 202 | ''' + MMLU_prompt + "\nInstruction:\n"+ prompt_format + ''' 203 | Principles: 204 | {principles} 205 | ''' + ''' 206 | Answer: 207 | ''' 208 | 209 | SBP_extract_physics = '''You are an expert at Physics. Your task is to extract the Physics concepts and principles involved in solving the problem. Here is an example. 210 | 211 | Question: 212 | A spherical conductor carries a net charge. How is this charge distributed on the sphere? 213 | 214 | Choices: 215 | (A) The charge is evenly distributed on the surface. 216 | (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere. 217 | (C) The charge moves continually within the sphere. 218 | (D) The charge is distributed uniformly throughout the sphere. 219 | 220 | Principles Involved: 221 | Coulomb's Law: the force between two charged particles is proportional to the product of their charges and inversely proportional to the square of the distance between them, F = k * q1 * q2 / r, where F is the electric force, k is a constant, q1 and q2 are the charges the particles carry, and r is the distance between them. 222 | 223 | ''' + MMLU_prompt + '''Principles Involved: 224 | ''' 225 | 226 | SBP_answer_physics = "You are an expert at Physics. You are given a Physics problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles. " + MMLU_prompt + ''' Here is an example. 227 | 228 | Question: 229 | A spherical conductor carries a net charge. How is this charge distributed on the sphere? 230 | 231 | Choices: 232 | (A) The charge is evenly distributed on the surface. 233 | (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere. 234 | (C) The charge moves continually within the sphere. 235 | (D) The charge is distributed uniformly throughout the sphere. 236 | 237 | Principles: 238 | Coulomb's Law: the force between two charged particles is proportional to the product of their charges and inversely proportional to the square of the distance between them, F = k * q1 * q2 / r, where F is the electric force, k is a constant, q1 and q2 are the charges the particles carry, and r is the distance between them. 239 | 240 | Answer: 241 | Using the Principles of Coulomb's Law, we can solve the problem as following: 242 | Step 1: Apply Coulomb's Law to find out how charges are distributed on the surface. 243 | In the case of a spherical conductor, the charges on the surface will repel each other. The further apart the charges are, the less force they will exert on each other. Therefore, the charges will distribute themselves evenly on the surface of the sphere, as this is the configuration that minimizes the repulsive force between them. 244 | 245 | Step 2: Apply Coulomb's Law to find out what happens if there are other charges present. 246 | The distribution of charge on the surface may also be affected by the presence of other charged objects near the sphere. For example, if a negatively charged object is brought near a positively charged sphere, the negative charges on the sphere will be repelled and will move to the opposite side of the sphere. This will result in a non-uniform distribution of charge on the surface of the sphere. 247 | 248 | Therefore, the correct answer is (B) The charge resides on the surface only; the distribution of charge on the surface depends on what other charged objects are near the sphere. 249 | 250 | ''' + MMLU_prompt + '''Principles: 251 | {principles} 252 | 253 | Answer: 254 | ''' 255 | 256 | SBP_extract_chemistry = '''You are an expert at Chemistry. Your task is to extract the Chemistry concepts and principles involved in solving the problem. Here is an example. 257 | 258 | Question: 259 | A sample of an unknown chloride compound was dissolved in water, and then titrated with excess Pb(NO3)2 to create a precipitate. After drying, it is determined there are 0.0050 mol of precipitate present. What mass of chloride is present in the original sample? 260 | 261 | Choices: 262 | (A) 0.177 g 263 | (B) 0.355 g 264 | (C) 0.522 g 265 | (D) 0.710 g 266 | 267 | Principles Involved: 268 | Precipitation reactions: Precipitation reactions occur when two soluble salts are mixed and form an insoluble product, called a precipitate. The precipitate can be separated from the solution by filtration or centrifugation. 269 | Molar mass: The molar mass of a substance is the mass of one mole of that substance. The molar mass is expressed in grams per mole (g/mol). 270 | Limiting reactant: The limiting reactant is the reactant that is completely consumed in a chemical reaction. The amount of product formed is determined by the amount of limiting reactant. 271 | 272 | ''' + MMLU_prompt + '''Principles Involved: 273 | ''' 274 | 275 | SBP_answer_chemistry = "You are an expert at Chemistry. You are given a Chemistry problem and a set of principles involved in solving the problem. Solve the problem step by step by following the principles. " + MMLU_prompt + ''' Here is an example. 276 | 277 | Question: 278 | A sample of an unknown chloride compound was dissolved in water, and then titrated with excess Pb(NO3)2 to create a precipitate. After drying, it is determined there are 0.0050 mol of precipitate present. What mass of chloride is present in the original sample? 279 | 280 | Choices: 281 | (A) 0.177 g 282 | (B) 0.355 g 283 | (C) 0.522 g 284 | (D) 0.710 g 285 | 286 | Principles: 287 | Precipitation reactions: Precipitation reactions occur when two soluble salts are mixed and form an insoluble product, called a precipitate. The precipitate can be separated from the solution by filtration or centrifugation. 288 | Molar mass: The molar mass of a substance is the mass of one mole of that substance. The molar mass is expressed in grams per mole (g/mol). 289 | Limiting reactant: The limiting reactant is the reactant that is completely consumed in a chemical reaction. The amount of product formed is determined by the amount of limiting reactant. 290 | 291 | Answer: 292 | Assuming the unknown chloride compound is MCl, where M represents the metal cation, the balanced chemical equation for the precipitation reaction is: 293 | Pb(NO3)2(aq) + 2MCl(aq) −→ PbCl2(s) + 2MNO3(aq) 294 | 295 | Since Pb(NO3)2 is in excess, MCl is the limiting reactant. The stoichiometry of the reaction indicates that 2 moles of MCl produce 1 mole of PbCl2 precipitate. Therefore, 0.0050 mol of PbCl2 corresponds to 0.010 mol of MCl. 296 | 297 | The mass of chloride in the original sample can be calculated using the molar mass of chloride (35.45 g/mol): 298 | 0.010 mol Cl x 35.45 g/mol = 0.355 g Cl 299 | 300 | The correct answer is (B) 0.355 g. 301 | 302 | ''' + MMLU_prompt + '''Principles: 303 | {principles} 304 | 305 | Answer: 306 | ''' 307 | 308 | 309 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from dataset import * 2 | from model import load_model, LLM_generate 3 | from dataset import create_prompt 4 | 5 | import os 6 | import argparse 7 | import copy 8 | from concurrent.futures import ThreadPoolExecutor 9 | from tqdm import tqdm 10 | 11 | from huggingface_hub import login 12 | 13 | hf_token = "hf_YourTokenHere" # Replace with your token 14 | login(token=hf_token) 15 | 16 | 17 | refine1_feeback_prompt = "Review your previous answer and find problems with your answer." 18 | refine1_refine_prompt = "Based on the problems you found, improve your answer." 19 | 20 | 21 | def concat_refine_records(records, output, string): 22 | for i in range(0, len(output)): 23 | for j in range(0, len(output[i])): 24 | records[i][j][string] = output[i][j] 25 | 26 | 27 | def concat_refine_messages(messages, output): 28 | for i in range(0, len(output)): 29 | if isinstance(output, str): 30 | messages[i].append(output) 31 | else: 32 | messages[i].append(output[i][0]["output"]) 33 | 34 | 35 | def get_model_outputs(args): 36 | if args.reasoning in ["DiP", "CoT", "AnP", "L2M"]: 37 | args.query = create_prompt(args) 38 | if args.verbal: 39 | print(args.query) 40 | records = LLM_generate(args) 41 | 42 | elif args.reasoning == "S-RF": 43 | args.reasoning = "DiP" 44 | queries = create_prompt(args) 45 | args.reasoning = "S-RF" 46 | args.messages = [[query] for query in queries] 47 | origin_output = LLM_generate(args) 48 | records = [] 49 | for i in range(0, len(origin_output)): 50 | record_ = [] 51 | for j in range(0, len(origin_output[i])): 52 | record = {} 53 | record["output0"] = origin_output[i][j] 54 | record_.append(record) 55 | records.append(record_) 56 | with ThreadPoolExecutor(max_workers=5) as executor: 57 | parameters = (args.messages, origin_output) 58 | executor.submit(concat_refine_messages, *parameters) 59 | for j in tqdm(range(0, args.rounds)): 60 | with ThreadPoolExecutor(max_workers=5) as executor: 61 | parameters = (args.messages, refine1_feeback_prompt) 62 | executor.submit(concat_refine_messages, *parameters) 63 | output = LLM_generate(args) 64 | if args.verbal: 65 | print(output) 66 | with ThreadPoolExecutor(max_workers=5) as executor: 67 | parameters = (records, output, f"problems{j+1}") 68 | executor.submit(concat_refine_records, *parameters) 69 | parameters = (args.messages, output) 70 | executor.submit(concat_refine_messages, *parameters) 71 | parameters = (args.messages, refine1_refine_prompt + " " + PROMPT_FORMAT) 72 | executor.submit(concat_refine_messages, *parameters) 73 | output = LLM_generate(args) 74 | if args.verbal: 75 | print(output) 76 | with ThreadPoolExecutor(max_workers=5) as executor: 77 | parameters = (records, output, f"output{j+1}") 78 | executor.submit(concat_refine_records, *parameters) 79 | parameters = (args.messages, output) 80 | executor.submit(concat_refine_messages, *parameters) 81 | 82 | elif args.reasoning == "ToT": 83 | records = [] 84 | args.query = create_prompt(args) 85 | if args.verbal: 86 | print(args.query) 87 | l = len(args.questions) 88 | output_choices = LLM_generate(args) 89 | for i in range(0, l): 90 | record_ = [] 91 | record = {} 92 | record["solutions"] = args.records_tot[i] 93 | for j in range(0, args.num): 94 | record["choose"] = output_choices[i][j] 95 | record_.append(record) 96 | records.append(record_) 97 | 98 | elif args.reasoning == "SBP": 99 | records = [] 100 | args.query = create_prompt(args) 101 | if args.verbal: 102 | print(args.query) 103 | principles = LLM_generate(args) 104 | num = args.num 105 | l = len(args.questions) 106 | args.num = 1 107 | args.query = [] 108 | for i in range(0, l): 109 | for j in range(0, num): 110 | record = {} 111 | record["principles"] = principles[i][j] 112 | args.principles = record["principles"]["output"] 113 | args.query.append(create_prompt(args, i)[0]) 114 | del args.principles 115 | solutions = LLM_generate(args) 116 | for i in range(0, l): 117 | record_ = [] 118 | for j in range(0, num): 119 | record = {} 120 | record["principles"] = principles[i][j] 121 | record["solution"] = solutions[num * i + j][0] 122 | record_.append(record) 123 | records.append(record_) 124 | args.num = num 125 | 126 | elif args.reasoning == "MAD": 127 | records = {} 128 | args.reasoning = "DiP" 129 | agent_contexts = [create_prompt(args) for agent in range(0, 3)] 130 | args.reasoning = "MAD" 131 | if args.verbal: 132 | print(args.query) 133 | if args.continue_: 134 | for round in range(args.rounds): 135 | if round < len(args.records): 136 | outputs = [output["output"] for output in args.records[f"round{round+1}"]] 137 | records[f"round{round+1}"] = args.records[f"round{round+1}"] 138 | else: 139 | records[f"round{round+1}"] = [] 140 | for i, agent_context in enumerate(agent_contexts): 141 | if round != 0: 142 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 143 | message = construct_message(args, agent_contexts_other, args.question, 2*round - 1) 144 | agent_context.append(message) 145 | args.messages = [agent_context] 146 | if round < len(args.records): 147 | assistant_message = outputs[i] 148 | else: 149 | record = LLM_generate(args)[0][0] 150 | records[f"round{round+1}"].append(record) 151 | assistant_message = record["output"] 152 | else: 153 | assistant_message = outputs[i] 154 | agent_context.append(assistant_message) 155 | else: 156 | for round in range(args.rounds): 157 | records[f"round{round+1}"] = [] 158 | for i, agent_context in enumerate(agent_contexts): 159 | if round != 0: 160 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 161 | message = construct_message(args, agent_contexts_other, args.question, 2*round - 1) 162 | agent_context.append(message) 163 | args.messages = [agent_context] 164 | record = LLM_generate(args)[0][0] 165 | records[f"round{round+1}"].append(record) 166 | assistant_message = record["output"] 167 | agent_context.append(assistant_message) 168 | records = [[records]] 169 | 170 | return records 171 | 172 | 173 | def handle_tot_reasoning(args): 174 | """Handle Tree of Thoughts (ToT) reasoning setup.""" 175 | logs_tot = [] 176 | shot = args.shot 177 | args.shot = 0 178 | args.reasoning = "CoT" 179 | 180 | for j in range(shot): 181 | args.n = j 182 | logs_tot.append(read_logs(args)) 183 | 184 | args.reasoning = "ToT" 185 | args.shot = shot 186 | return logs_tot 187 | 188 | 189 | def setup_mad_reasoning(args): 190 | """Setup for MAD (Multi-Agent Debate) reasoning, reusing the results of DiP.""" 191 | reasoning = args.reasoning 192 | args.reasoning = "DiP" 193 | 194 | # Collect initial logs 195 | args.n = 0 196 | logs_DiP_0 = read_logs(args) 197 | args.n = 1 198 | logs_DiP_1 = read_logs(args) 199 | args.n = 2 200 | logs_DiP_2 = read_logs(args) 201 | 202 | assert len(logs_DiP_0) == len(logs_DiP_1) == len(logs_DiP_2), "Logs of DiP for MAD reasoning must be of the same length." 203 | assert len(logs_DiP_0) > 0, "To use MAD reasoning, DiP logs must be available." 204 | args.reasoning = reasoning 205 | return logs_DiP_0, logs_DiP_1, logs_DiP_2 206 | 207 | 208 | if __name__ == "__main__": 209 | # Model and dataset configurations 210 | MODEL_CONFIGS = { 211 | "Qwen-2.5": "Qwen/Qwen2.5-7B-Instruct", 212 | "Llama-3": "meta-llama/Meta-Llama-3-8B-Instruct", 213 | "GLM-4": "THUDM/glm-4-9b-chat", 214 | "Gemini": "gemini-1.5-flash", 215 | "GPT4-mini": "gpt-4o-mini", 216 | "Phi-3.5": "microsoft/Phi-3.5-mini-instruct" 217 | } 218 | 219 | DATASET_CONFIGS = { 220 | "math": ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"], 221 | "science": [ 222 | "MMLU-high_school_physics", 223 | "MMLU-high_school_chemistry", 224 | "MMLU-high_school_biology", 225 | "GPQA", 226 | ], 227 | } 228 | 229 | PROMPT_FORMATS = { 230 | "GSM8K": GSM8K.prompt_format, 231 | "GPQA": GPQA.prompt_format, 232 | "GSM-Hard": GSM_Hard.prompt_format, 233 | "MATH": MATH.prompt_format, 234 | "MMLU-high_school_physics": MMLU.prompt_format, 235 | "MMLU-high_school_chemistry": MMLU.prompt_format, 236 | "MMLU-high_school_biology": MMLU.prompt_format, 237 | "AIME_2024": AIME.prompt_format 238 | } 239 | 240 | # Set up argument parser with improved descriptions 241 | parser = argparse.ArgumentParser(description="Large Language Model Evaluation Framework") 242 | 243 | # Model configuration 244 | parser.add_argument("--model_name", type=str, default=MODEL_CONFIGS["Qwen-2.5"], 245 | choices=list(MODEL_CONFIGS.values())) 246 | parser.add_argument("--model_type", type=str, default="vllm", 247 | choices=["vllm", "gemini", "openai"]) 248 | 249 | # Dataset configuration 250 | parser.add_argument("--dataset", type=str, default="GSM8K", 251 | choices=[ds for group in DATASET_CONFIGS.values() for ds in group]) 252 | parser.add_argument("--split", type=str, default="test") 253 | 254 | # Reasoning strategy configuration 255 | parser.add_argument("--reasoning", type=str, default="DiP", 256 | choices=["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT", "MAD"], 257 | help="Reasoning prompting strategy") 258 | parser.add_argument("--shot", type=int, default=0, 259 | help='''DiP, SBP: Shot is fixed at 0, while SBP using 1-shot on MMLU. 260 | CoT, L2M: Number of examples in few-shot prompting. 261 | AnP: Number of analogous problems to generate. 262 | S-RF: Shot is fixed at 0. 263 | ToT: Number of reasoning paths to generate. 264 | MAD: Shot is fixed at 0, using 3 agents for debate. 265 | ''') 266 | 267 | # Processing configuration 268 | parser.add_argument("--max_num_workers", type=int, default=1, 269 | help="Maximum number of workers for parallel processing for API-based models") 270 | parser.add_argument("--batchsize", type=int, default=10) 271 | parser.add_argument("--rounds", type=int, default=5) 272 | 273 | # Generation parameters 274 | parser.add_argument("--temperature", type=float, default=0.7) 275 | parser.add_argument("--top_p", type=float, default=0.8) 276 | parser.add_argument("--max_new_tokens", type=int, default=4096) 277 | 278 | # Generation range configuration 279 | parser.add_argument("--range_begin", type=int, default=0) 280 | parser.add_argument("--range_end", type=int, default=16) 281 | 282 | # Hardware configuration 283 | parser.add_argument("--gpu", type=str, default="4,5", help="GPU IDs to use, e.g., '0,1'") 284 | parser.add_argument("--dtype", type=str, default="bfloat16", 285 | help="Data type for VLLM") 286 | 287 | # API configuration 288 | parser.add_argument("--google_api_key", type=str) 289 | parser.add_argument("--openai_api_key", type=str) 290 | parser.add_argument("--openai_base_url", type=str) 291 | 292 | # Other options 293 | parser.add_argument("--verbal", action="store_true", 294 | help="Enable verbose output") 295 | parser.add_argument("--seed", type=int, default=0) 296 | 297 | args = parser.parse_args() 298 | 299 | # Set up basic configurations 300 | args.range = range(args.range_begin, args.range_end) 301 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 302 | 303 | # Initialize prompt format based on dataset 304 | PROMPT_FORMAT = PROMPT_FORMATS.get(args.dataset) 305 | 306 | # Reset messaging configuration 307 | args.system = None 308 | args.messages = None 309 | args.query = None 310 | 311 | # Special handling for MAD reasoning 312 | if args.reasoning == "MAD": 313 | args.batchsize = 1 314 | args.max_num_workers = 1 315 | 316 | # Load dataset and prepare model 317 | dataset_list = read_dataset(args) 318 | model_name = args.model_name 319 | args.model_name = model_name.split("/")[-1] 320 | 321 | # Special handling for different reasoning strategies 322 | if args.reasoning in ["DiP", "SBP"]: 323 | args.shot = 0 324 | elif args.reasoning in ["CoT"]: 325 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH"]: 326 | assert args.shot in [0, 1, 5], "CoT reasoning requires 0, 1, or 5 shots for GSM8K, GSM-Hard, and MATH datasets" 327 | else: 328 | assert args.shot == 0, f"CoT reasoning requires 0-shot prompting for {args.dataset} dataset" 329 | elif args.reasoning == "AnP": 330 | assert args.shot in [1, 3, 5], "AnP reasoning requires 1, 3, or 5 shots, i.e., generating 1, 3, or 5 analogous problems" 331 | elif args.reasoning in ["S-RF", "MAD"]: 332 | args.shot = 0 333 | assert len(args.range) == 1, "Range from the beginning to the end must be 1 for S-RF and MAD reasoning" 334 | elif args.reasoning == "ToT": 335 | assert args.shot in [3, 5, 10], "ToT reasoning requires 3, 5, or 10 shots, i.e., generating 3, 5, or 10 reasoning paths" 336 | logs_tot = handle_tot_reasoning(args) 337 | 338 | args.num = len(args.range) 339 | 340 | # Handle MAD reasoning specific setup 341 | if args.reasoning == "MAD": 342 | logs_DiP_0, logs_DiP_1, logs_DiP_2 = setup_mad_reasoning(args) 343 | 344 | print(f"{'='*10}{args.dataset}===={args.reasoning}===={args.shot}{'='*32}") 345 | 346 | # Process logs and validate progress 347 | logs_all = [] 348 | for j in args.range: 349 | args.n = j 350 | logs_all.append(read_logs(args)) 351 | args.n = args.range[0] 352 | logs = logs_all[0] 353 | begin_num = len(logs) 354 | 355 | assert begin_num < len(dataset_list), "Processing already completed. Please check logs." 356 | 357 | # Initialize model 358 | args.model_name = model_name 359 | load_model(args) 360 | args.model_name = model_name.split("/")[-1] 361 | 362 | 363 | letters = ["A", "B", "C", "D", "E"] 364 | if args.reasoning == "MAD": 365 | if len(logs) == len(dataset_list): 366 | rounds = len(logs[0]["record"]) 367 | if rounds < args.rounds: 368 | begin_num = 0 369 | args.continue_ = True 370 | else: 371 | args.continue_ = False 372 | else: 373 | args.continue_ = False 374 | 375 | if args.model_type == "vllm": 376 | for ii in tqdm(range(begin_num, len(dataset_list), args.batchsize)): 377 | examples = [] 378 | args.questions = [] 379 | args.choiceses = [] 380 | args.subjects = [] 381 | range_ = range(ii, min(len(dataset_list), ii + args.batchsize)) 382 | args.records_tot = [] 383 | for i in range_: 384 | if args.reasoning == "ToT": 385 | records = [] 386 | for j in range(0, args.shot): 387 | records.append(logs_tot[j][i]["record"]) 388 | args.records_tot.append(records) 389 | example = copy.deepcopy(dataset_list[i]) 390 | if args.dataset == "GSM8K": 391 | args.questions.append(example["question"]) 392 | answer = example["answer"] 393 | key = answer.split("#### ")[-1] 394 | example["key"] = key 395 | elif args.dataset == "GPQA": 396 | args.questions.append(example['problem']) 397 | args.choiceses.append(example['choices']) 398 | args.subjects.append(example["subject"].lower()) 399 | example["question"] = example.pop("problem") 400 | example["key"] = example["answer"] 401 | elif args.dataset == "GSM-Hard": 402 | args.questions.append(example["input"]) 403 | example["question"] = example.pop("input") 404 | example["key"] = example.pop("target") 405 | elif args.dataset == "MATH": 406 | args.questions.append(example["problem"]) 407 | example["question"] = example.pop("problem") 408 | example["key"] = example.pop("answer") 409 | elif "MMLU" in args.dataset: 410 | args.questions.append(example["question"]) 411 | args.subjects.append(example["subject"]) 412 | args.choiceses.append(example["choices"]) 413 | example["key"] = letters[example.pop("answer")] 414 | elif args.dataset == "AIME_2024": 415 | args.questions.append(example["Problem"]) 416 | example["question"] = example.pop("Problem") 417 | example["key"] = example.pop("Answer") 418 | example["num"] = i 419 | examples.append(example) 420 | 421 | if args.reasoning == "MAD": 422 | args.previous_record = [logs_DiP_0[i]["record"], logs_DiP_1[i]["record"], logs_DiP_2[i]["record"]] 423 | if args.continue_: 424 | args.records = logs[i]["record"] 425 | 426 | records_all = get_model_outputs(args) 427 | 428 | for i in range(0, len(range_)): 429 | records = records_all[i] 430 | for j, k in enumerate(args.range): 431 | new_example = examples[i].copy() 432 | new_example["record"] = records[j] 433 | logs_all[j].append(new_example) 434 | del new_example 435 | del examples 436 | 437 | for j, k in enumerate(args.range): 438 | args.n = k 439 | record_logs(logs_all[j], args) 440 | else: 441 | nums = [log["num"] for log in logs] 442 | remain_data = [i for i in range(len(dataset_list)) if i not in nums] 443 | remain_ranges = [remain_data[i:i+args.batchsize] for i in range(0, len(remain_data), args.batchsize)] 444 | for range_ in tqdm(remain_ranges): 445 | examples = [] 446 | args.questions = [] 447 | args.choiceses = [] 448 | args.subjects = [] 449 | args.records_tot = [] 450 | for i in range_: 451 | if "tot" in args.reasoning: 452 | records = [] 453 | for j in range(0, args.shot): 454 | records.append(logs_tot[j][i]["record"]) 455 | args.records_tot.append(records) 456 | example = copy.deepcopy(dataset_list[i]) 457 | if args.dataset == "GSM8K": 458 | args.questions.append(example["question"]) 459 | answer = example["answer"] 460 | key = answer.split("#### ")[-1] 461 | example["key"] = key 462 | elif args.dataset == "GPQA": 463 | args.questions.append(example['problem']) 464 | args.choiceses.append(example['choices']) 465 | args.subjects.append(example["subject"].lower()) 466 | example["key"] = example["answer"] 467 | elif args.dataset == "GSM-Hard": 468 | args.questions.append(example["input"]) 469 | example["question"] = example.pop("input") 470 | example["key"] = example.pop("target") 471 | elif args.dataset == "MATH": 472 | args.questions.append(example["problem"]) 473 | example["question"] = example.pop("problem") 474 | example["key"] = example.pop("answer") 475 | elif "MMLU" in args.dataset: 476 | args.questions.append(example["question"]) 477 | args.subjects.append(example["subject"]) 478 | args.choiceses.append(example["choices"]) 479 | example["key"] = letters[example.pop("answer")] 480 | elif args.dataset == "AIME_2024": 481 | args.questions.append(example["Problem"]) 482 | example["question"] = example.pop("Problem") 483 | example["key"] = example.pop("Answer") 484 | example["num"] = i 485 | examples.append(example) 486 | 487 | if args.reasoning == "MAD": 488 | args.previous_record = [logs_DiP_0[i]["record"], logs_DiP_1[i]["record"], logs_DiP_2[i]["record"]] 489 | if args.continue_: 490 | args.records = logs[i]["record"] 491 | 492 | records_all = get_model_outputs(args) 493 | 494 | for i in range(0, len(range_)): 495 | records = records_all[i] 496 | for j, k in enumerate(args.range): 497 | new_example = examples[i].copy() 498 | new_example["record"] = records[j] 499 | logs_all[j].append(new_example) 500 | del new_example 501 | del examples 502 | 503 | for j, k in enumerate(args.range): 504 | args.n = k 505 | record_logs(logs_all[j], args) 506 | -------------------------------------------------------------------------------- /eval_csv_cost.py: -------------------------------------------------------------------------------- 1 | from openpyxl import Workbook, load_workbook 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | from dataset import * 6 | import argparse 7 | from collections import Counter 8 | import random 9 | import os 10 | import re 11 | 12 | from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter 13 | 14 | 15 | def all_equal(lst): 16 | return all(x == lst[0] for x in lst) 17 | 18 | 19 | def find_most_common_elements(input_list): 20 | counter = Counter(input_list) 21 | max_count = max(counter.values()) 22 | most_common_elements = [element for element, count in counter.items() if count == max_count] 23 | return most_common_elements, max_count 24 | 25 | 26 | def get_most_common_answer(outputs): 27 | outputs = [output for output in outputs if output != None] 28 | if outputs == []: 29 | return None 30 | most_common_elements, max_count = find_most_common_elements(outputs) 31 | most_common_answer = random.choice(most_common_elements) 32 | return most_common_answer 33 | 34 | 35 | def get_cost(model_name, prompt_tokens, completion_tokens): 36 | prompt_tokens = float(prompt_tokens) 37 | completion_tokens = float(completion_tokens) 38 | if "gemini" in model_name: 39 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3 40 | elif model_name == "gpt-3.5-turbo-0613": 41 | cost = prompt_tokens * 1.5 + completion_tokens * 2 42 | elif model_name == "gpt-4o-mini": 43 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 44 | else: 45 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 46 | cost = cost / 10 ** 6 47 | return cost 48 | 49 | 50 | if __name__ == "__main__": 51 | model_names = ["Qwen/Qwen2.5-7B-Instruct", 52 | "meta-llama/Meta-Llama-3-8B-Instruct", 53 | "THUDM/glm-4-9b-chat", 54 | "gemini-1.5-flash", 55 | "gpt-4o-mini", 56 | "microsoft/Phi-3.5-mini-instruct"] 57 | 58 | datasets = ["GSM8K", 59 | "GPQA", 60 | "GSM-Hard", 61 | "MATH", 62 | "MMLU-high_school_physics", 63 | "MMLU-high_school_chemistry", 64 | "MMLU-high_school_biology", 65 | "AIME_2024"] 66 | 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--model_name", type=str, default=model_names[0]) 69 | parser.add_argument("--dataset", type=str, default=datasets[0]) 70 | parser.add_argument("--split", type=str, default="test") 71 | parser.add_argument("--shuffle", type=bool, default=True) 72 | args = parser.parse_args() 73 | 74 | DiP_max = 100 ## maximum sampling time of DiP 75 | CoT_max = 90 ## maximum sampling time of CoT 76 | L2M_max = 75 ## maximum sampling time of L2M 77 | SBP_max = 25 ## maximum sampling time of SBP 78 | AnP_max = 40 ## maximum sampling time of AnP 79 | S_RF_max = 10 ## maximum round of S-RF 80 | MAD_max = 10 ## maximum round of MAD 81 | 82 | sampling_times = [1, 3, 5, 7, 10, 12, 15] 83 | 84 | model_names_formal = { 85 | "Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct", 86 | "Llama-3-8B-Instruct": "LLaMA-3-8B-Instruct", 87 | "glm-4-9b-chat": "GLM-4-9B-Chat", 88 | "gemini-1.5-flash": "Gemini-1.5-Flash", 89 | "gpt-4o-mini": "GPT-4o-mini", 90 | "Phi-3.5-mini-instruct": "Phi-3.5-mini-Instruct" 91 | } 92 | 93 | marker_dict = { 94 | 'DiP': '^', 95 | 'CoT': '^', 96 | 'L2M': '^', 97 | 'ToT': 'o', 98 | 'S-RF':'o', 99 | 'SBP': '^', 100 | 'AnP': '^', 101 | 'MAD': 'o' 102 | } 103 | 104 | 105 | if "/" in args.model_name: 106 | args.model_name = args.model_name.split('/')[-1] 107 | 108 | path = os.path.join(log_path, args.dataset, f"{args.dataset}_cost.xlsx") 109 | if os.path.exists(path): 110 | wb = load_workbook(path) 111 | else: 112 | wb = Workbook() 113 | 114 | unique_labels = [] 115 | 116 | headers = ["Method", "Subject", "x-shot", "num", "prompt_tokens", "completion_tokens", "cost", "tokens", "accuracy"] 117 | 118 | if args.model_name in wb.sheetnames: 119 | ws_new = wb[args.model_name] 120 | ws_new.delete_rows(1, ws_new.max_row) 121 | ws_new.delete_cols(1, ws_new.max_column) 122 | else: 123 | ws_new = wb.create_sheet(args.model_name) 124 | ws_new.append(headers) 125 | 126 | try: 127 | model_name = args.model_name 128 | sheet = wb.get_sheet_by_name(model_name) 129 | 130 | tokens = [] 131 | accuracy = [] 132 | cost = [] 133 | labels = [] 134 | flag = 0 135 | for row in sheet.iter_rows(values_only=True): 136 | if flag == 0: 137 | flag = 1 138 | continue 139 | # if row[0] != "tot-io": 140 | labels.append(row[0]) 141 | tokens.append(int(row[-2])) 142 | accuracy.append(float(row[-1])) 143 | cost.append(float(row[-3])) 144 | counter = Counter(labels) 145 | assert counter["DiP"] >= DiP_max 146 | assert counter["CoT"] >= CoT_max 147 | assert counter["L2M"] >= L2M_max 148 | assert counter["SBP"] >= SBP_max 149 | assert counter["AnP"] >= AnP_max 150 | # assert counter["ToT"] >= 3 151 | # assert counter["S-RF"] >= S_RF_max 152 | # assert counter["MAD"] >= MAD_max 153 | except: 154 | ws_new.delete_rows(1, ws_new.max_row) 155 | ws_new.delete_cols(1, ws_new.max_column) 156 | reasonings = ["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT_3", "ToT_5", "ToT_10", "MAD"] 157 | 158 | for reasoning in reasonings: 159 | try: 160 | if reasoning == "DiP": 161 | args.reasoning = "DiP" 162 | args.shot = 0 163 | for N in range(0, DiP_max+1): 164 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 165 | break 166 | 167 | elif reasoning == "CoT": 168 | args.reasoning = "CoT" 169 | args.shot = 0 170 | for N in range(0, CoT_max+1): 171 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 172 | break 173 | 174 | elif "L2M" in reasoning: 175 | args.reasoning = "L2M" 176 | if args.dataset in ["GSM8K", "GSM-Hard"]: 177 | args.shot = 1 178 | else: 179 | args.shot = 0 180 | for N in range(0, L2M_max+1): 181 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 182 | break 183 | 184 | elif reasoning == "ToT_3": 185 | args.reasoning = "ToT" 186 | args.shot = 3 187 | N = 1 188 | elif reasoning == "ToT_5": 189 | args.reasoning = "ToT" 190 | args.shot = 5 191 | N = 1 192 | elif reasoning == "ToT_10": 193 | args.reasoning = "ToT" 194 | args.shot = 10 195 | N = 1 196 | 197 | elif reasoning == "S-RF": 198 | args.reasoning = reasoning 199 | args.shot = 0 200 | N = min((len(logs[-1]["record"]) - 1) // 2, S_RF_max) 201 | 202 | elif reasoning == "SBP": 203 | args.reasoning = "SBP" 204 | args.shot = 0 205 | for N in range(0, SBP_max+1): 206 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 207 | break 208 | 209 | elif reasoning == "AnP": 210 | args.reasoning = "AnP" 211 | args.shot = 1 212 | for N in range(0, AnP_max+1): 213 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 214 | break 215 | 216 | elif reasoning == "MAD": 217 | args.reasoning = "MAD" 218 | args.shot = 0 219 | args.n = 0 220 | logs = read_logs(args) 221 | N = min(len(logs[0]["record"]), MAD_max) 222 | 223 | logs_list = [] 224 | assert N != 0 225 | for m in range(0, N): 226 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP", "SBP"]: 227 | args.nums = range(0,m+1) 228 | args.n = m 229 | logs = read_logs(args) 230 | logs_list.append(logs) 231 | if (m+1) not in sampling_times: 232 | continue 233 | 234 | elif args.reasoning in ["S-RF", "MAD"]: 235 | args.n = 0 236 | args.nums = range(0, m+1) 237 | logs = read_logs(args) 238 | logs_list.append(logs) 239 | 240 | elif "ToT" in args.reasoning: 241 | args.nums = range(0, m+1) 242 | for n in range(0, 5): 243 | try: 244 | args.n = n 245 | logs = read_logs(args) 246 | logs_list.append(logs) 247 | except: 248 | break 249 | 250 | accs = [] 251 | prompt_tokens_ = 0 252 | completion_tokens_ = 0 253 | l = len(logs_list[0]) 254 | 255 | for count in range(0, 5): 256 | random.shuffle(logs_list) 257 | acc_num = 0 258 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"]: 259 | subject = "mathematic" 260 | for j in range(0, l): 261 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset: 262 | subject = logs_list[0][j]["subject"] 263 | key = logs_list[0][j]['key'] 264 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP"]: 265 | output_keys = [parse_answer(args, log[j]['record']['output']) for log in logs_list] 266 | output_keys = [output for output in output_keys if output != None] 267 | if len(output_keys) == 0: 268 | output_keys = [None] 269 | 270 | if count == 0: 271 | prompt_tokens = [log[j]['record']['usage']['prompt_tokens'] for log in logs_list[:1]] 272 | completion_tokens = [log[j]['record']['usage']['completion_tokens'] for log in logs_list] 273 | 274 | elif args.reasoning == "S-RF": 275 | output_keys = [parse_answer(args, logs[j]["record"][f"output{m+1}"]["output"])] 276 | if count == 0: 277 | prompt_tokens = [logs[j]["record"]["output0"]['usage']["prompt_tokens"]] 278 | completion_tokens = [logs[j]["record"]["output0"]['usage']["completion_tokens"]] 279 | for k in range(0, m+1): 280 | prompt_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["prompt_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["prompt_tokens"]] 281 | completion_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["completion_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["completion_tokens"]] 282 | 283 | elif "tot" in args.reasoning: 284 | indexes = [] 285 | for logs in logs_list: 286 | index = parse_best_solution(logs[j]["record"]["choose"]["output"]) 287 | if index != None and "0" < index and index <= str(args.shot): 288 | index = int(index) - 1 289 | else: 290 | index = random.choice(range(0, args.shot)) 291 | indexes.append(index) 292 | solutions = logs_list[0][j]["record"]["solutions"] 293 | index = get_most_common_answer(indexes) 294 | best_solution = solutions[index] 295 | output_keys = [parse_answer(args, best_solution["output"])] 296 | if count == 0: 297 | prompt_tokens = [] 298 | completion_tokens = [] 299 | for k in range(0, len(solutions)): 300 | solution = solutions[k] 301 | if k == 0: 302 | prompt_tokens.append(solution["usage"]["prompt_tokens"]) 303 | completion_tokens.append(solution["usage"]["completion_tokens"]) 304 | for ii, logs in enumerate(logs_list): 305 | if ii == 0: 306 | prompt_tokens.append(logs[j]["record"]["choose"]["usage"]["prompt_tokens"]) 307 | completion_tokens.append(logs[j]["record"]["choose"]["usage"]["completion_tokens"]) 308 | 309 | elif args.reasoning == "SBP": 310 | output_keys = [parse_answer(args, log[j]['record']['solution']["output"]) for log in logs_list] 311 | output_keys = [output for output in output_keys if output != None] 312 | if len(output_keys) == 0: 313 | output_keys = [None] 314 | if count == 0: 315 | prompt_tokens = [] 316 | completion_tokens = [] 317 | for k in range(0, len(logs_list)): 318 | log = logs_list[k][j] 319 | if k == 0: 320 | prompt_tokens.append(log["record"]["principles"]["usage"]["prompt_tokens"]) 321 | completion_tokens.append(log["record"]["principles"]["usage"]["completion_tokens"]) 322 | prompt_tokens.append(log["record"]["solution"]["usage"]["prompt_tokens"]) 323 | completion_tokens.append(log["record"]["solution"]["usage"]["completion_tokens"]) 324 | 325 | elif args.reasoning == "MAD": 326 | output_keys = [parse_answer(args, log["output"]) for log in logs[j]["record"][f"round{m+1}"]] 327 | output_keys = [output for output in output_keys if output != None] 328 | if len(output_keys) == 0: 329 | output_keys = [None] 330 | if count == 0: 331 | prompt_tokens = [] 332 | completion_tokens = [] 333 | for k in range(0, m+1): 334 | prompt_tokens += [log["usage"]["prompt_tokens"] for log in logs[j]["record"][f"round{k+1}"]] 335 | completion_tokens += [log["usage"]["completion_tokens"] for log in logs[j]["record"][f"round{k+1}"]] 336 | 337 | if count == 0: 338 | prompt_tokens_ += sum(prompt_tokens) 339 | completion_tokens_ += sum(completion_tokens) 340 | 341 | most_common_elements, max_count = find_most_common_elements(output_keys) 342 | output_key = random.choice(most_common_elements) 343 | 344 | if args.dataset in ["GSM8K"]: 345 | if output_key != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output_key) < 10**(-4): 346 | acc_num += 1 347 | elif args.dataset in ["GSM-Hard"]: 348 | if output_key != None and abs(float(key) - output_key) < 10**(-4): 349 | acc_num += 1 350 | elif args.dataset in ["MATH", "AIME_2024"]: 351 | if is_equiv(str(key), output_key): 352 | acc_num += 1 353 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset: 354 | if key == output_key: 355 | acc_num += 1 356 | acc = acc_num / l * 100 357 | accs.append(acc) 358 | 359 | acc = sum(accs)/len(accs) 360 | total_tokens = prompt_tokens_ + completion_tokens_ 361 | cost = get_cost(args.model_name, prompt_tokens_, completion_tokens_) 362 | 363 | s = f"{args.reasoning.ljust(15)} " + "{}-shot".format(args.shot).ljust(7) + f" {str(args.nums).ljust(12)} prompt_tokens: {str(prompt_tokens_).ljust(10)} completion_tokens: {str(completion_tokens_).ljust(10)} cost:" + str("%.12f"%(cost)).ljust(20) + f"tokens: {str(total_tokens).ljust(11)}" " Acc: " + "%.4f"%acc 364 | print(s) 365 | 366 | ws_new.append([args.reasoning, subject, str(args.shot), str(args.nums[-1]+1), str(prompt_tokens_), str(completion_tokens_), str("%.15f"%(cost)), str(total_tokens), str("%.4f"%acc)]) 367 | 368 | except Exception as e: 369 | print(args.reasoning, "error") 370 | 371 | wb.save(path) 372 | 373 | model_name = args.model_name 374 | sheet = wb.get_sheet_by_name(model_name) 375 | 376 | tokens = [] 377 | accuracy = [] 378 | cost = [] 379 | labels = [] 380 | flag = 0 381 | for row in sheet.iter_rows(values_only=True): 382 | if flag == 0: 383 | flag = 1 384 | continue 385 | labels.append(row[0]) 386 | tokens.append(int(row[-2])) 387 | accuracy.append(float(row[-1])) 388 | cost.append(float(row[-3])) 389 | 390 | wb.close() 391 | 392 | unique_labels = ["DiP", "CoT", "L2M", "ToT", "S-RF", "SBP", "AnP", "MAD"] 393 | colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) 394 | 395 | plt.figure(figsize=(6, 3)) 396 | ax = plt.gca() 397 | 398 | colors = np.array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00], 399 | [2.17647059e-01, 4.29120609e-01, 9.75511968e-01, 1.00000000e+00], 400 | [17.25490196e-02, 8.82927610e-01, 1, 1.00000000e+00], 401 | [3.54901961e-01, 9.74138602e-01, 7.82927610e-01, 1.00000000e+00], 402 | [6.45098039e-01, 9.74138602e-01, 6.22112817e-01, 1.00000000e+00], 403 | [1, 7.82927610e-01, 5.34676422e-01, 1.00000000e+00], 404 | [1.00000000e+00, 4.29120609e-01, 2.19946358e-01, 1.00000000e+00], 405 | [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00],]) 406 | 407 | 408 | for i, label in enumerate(unique_labels): 409 | mask = [label == lb for lb in labels] 410 | plt.scatter([cost[idx] for idx, val in enumerate(mask) if val], 411 | [accuracy[idx] for idx, val in enumerate(mask) if val], 412 | color=colors[i], label=label, s=36, marker=marker_dict[label], edgecolor='black', linewidths=0.75, zorder=5) 413 | plt.plot([cost[idx] for idx, val in enumerate(mask) if val], 414 | [accuracy[idx] for idx, val in enumerate(mask) if val], 415 | color=colors[i], linewidth=1.5, marker=marker_dict[label]) 416 | 417 | max_cost = max(cost) // 0.1 * 0.1 418 | x_ticks = [max_cost / 10 * i for i in range(1, 11)] 419 | pos = - x_ticks[0] / 1 420 | 421 | y_max = max(accuracy) 422 | y_min = min(accuracy) 423 | ind_line = (y_max - y_min) / 7.5 424 | plt.ylim(bottom=None, top=y_max + ind_line * 1.5) 425 | 426 | def custom_formatter(y, pos): 427 | if y > max(accuracy) + 0.3: 428 | return "" 429 | else: 430 | return f"{int(y)}" 431 | 432 | for x in x_ticks: 433 | plt.axvline(x=x, color='gray', linestyle='--', alpha=0.5, zorder=0) 434 | 435 | mask = [cost[idx] <= x for idx in range(len(cost))] 436 | if any(mask): 437 | best_idx = np.argmax([accuracy[idx] for idx, val in enumerate(mask) if val]) 438 | best_label = [labels[idx] for idx, val in enumerate(mask) if val][best_idx] 439 | 440 | plt.scatter(x, y_max + ind_line, color=colors[unique_labels.index(best_label)], 441 | marker=marker_dict[best_label], s=64, zorder=8, edgecolor='black', linewidths=1) 442 | 443 | plt.axhline(y=y_max + ind_line * 0.5, color='k', linestyle='-', alpha=0.5, zorder=10, linewidth=1.5) 444 | 445 | plt.rcParams['xtick.labelsize'] = 14 446 | plt.rcParams['ytick.labelsize'] = 14 447 | plt.rcParams['font.weight'] = 'bold' 448 | plt.rcParams['axes.labelweight'] = 'bold' 449 | plt.rcParams['axes.titleweight'] = 'bold' 450 | plt.title(f"{model_names_formal[args.model_name]}", fontsize=18) 451 | plt.xlabel('Cost ($)', fontsize=14, fontweight='bold') 452 | plt.ylabel('Accuracy', fontsize=14, fontweight='bold') 453 | 454 | plt.xticks(x_ticks) 455 | 456 | yticks = ax.get_yticks() 457 | yticks = yticks[yticks <= max(accuracy)] 458 | 459 | ax.set_yticks(yticks) 460 | ax.set_yticklabels([str(int(tick)) for tick in yticks]) 461 | ax.yaxis.set_major_locator(MaxNLocator(integer=True)) 462 | ax.yaxis.set_major_formatter(FuncFormatter(custom_formatter)) 463 | 464 | plt.text(pos, y_max + ind_line, r'$\mathbf{P}_{O}^*$', fontsize=12, ha='center', va='center') 465 | plt.rcParams['font.weight'] = 'bold' 466 | plt.rcParams['axes.labelweight'] = 'bold' 467 | plt.rcParams['axes.titleweight'] = 'bold' 468 | plt.legend(loc='best', fontsize=12, framealpha=0.9, ncol=2, markerscale=1.5, handletextpad=0.5, columnspacing=1.0) 469 | 470 | plt.gca().yaxis.set_major_locator(MultipleLocator(5)) 471 | plt.savefig(os.path.join(log_path, args.dataset, args.model_name, "pics", f"Performance_cost.png"), bbox_inches='tight', dpi = 600) 472 | plt.close() 473 | print(f"Performance_cost.png saved to {os.path.join(log_path, args.dataset, args.model_name, 'pics', 'Performance_cost.png')}") 474 | 475 | -------------------------------------------------------------------------------- /eval_csv_N.py: -------------------------------------------------------------------------------- 1 | from openpyxl import Workbook, load_workbook 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | from dataset import * 6 | import argparse 7 | from collections import Counter 8 | import random 9 | import os 10 | import re 11 | 12 | from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter 13 | 14 | 15 | def all_equal(lst): 16 | return all(x == lst[0] for x in lst) 17 | 18 | 19 | def find_most_common_elements(input_list): 20 | counter = Counter(input_list) 21 | max_count = max(counter.values()) 22 | most_common_elements = [element for element, count in counter.items() if count == max_count] 23 | return most_common_elements, max_count 24 | 25 | 26 | def get_most_common_answer(outputs): 27 | outputs = [output for output in outputs if output != None] 28 | if outputs == []: 29 | return None 30 | most_common_elements, max_count = find_most_common_elements(outputs) 31 | most_common_answer = random.choice(most_common_elements) 32 | return most_common_answer 33 | 34 | 35 | def get_cost(model_name, prompt_tokens, completion_tokens): 36 | prompt_tokens = float(prompt_tokens) 37 | completion_tokens = float(completion_tokens) 38 | if "gemini" in model_name: 39 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3 40 | elif model_name == "gpt-3.5-turbo-0613": 41 | cost = prompt_tokens * 1.5 + completion_tokens * 2 42 | elif model_name == "gpt-4o-mini": 43 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 44 | else: 45 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 46 | cost = cost / 10 ** 6 47 | return cost 48 | 49 | 50 | if __name__ == "__main__": 51 | model_names = ["Qwen/Qwen2.5-7B-Instruct", 52 | "meta-llama/Meta-Llama-3-8B-Instruct", 53 | "THUDM/glm-4-9b-chat", 54 | "gemini-1.5-flash", 55 | "gpt-4o-mini", 56 | "microsoft/Phi-3.5-mini-instruct"] 57 | 58 | datasets = ["GSM8K", 59 | "GPQA", 60 | "GSM-Hard", 61 | "MATH", 62 | "MMLU-high_school_physics", 63 | "MMLU-high_school_chemistry", 64 | "MMLU-high_school_biology", 65 | "AIME_2024"] 66 | 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--model_name", type=str, default=model_names[0]) 69 | parser.add_argument("--dataset", type=str, default=datasets[0]) 70 | parser.add_argument("--split", type=str, default="test") 71 | parser.add_argument("--shuffle", type=bool, default=True) 72 | args = parser.parse_args() 73 | 74 | sampling_times = [1, 3, 5, 7, 10, 15] 75 | pos = - sampling_times[-1] / 30 * sampling_times[-1]/15 76 | 77 | model_names_formal = { 78 | "Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct", 79 | "Llama-3-8B-Instruct": "LLaMA-3-8B-Instruct", 80 | "glm-4-9b-chat": "GLM-4-9B-Chat", 81 | "gemini-1.5-flash": "Gemini-1.5-Flash", 82 | "gpt-4o-mini": "GPT-4o-mini", 83 | "Phi-3.5-mini-instruct": "Phi-3.5-mini-Instruct" 84 | } 85 | 86 | # Maximum sampling time for each LLM, in order to avoid loading files that are recording the current runnning results. 87 | N_dict = { 88 | "Qwen2.5-7B-Instruct": 16, 89 | "Llama-3-8B-Instruct": 16, 90 | "glm-4-9b-chat": 16, 91 | "gemini-1.5-flash": 16, 92 | "gpt-4o-mini": 16, 93 | "Phi-3.5-mini-instruct": 16 94 | } 95 | 96 | marker_dict = { 97 | 'DiP': '^', 98 | 'CoT': '^', 99 | 'L2M': '^', 100 | 'ToT': 'o', 101 | 'S-RF':'o', 102 | 'SBP': '^', 103 | 'AnP': '^', 104 | 'MAD': 'o' 105 | } 106 | 107 | 108 | if "/" in args.model_name: 109 | args.model_name = args.model_name.split('/')[-1] 110 | N_max = N_dict[args.model_name] 111 | 112 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, "pics")): 113 | os.mkdir(os.path.join(log_path, args.dataset, args.model_name, "pics")) 114 | 115 | if "/" in args.model_name: 116 | args.model_name = args.model_name.split('/')[-1] 117 | 118 | path = os.path.join(log_path, args.dataset, f"{args.dataset}_N.xlsx") 119 | if os.path.exists(path): 120 | wb = load_workbook(path) 121 | else: 122 | wb = Workbook() 123 | 124 | unique_labels = [] 125 | 126 | headers = ["Method", "Subject", "x-shot", "num", "prompt_tokens", "completion_tokens", "cost", "tokens", "accuracy"] 127 | 128 | if args.model_name in wb.sheetnames: 129 | ws_new = wb[args.model_name] 130 | # ws_new.delete_rows(1, ws_new.max_row) 131 | # ws_new.delete_cols(1, ws_new.max_column) 132 | else: 133 | ws_new = wb.create_sheet(args.model_name) 134 | ws_new.append(headers) 135 | 136 | reasonings = ["DiP", "CoT", "L2M", "SBP", "AnP", "S-RF", "ToT_3", "ToT_5", "ToT_10", "MAD"] 137 | 138 | try: 139 | model_name = args.model_name 140 | sheet = wb.get_sheet_by_name(model_name) 141 | 142 | tokens = [] 143 | accuracy = [] 144 | cost = [] 145 | labels = [] 146 | flag = 0 147 | for row in sheet.iter_rows(values_only=True): 148 | if flag == 0: 149 | flag = 1 150 | continue 151 | # if row[0] != "tot-io": 152 | labels.append(row[0]) 153 | tokens.append(int(row[-2])) 154 | accuracy.append(float(row[-1])) 155 | cost.append(float(row[-3])) 156 | counter = Counter(labels) 157 | assert counter["DiP"] >= N_max 158 | assert counter["CoT"] >= N_max 159 | assert counter["L2M"] >= N_max 160 | assert counter["SBP"] >= N_max 161 | assert counter["AnP"] >= N_max 162 | # assert counter["ToT"] >= 3 163 | # assert counter["S-RF"] > 3 164 | # assert counter["MAD"] > 3 165 | except: 166 | for reasoning in reasonings: 167 | try: 168 | if reasoning == "DiP": 169 | args.reasoning = "DiP" 170 | args.shot = 0 171 | for N in range(0, N_max+1): 172 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 173 | break 174 | # N = N - 1 175 | elif reasoning == "CoT": 176 | args.reasoning = "CoT" 177 | args.shot = 0 178 | for N in range(0, N_max+1): 179 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 180 | break 181 | # N = N - 1 182 | elif reasoning == "L2M": 183 | args.reasoning = "L2M" 184 | if args.dataset in ["GSM8K", "GSM-Hard"]: 185 | args.shot = 1 186 | else: 187 | args.shot = 0 188 | for N in range(0, N_max+1): 189 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 190 | break 191 | # N = N - 1 192 | elif reasoning == "ToT_3": 193 | args.reasoning = "ToT" 194 | args.shot = 3 195 | N = 1 196 | elif reasoning == "ToT_5": 197 | args.reasoning = "ToT" 198 | args.shot = 5 199 | N = 1 200 | elif reasoning == "ToT_10": 201 | args.reasoning = "ToT" 202 | args.shot = 10 203 | N = 1 204 | elif reasoning == "S-RF": 205 | args.reasoning = reasoning 206 | args.shot = 0 207 | N = (len(logs[-1]["record"]) - 1) // 2 208 | 209 | elif reasoning == "SBP": 210 | args.reasoning = "SBP" 211 | args.shot = 0 212 | for N in range(0, N_max+1): 213 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 214 | break 215 | # N = N - 1 216 | elif reasoning == "AnP": 217 | args.reasoning = "AnP" 218 | args.shot = 1 219 | for N in range(0, N_max+1): 220 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{N}.json")): 221 | break 222 | elif reasoning == "MAD": 223 | args.reasoning = "MAD" 224 | args.shot = 0 225 | args.n = 0 226 | logs = read_logs(args) 227 | N = len(logs[0]["record"]) 228 | 229 | logs_list = [] 230 | assert N != 0 231 | for m in range(0, N): 232 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP", "SBP"]: 233 | args.nums = range(0,m+1) 234 | args.n = m 235 | logs = read_logs(args) 236 | logs_list.append(logs) 237 | if (m+1) not in sampling_times: 238 | continue 239 | elif args.reasoning in ["S-RF", "MAD"]: 240 | args.n = 0 241 | args.nums = range(0, m+1) 242 | logs = read_logs(args) 243 | logs_list.append(logs) 244 | elif args.reasoning == "ToT": 245 | args.nums = range(0, m+1) 246 | for n in range(0, 5): 247 | try: 248 | args.n = n 249 | logs = read_logs(args) 250 | logs_list.append(logs) 251 | except: 252 | break 253 | 254 | accs = [] 255 | prompt_tokens_ = 0 256 | completion_tokens_ = 0 257 | l = len(logs_list[0]) 258 | for count in range(0, 5): 259 | if args.shuffle: 260 | random.shuffle(logs_list) 261 | acc_num = 0 262 | if args.dataset in ["GSM8K", "GSM-Hard", "MATH", "AIME_2024"]: 263 | subject = "mathematic" 264 | for j in range(0, l): 265 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset: 266 | subject = logs_list[0][j]["subject"] 267 | key = logs_list[0][j]['key'] 268 | if args.reasoning in ["DiP", "CoT", "L2M", "AnP"]: 269 | output_keys = [parse_answer(args, log[j]['record']['output']) for log in logs_list] 270 | output_keys = [output for output in output_keys if output != None] 271 | if len(output_keys) == 0: 272 | output_keys = [None] 273 | 274 | if count == 0: 275 | prompt_tokens = [log[j]['record']['usage']['prompt_tokens'] for log in logs_list[:1]] 276 | completion_tokens = [log[j]['record']['usage']['completion_tokens'] for log in logs_list] 277 | 278 | elif args.reasoning == "S-RF": 279 | output_keys = [parse_answer(args, logs[j]["record"][f"output{m+1}"]["output"])] 280 | if count == 0: 281 | prompt_tokens = [logs[j]["record"]["output0"]['usage']["prompt_tokens"]] 282 | completion_tokens = [logs[j]["record"]["output0"]['usage']["completion_tokens"]] 283 | for k in range(0, m+1): 284 | prompt_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["prompt_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["prompt_tokens"]] 285 | completion_tokens += [logs[j]["record"][f"problems{k+1}"]['usage']["completion_tokens"], logs[j]["record"][f"output{k+1}"]['usage']["completion_tokens"]] 286 | 287 | elif args.reasoning == "ToT": 288 | indexes = [] 289 | for logs in logs_list: 290 | index = parse_best_solution(logs[j]["record"]["choose"]["output"]) 291 | if index != None and "0" < index and index <= str(args.shot): 292 | index = int(index) - 1 293 | else: 294 | index = random.choice(range(0, args.shot)) 295 | indexes.append(index) 296 | solutions = logs_list[0][j]["record"]["solutions"] 297 | index = get_most_common_answer(indexes) 298 | best_solution = solutions[index] 299 | output_keys = [parse_answer(args, best_solution["output"])] 300 | if count == 0: 301 | prompt_tokens = [] 302 | completion_tokens = [] 303 | for k in range(0, len(solutions)): 304 | solution = solutions[k] 305 | if k == 0: 306 | prompt_tokens.append(solution["usage"]["prompt_tokens"]) 307 | completion_tokens.append(solution["usage"]["completion_tokens"]) 308 | for ii, logs in enumerate(logs_list): 309 | if ii == 0: 310 | prompt_tokens.append(logs[j]["record"]["choose"]["usage"]["prompt_tokens"]) 311 | completion_tokens.append(logs[j]["record"]["choose"]["usage"]["completion_tokens"]) 312 | 313 | elif args.reasoning == "SBP": 314 | output_keys = [parse_answer(args, log[j]['record']['solution']["output"]) for log in logs_list] 315 | output_keys = [output for output in output_keys if output != None] 316 | if len(output_keys) == 0: 317 | output_keys = [None] 318 | if count == 0: 319 | prompt_tokens = [] 320 | completion_tokens = [] 321 | for k in range(0, len(logs_list)): 322 | log = logs_list[k][j] 323 | if k == 0: 324 | prompt_tokens.append(log["record"]["principles"]["usage"]["prompt_tokens"]) 325 | completion_tokens.append(log["record"]["principles"]["usage"]["completion_tokens"]) 326 | prompt_tokens.append(log["record"]["solution"]["usage"]["prompt_tokens"]) 327 | completion_tokens.append(log["record"]["solution"]["usage"]["completion_tokens"]) 328 | 329 | elif args.reasoning == "MAD": 330 | output_keys = [parse_answer(args, log["output"]) for log in logs[j]["record"][f"round{m+1}"]] 331 | output_keys = [output for output in output_keys if output != None] 332 | if len(output_keys) == 0: 333 | output_keys = [None] 334 | if count == 0: 335 | prompt_tokens = [] 336 | completion_tokens = [] 337 | for k in range(0, m+1): 338 | prompt_tokens += [log["usage"]["prompt_tokens"] for log in logs[j]["record"][f"round{k+1}"]] 339 | completion_tokens += [log["usage"]["completion_tokens"] for log in logs[j]["record"][f"round{k+1}"]] 340 | if count == 0: 341 | prompt_tokens_ += sum(prompt_tokens) 342 | completion_tokens_ += sum(completion_tokens) 343 | 344 | most_common_elements, max_count = find_most_common_elements(output_keys) 345 | output_key = random.choice(most_common_elements) 346 | 347 | if args.dataset in ["GSM8K"]: 348 | if output_key != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output_key) < 10**(-4): 349 | acc_num += 1 350 | elif args.dataset in ["GSM-Hard"]: 351 | if output_key != None and abs(float(key) - output_key) < 10**(-4): 352 | acc_num += 1 353 | elif args.dataset in ["MATH", "AIME_2024"]: 354 | if is_equiv(str(key), output_key): 355 | acc_num += 1 356 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset: 357 | if key == output_key: 358 | acc_num += 1 359 | acc = acc_num / l * 100 360 | accs.append(acc) 361 | 362 | acc = sum(accs)/len(accs) 363 | total_tokens = prompt_tokens_ + completion_tokens_ 364 | if args.reasoning in ["ToT"]: 365 | cost = args.shot 366 | else: 367 | cost = m + 1 368 | 369 | 370 | s = f"{args.reasoning.ljust(15)} " + "{}-shot".format(args.shot).ljust(7) + f" {str(args.nums).ljust(12)} prompt_tokens: {str(prompt_tokens_).ljust(10)} completion_tokens: {str(completion_tokens_).ljust(10)} cost:" + str("%.12f"%(cost)).ljust(20) + f"tokens: {str(total_tokens).ljust(11)}" " Acc: " + "%.4f"%acc 371 | print(s) 372 | 373 | ws_new.append([args.reasoning, subject, str(args.shot), str(args.nums[-1]+1), str(prompt_tokens_), str(completion_tokens_), str("%.15f"%(cost)), str(total_tokens), str("%.4f"%acc)]) 374 | 375 | except Exception as e: 376 | print(args.reasoning, "error") 377 | 378 | wb.save(path) 379 | 380 | labels_num = { 381 | "DiP": 0, 382 | "CoT": 0, 383 | "L2M": 0, 384 | "ToT": 0, 385 | "S-RF": 0, 386 | "AnP": 0, 387 | "SBP": 0, 388 | "MAD": 0, 389 | } 390 | 391 | model_name = args.model_name 392 | sheet = wb.get_sheet_by_name(model_name) 393 | 394 | tokens = [] 395 | accuracy = [] 396 | cost = [] 397 | labels = [] 398 | flag = 0 399 | for row in sheet.iter_rows(values_only=True): 400 | if flag == 0: 401 | flag = 1 402 | continue 403 | name = row[0] 404 | if labels_num[row[0]] >= N_max: 405 | continue 406 | labels.append(name) 407 | labels_num[row[0]] += 1 408 | tokens.append(int(row[-2])) 409 | accuracy.append(float(row[-1])) 410 | cost.append(float(row[-3])) 411 | 412 | unique_labels = ["DiP", "CoT", "L2M", "ToT", "S-RF", "SBP", "AnP", "MAD"] 413 | wb.close() 414 | 415 | plt.figure(figsize=(6, 3)) 416 | ax = plt.gca() 417 | 418 | colors = np.array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00], 419 | [2.17647059e-01, 4.29120609e-01, 9.75511968e-01, 1.00000000e+00], 420 | [17.25490196e-02, 8.82927610e-01, 1, 1.00000000e+00], 421 | [3.54901961e-01, 9.74138602e-01, 7.82927610e-01, 1.00000000e+00], 422 | [6.45098039e-01, 9.74138602e-01, 6.22112817e-01, 1.00000000e+00], 423 | [1, 7.82927610e-01, 5.34676422e-01, 1.00000000e+00], 424 | [1.00000000e+00, 4.29120609e-01, 2.19946358e-01, 1.00000000e+00], 425 | [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00],]) 426 | 427 | for i, label in enumerate(unique_labels): 428 | mask = [label == lb for lb in labels] 429 | plt.scatter([cost[idx] for idx, val in enumerate(mask) if val], 430 | [accuracy[idx] for idx, val in enumerate(mask) if val], 431 | color=colors[i], label=label, s=36, marker=marker_dict[label], edgecolor='black', linewidths=0.75, zorder=5) 432 | plt.plot([cost[idx] for idx, val in enumerate(mask) if val], 433 | [accuracy[idx] for idx, val in enumerate(mask) if val], 434 | color=colors[i], linewidth=1.5, marker=marker_dict[label]) 435 | 436 | x_ticks = sampling_times 437 | y_max = max(accuracy) 438 | y_min = min(accuracy) 439 | ind_line = (y_max - y_min) / 7.5 440 | plt.ylim(bottom=None, top=y_max + ind_line * 1.5) 441 | 442 | def custom_formatter(y, pos): 443 | if y > max(accuracy) + 0.9: 444 | return "" 445 | else: 446 | return f"{int(y)}" 447 | 448 | for x in x_ticks: 449 | plt.axvline(x=x, color='gray', linestyle='--', alpha=0.5, zorder=0) 450 | mask = [cost[idx] == x for idx in range(len(cost))] 451 | if any(mask): 452 | best_idx = np.argmax([accuracy[idx] for idx, val in enumerate(mask) if val]) 453 | best_label = [labels[idx] for idx, val in enumerate(mask) if val][best_idx] 454 | plt.scatter(x, y_max + ind_line, color=colors[unique_labels.index(best_label)], 455 | marker=marker_dict[best_label], s=64, zorder=8, edgecolor='black', linewidths=1) 456 | 457 | plt.axhline(y=y_max + ind_line / 2, color='k', linestyle='-', alpha=0.5, zorder=10, linewidth=1.5) 458 | plt.rcParams['xtick.labelsize'] = 14 459 | plt.rcParams['ytick.labelsize'] = 14 460 | plt.rcParams['font.weight'] = 'bold' 461 | plt.rcParams['axes.labelweight'] = 'bold' 462 | plt.rcParams['axes.titleweight'] = 'bold' 463 | plt.title(f"{model_names_formal[args.model_name]}", fontsize=18) 464 | plt.xlabel('Sampling Time', fontsize=14, fontweight='bold') 465 | plt.ylabel('Accuracy', fontsize=14, fontweight='bold') 466 | 467 | plt.xticks(x_ticks) 468 | yticks = ax.get_yticks() 469 | yticks = yticks[yticks <= max(accuracy)] 470 | 471 | ax.set_yticks(yticks) 472 | ax.set_yticklabels([str(int(tick)) for tick in yticks]) 473 | ax.yaxis.set_major_locator(MaxNLocator(integer=True)) 474 | ax.yaxis.set_major_formatter(FuncFormatter(custom_formatter)) 475 | 476 | plt.text(pos, y_max + ind_line, r'$\mathbf{P}_{N}^*$', fontsize=12, ha='center', va='center') 477 | plt.rcParams['font.weight'] = 'bold' 478 | plt.rcParams['axes.labelweight'] = 'bold' 479 | plt.rcParams['axes.titleweight'] = 'bold' 480 | 481 | plt.legend(loc='best', fontsize=12, framealpha=0.9, ncol=2, markerscale=1.5, handletextpad=0.5, columnspacing=1.0) 482 | 483 | plt.savefig(os.path.join(log_path, args.dataset, args.model_name, "pics", f"Performance_N.png"), bbox_inches='tight', dpi = 600) 484 | plt.close() 485 | print(f"Performance_N.png saved to {os.path.join(log_path, args.dataset, args.model_name, 'pics', 'Performance_N.png')}") 486 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from prompts import GSM8K, GPQA, GSM_Hard, MATH, MMLU, AIME 3 | 4 | import os 5 | import json 6 | import re 7 | import random 8 | from collections import Counter 9 | 10 | 11 | base_path = f"xxx/xxx/.../rethinking_prompting" 12 | log_path = os.path.join(base_path, "logs") 13 | 14 | 15 | def get_cost(model_name, prompt_tokens, completion_tokens): 16 | prompt_tokens = float(prompt_tokens) 17 | completion_tokens = float(completion_tokens) 18 | if "gemini" in model_name: 19 | cost = prompt_tokens * 0.075 + completion_tokens * 0.3 20 | elif model_name == "gpt-3.5-turbo-0613": 21 | cost = prompt_tokens * 1.5 + completion_tokens * 2 22 | elif model_name == "gpt-4o-mini": 23 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 24 | else: 25 | cost = prompt_tokens * 0.15 + completion_tokens * 0.6 26 | cost = cost / 10 ** 6 27 | return cost 28 | 29 | 30 | def find_most_common_elements(outputs): 31 | outputs = [output for output in outputs if output != None] 32 | if outputs == []: 33 | return None 34 | counter = Counter(outputs) 35 | max_count = max(counter.values()) 36 | most_common_elements = [element for element, count in counter.items() if count == max_count] 37 | return most_common_elements, max_count 38 | 39 | 40 | def get_unique_most_common_answer(outputs): 41 | outputs = [output for output in outputs if output != None] 42 | if outputs == []: 43 | return None 44 | most_common_elements, max_count = find_most_common_elements(outputs) 45 | most_common_answer = random.choice(most_common_elements) 46 | return most_common_answer 47 | 48 | 49 | def load_GPQA_examples(dataset_list, seed: int): 50 | random.seed(seed) 51 | 52 | def shuffle_choices_and_create_example(row): 53 | list_choices = [row["Incorrect Answer 1"], row["Incorrect Answer 2"], row["Incorrect Answer 3"], row["Correct Answer"]] 54 | random.shuffle(list_choices) 55 | 56 | example = {} 57 | example["problem"] = row["Question"] 58 | example["subject"] = row["High-level domain"] 59 | example["choices"] = list_choices 60 | example["answer"] = chr(list_choices.index(row["Correct Answer"]) + 65) 61 | return example 62 | 63 | return [shuffle_choices_and_create_example(row) for row in dataset_list] 64 | 65 | 66 | def read_dataset(args): 67 | dataset = args.dataset 68 | if dataset == "GSM8K": 69 | ds = load_dataset("openai/gsm8k", "main")[args.split] 70 | dataset_list = [d for d in ds] 71 | elif dataset == "GSM-Hard": 72 | ds = load_dataset("reasoning-machines/gsm-hard") 73 | dataset_list = [d for d in ds["train"]] 74 | elif dataset == "MATH": 75 | ds = load_dataset("HuggingFaceH4/MATH-500")[args.split] 76 | dataset_list = [d for d in ds] 77 | elif dataset == "GPQA": 78 | ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"] 79 | dataset_list = load_GPQA_examples([d for d in ds], args.seed) 80 | elif "MMLU" in dataset: 81 | subject = dataset.split("-")[-1] 82 | dataset_list = load_dataset("cais/mmlu", subject)[args.split] 83 | elif dataset == "AIME_2024": 84 | os.environ["HF_DATASETS_OFFLINE"] = "1" 85 | from modelscope.msdatasets import MsDataset 86 | ds = MsDataset.load("AI-ModelScope/AIME_2024", subset_name="default", split="train") 87 | dataset_list = [d for d in ds] 88 | return dataset_list 89 | 90 | 91 | def examine_output(dataset, output, key): 92 | if dataset in ["GSM8K"]: 93 | if output != None and abs(float(re.sub(r"[^0-9.-]", "", str(key))) - output) < 10**(-4): 94 | return True 95 | elif dataset in ["GSM-Hard"]: 96 | if output != None and abs(float(key) - output) < 10**(-4): 97 | return True 98 | elif dataset in ["MATH", "AIME_2024"]: 99 | if is_equiv(str(key), output): 100 | return True 101 | elif dataset in ["GPQA", "CommonSenseQA"] or "MMLU" in dataset: 102 | if key == output: 103 | return True 104 | return False 105 | 106 | 107 | def record_logs(logs, args): 108 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name)): 109 | os.makedirs(os.path.join(log_path, args.dataset, args.model_name)) 110 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json") 111 | with open(path, "w") as f: 112 | json.dump(logs, f, indent = 4) 113 | 114 | 115 | def record_a_logs(logs, args): 116 | if not os.path.exists(os.path.join(log_path, args.dataset)): 117 | os.mkdir(os.path.join(log_path, args.dataset)) 118 | if not os.path.exists(os.path.join(log_path, args.dataset, args.model_name)): 119 | os.mkdir(os.path.join(log_path, args.dataset, args.model_name)) 120 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json") 121 | with open(path, "a") as f: 122 | json.dump(logs, f, indent = 4) 123 | 124 | 125 | def read_logs(args): 126 | path = os.path.join(log_path, args.dataset, args.model_name, f"{args.reasoning}_{args.shot}_{args.n}.json") 127 | if os.path.exists(path): 128 | with open (path, "r") as f: 129 | logs = json.loads(f.read()) 130 | else: 131 | logs = [] 132 | return logs 133 | 134 | 135 | def construct_message(args, agents, question, idx): 136 | 137 | prefix_string = "These are the answers to the question from other agents: " 138 | 139 | for agent in agents: 140 | agent_response = agent[idx] 141 | response = "\n\n One agent answer: ```{}```".format(agent_response) 142 | 143 | prefix_string = prefix_string + response 144 | 145 | if args.dataset in ["GSM8K", "GSM-Hard"]: 146 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}.".format(question) + " Your final answer should be a single numerical number, in the form \\boxed{answer}, at the end of your response." 147 | elif args.dataset in ["GPQA"]: 148 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the problem? \n The original problem is {}. ".format(question) + GPQA.prompt_format 149 | elif "MMLU" in args.dataset: 150 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the problem? \n The original problem is {}. ".format(question) + MMLU.prompt_format 151 | elif args.dataset in ["MATH", "AIME_2024"]: 152 | prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}. ".format(question) + MATH.prompt_format 153 | else: 154 | raise ValueError(f"{args.dataset} is not included in MAD!") 155 | 156 | return prefix_string 157 | 158 | 159 | def create_prompt(args, index = None): 160 | prompts = [] 161 | if index != None: 162 | range_ = [index] 163 | else: 164 | range_ = range(0, len(args.questions)) 165 | 166 | for i in range_: 167 | args.question = args.questions[i] 168 | if args.dataset == "GSM8K": 169 | if args.reasoning == "DiP": 170 | prompt = GSM8K.io.replace("{question}", args.question) 171 | elif args.reasoning == "CoT": 172 | assert args.shot in [0, 1, 5] 173 | if args.shot == 0: 174 | prompt = GSM8K.cot_0_shot.replace("{question}", args.question) 175 | elif args.shot == 1: 176 | prompt = GSM8K.cot_1_shot.replace("{question}", args.question) 177 | elif args.shot == 5: 178 | prompt = GSM8K.cot_5_shot.replace("{question}", args.question) 179 | elif args.reasoning == "L2M": 180 | if args.shot == 1: 181 | prompt = GSM8K.Least_to_Most_1_shot.replace("{question}", args.question) 182 | elif args.reasoning == "ToT": 183 | if args.shot == 3: 184 | prompt = GSM8K.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GSM8K.tot_post 185 | elif args.shot == 5: 186 | prompt = GSM8K.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GSM8K.tot_post 187 | elif args.shot == 10: 188 | prompt = GSM8K.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GSM8K.tot_post 189 | elif args.reasoning == "AnP": 190 | if args.shot == 1: 191 | prompt = GSM8K.anologous_1_prompt.replace("{question}", args.question) 192 | elif args.shot == 3: 193 | prompt = GSM8K.anologous_3_prompt.replace("{question}", args.question) 194 | elif args.shot == 5: 195 | prompt = GSM8K.anologous_5_prompt.replace("{question}", args.question) 196 | elif args.reasoning == "SBP": 197 | if "principles" not in args: 198 | prompt = GSM8K.SBP_extract.replace("{question}", args.question) 199 | else: 200 | prompt = GSM8K.SBP_answer.replace("{question}", args.question) 201 | prompt = prompt.replace("{principles}", args.principles) 202 | 203 | elif args.dataset == "GSM-Hard": 204 | if args.reasoning == "DiP": 205 | prompt = GSM_Hard.io.replace("{question}", args.question) 206 | elif args.reasoning == "CoT": 207 | assert args.shot in [0, 1, 5] 208 | if args.shot == 0: 209 | prompt = GSM_Hard.cot_0_shot.replace("{question}", args.question) 210 | elif args.shot == 1: 211 | prompt = GSM_Hard.cot_1_shot.replace("{question}", args.question) 212 | elif args.shot == 5: 213 | prompt = GSM_Hard.cot_5_shot.replace("{question}", args.question) 214 | elif args.reasoning == "L2M": 215 | if args.shot == 1: 216 | prompt = GSM_Hard.Least_to_Most_1_shot.replace("{question}", args.question) 217 | elif args.reasoning == "ToT": 218 | if args.shot == 3: 219 | prompt = GSM_Hard.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GSM_Hard.tot_post 220 | elif args.shot == 5: 221 | prompt = GSM_Hard.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GSM_Hard.tot_post 222 | elif args.shot == 10: 223 | prompt = GSM_Hard.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GSM_Hard.tot_post 224 | elif args.reasoning == "AnP": 225 | if args.shot == 1: 226 | prompt = GSM_Hard.anologous_1_prompt.replace("{question}", args.question) 227 | elif args.shot == 3: 228 | prompt = GSM_Hard.anologous_3_prompt.replace("{question}", args.question) 229 | elif args.shot == 5: 230 | prompt = GSM_Hard.anologous_5_prompt.replace("{question}", args.question) 231 | elif args.reasoning == "SBP": 232 | if "principles" not in args: 233 | prompt = GSM_Hard.SBP_extract.replace("{question}", args.question) 234 | else: 235 | prompt = GSM_Hard.SBP_answer.replace("{question}", args.question) 236 | prompt = prompt.replace("{principles}", args.principles) 237 | 238 | elif args.dataset == "GPQA": 239 | args.choices = args.choiceses[i] 240 | if args.reasoning == "DiP": 241 | prompt = GPQA.io.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 242 | elif args.reasoning == "CoT": 243 | assert args.shot in [0, 1, 5] 244 | if args.shot == 0: 245 | prompt = GPQA.cot_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 246 | elif args.shot == 1: 247 | prompt = GPQA.cot_1_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 248 | elif args.shot == 5: 249 | prompt = GPQA.cot_5_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 250 | elif args.reasoning == "L2M": 251 | if args.shot == 0: 252 | prompt = GPQA.Least_to_Most_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 253 | elif args.reasoning == "ToT": 254 | if args.shot == 3: 255 | prompt = GPQA.tot_3_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + GPQA.tot_post 256 | elif args.shot == 5: 257 | prompt = GPQA.tot_5_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + GPQA.tot_post 258 | elif args.shot == 10: 259 | prompt = GPQA.tot_10_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + GPQA.tot_post 260 | elif args.reasoning == "AnP": 261 | if args.shot == 1: 262 | prompt = GPQA.anologous_1_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 263 | elif args.shot == 3: 264 | prompt = GPQA.anologous_3_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 265 | elif args.shot == 5: 266 | prompt = GPQA.anologous_5_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 267 | elif args.reasoning == "SBP": 268 | if "principles" not in args: 269 | prompt = GPQA.SBP_extract.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 270 | else: 271 | prompt = GPQA.SBP_answer.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles) 272 | 273 | elif "MMLU" in args.dataset: 274 | args.choices = args.choiceses[i] 275 | if args.reasoning == "DiP": 276 | prompt = MMLU.io.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 277 | elif args.reasoning == "CoT": 278 | assert args.shot in [0, 1, 5] 279 | if args.shot == 0: 280 | prompt = MMLU.cot_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 281 | elif args.reasoning == "L2M": 282 | if args.shot == 0: 283 | prompt = MMLU.Least_to_Most_0_shot.format(question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 284 | elif args.reasoning == "ToT": 285 | if args.shot == 3: 286 | prompt = MMLU.tot_3_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + MMLU.tot_post 287 | elif args.shot == 5: 288 | prompt = MMLU.tot_5_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + MMLU.tot_post 289 | elif args.shot == 10: 290 | prompt = MMLU.tot_10_solutions.format(question=args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + MMLU.tot_post 291 | elif args.reasoning == "AnP": 292 | if args.shot == 1: 293 | prompt = MMLU.anologous_1_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 294 | elif args.shot == 3: 295 | prompt = MMLU.anologous_3_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 296 | elif args.shot == 5: 297 | prompt = MMLU.anologous_5_prompt.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 298 | elif args.reasoning == "SBP": 299 | if "principles" not in args: 300 | if "physics" in args.subjects[i]: 301 | prompt = MMLU.SBP_extract_physics.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 302 | elif "chemistry" in args.subjects[i]: 303 | prompt = MMLU.SBP_extract_chemistry.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 304 | else: 305 | prompt = MMLU.SBP_extract.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3]) 306 | else: 307 | if "physics" in args.subjects[i]: 308 | prompt = MMLU.SBP_answer_physics.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles) 309 | elif "chemistry" in args.subjects[i]: 310 | prompt = MMLU.SBP_answer_chemistry.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles) 311 | else: 312 | prompt = MMLU.SBP_answer.format(subject = args.subjects[i], question = args.question, choice1 = args.choices[0], choice2 = args.choices[1], choice3 = args.choices[2], choice4 = args.choices[3], principles = args.principles) 313 | 314 | elif args.dataset == "MATH": 315 | if args.reasoning == "DiP": 316 | prompt = MATH.io.replace("{question}", args.question) 317 | elif args.reasoning == "CoT": 318 | assert args.shot in [0, 1, 5] 319 | if args.shot == 0: 320 | prompt = MATH.cot_0_shot.replace("{question}", args.question) 321 | elif args.shot == 1: 322 | prompt = MATH.cot_1_shot.replace("{question}", args.question) 323 | elif args.shot == 5: 324 | prompt = MATH.cot_5_shot.replace("{question}", args.question) 325 | elif args.reasoning == "L2M": 326 | if args.shot == 0: 327 | prompt = MATH.Least_to_Most_0_shot.replace("{question}", args.question) 328 | elif args.reasoning == "ToT": 329 | if args.shot == 3: 330 | prompt = MATH.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + MATH.tot_post 331 | elif args.shot == 5: 332 | prompt = MATH.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + MATH.tot_post 333 | elif args.shot == 10: 334 | prompt = MATH.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + MATH.tot_post 335 | elif args.reasoning == "AnP": 336 | if args.shot == 1: 337 | prompt = MATH.anologous_1_prompt.replace("{question}", args.question) 338 | elif args.shot == 3: 339 | prompt = MATH.anologous_3_prompt.replace("{question}", args.question) 340 | elif args.shot == 5: 341 | prompt = MATH.anologous_5_prompt.replace("{question}", args.question) 342 | elif args.reasoning == "SBP": 343 | if "principles" not in args: 344 | prompt = MATH.SBP_extract.replace("{question}", args.question) 345 | else: 346 | prompt = MATH.SBP_answer.replace("{question}", args.question) 347 | prompt = prompt.replace("{principles}", args.principles) 348 | 349 | elif args.dataset == "AIME_2024": 350 | if args.reasoning == "DiP": 351 | prompt = AIME.io.replace("{question}", args.question) 352 | elif args.reasoning == "CoT": 353 | assert args.shot in [0, 1, 5] 354 | if args.shot == 0: 355 | prompt = AIME.cot_0_shot.replace("{question}", args.question) 356 | elif args.reasoning == "L2M": 357 | if args.shot == 0: 358 | prompt = AIME.Least_to_Most_0_shot.replace("{question}", args.question) 359 | elif args.reasoning == "ToT": 360 | if args.shot == 3: 361 | prompt = AIME.tot_3_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"]) + AIME.tot_post 362 | elif args.shot == 5: 363 | prompt = AIME.tot_5_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"]) + AIME.tot_post 364 | elif args.shot == 10: 365 | prompt = AIME.tot_10_solutions.format(question=args.question, solution1 = args.records_tot[i][0]["output"], solution2 = args.records_tot[i][1]["output"], solution3 = args.records_tot[i][2]["output"], solution4 = args.records_tot[i][3]["output"], solution5 = args.records_tot[i][4]["output"], solution6 = args.records_tot[i][5]["output"], solution7 = args.records_tot[i][6]["output"], solution8 = args.records_tot[i][7]["output"], solution9 = args.records_tot[i][8]["output"], solution10 = args.records_tot[i][9]["output"]) + AIME.tot_post 366 | elif args.reasoning == "AnP": 367 | if args.shot == 1: 368 | prompt = AIME.anologous_1_prompt.replace("{question}", args.question) 369 | elif args.shot == 3: 370 | prompt = AIME.anologous_3_prompt.replace("{question}", args.question) 371 | elif args.shot == 5: 372 | prompt = AIME.anologous_5_prompt.replace("{question}", args.question) 373 | elif args.reasoning == "SBP": 374 | if "principles" not in args: 375 | prompt = AIME.SBP_extract.replace("{question}", args.question) 376 | else: 377 | prompt = AIME.SBP_answer.replace("{question}", args.question) 378 | prompt = prompt.replace("{principles}", args.principles) 379 | else: 380 | raise ValueError(f"{args.dataset} is not included in def create_prompt!") 381 | prompts.append(prompt) 382 | 383 | return prompts 384 | 385 | 386 | def _strip_string(string): 387 | # linebreaks 388 | string = string.replace("\n", "") 389 | #print(string) 390 | 391 | # remove inverse spaces 392 | string = string.replace("\\!", "") 393 | #print(string) 394 | 395 | # replace \\ with \ 396 | string = string.replace("\\\\", "\\") 397 | #print(string) 398 | 399 | # replace tfrac and dfrac with frac 400 | string = string.replace("tfrac", "frac") 401 | string = string.replace("dfrac", "frac") 402 | #print(string) 403 | 404 | # remove \left and \right 405 | string = string.replace("\\left", "") 406 | string = string.replace("\\right", "") 407 | #print(string) 408 | 409 | # Remove circ (degrees) 410 | string = string.replace("^{\\circ}", "") 411 | string = string.replace("^\\circ", "") 412 | 413 | # remove dollar signs 414 | string = string.replace("\\$", "") 415 | 416 | # remove units (on the right) 417 | string = _remove_right_units(string) 418 | 419 | # remove percentage 420 | string = string.replace("\\%", "") 421 | string = string.replace("\%", "") 422 | 423 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 424 | string = string.replace(" .", " 0.") 425 | string = string.replace("{.", "{0.") 426 | 427 | # 6+9j -> 6+9i 428 | string = string.replace("j", "i") 429 | # if empty, return empty string 430 | 431 | if len(string) == 0: 432 | return string 433 | if string[0] == ".": 434 | string = "0" + string 435 | 436 | # to consider: get rid of e.g. "k = " or "q = " at beginning 437 | if len(string.split("=")) == 2: 438 | if len(string.split("=")[0]) <= 2: 439 | string = string.split("=")[1] 440 | 441 | # fix sqrt3 --> sqrt{3} 442 | string = _fix_sqrt(string) 443 | 444 | # remove spaces 445 | string = string.replace(" ", "") 446 | 447 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 448 | string = _fix_fracs(string) 449 | 450 | # manually change 0.5 --> \frac{1}{2} 451 | if string == "0.5": 452 | string = "\\frac{1}{2}" 453 | 454 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 455 | string = _fix_a_slash_b(string) 456 | 457 | return string 458 | 459 | 460 | def replace_sqrt_with_power(s): 461 | pattern = r"\\sqrt\{([^}]+)\}" 462 | return re.sub(pattern, r"(\1)**0.5", s) 463 | 464 | 465 | def replace_pi(s): 466 | # result = re.sub(r"(? 1: 487 | substrs = substrs[1:] 488 | for substr in substrs: 489 | new_str += "\\frac" 490 | if substr[0] == "{": 491 | new_str += substr 492 | else: 493 | try: 494 | assert len(substr) >= 2 495 | except: 496 | return string 497 | a = substr[0] 498 | b = substr[1] 499 | if b != "{": 500 | if len(substr) > 2: 501 | post_substr = substr[2:] 502 | new_str += "{" + a + "}{" + b + "}" + post_substr 503 | else: 504 | new_str += "{" + a + "}{" + b + "}" 505 | else: 506 | if len(substr) > 2: 507 | post_substr = substr[2:] 508 | new_str += "{" + a + "}" + b + post_substr 509 | else: 510 | new_str += "{" + a + "}" + b 511 | string = new_str 512 | return string 513 | 514 | 515 | def _fix_a_slash_b(string): 516 | if len(string.split("/")) != 2: 517 | return string 518 | a = string.split("/")[0] 519 | b = string.split("/")[1] 520 | try: 521 | a = int(a) 522 | b = int(b) 523 | assert string == "{}/{}".format(a, b) 524 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 525 | return new_string 526 | except: 527 | return string 528 | 529 | 530 | def _remove_right_units(string): 531 | # "\\text{ " only ever occurs (at least in the val set) when describing units 532 | if "\\text{ " in string: 533 | splits = string.split("\\text{ ") 534 | assert len(splits) == 2 535 | return splits[0] 536 | else: 537 | return string 538 | 539 | 540 | def _fix_sqrt(string): 541 | if "\\sqrt" not in string: 542 | return string 543 | splits = string.split("\\sqrt") 544 | new_string = splits[0] 545 | for split in splits[1:]: 546 | if split[0] != "{": 547 | a = split[0] 548 | new_substr = "\\sqrt{" + a + "}" + split[1:] 549 | else: 550 | new_substr = "\\sqrt" + split 551 | new_string += new_substr 552 | return new_string 553 | 554 | 555 | def is_equiv(str1, str2, verbose=False): 556 | if str1 is None and str2 is None: 557 | print("WARNING: Both None") 558 | return True 559 | if str1 is None or str2 is None: 560 | return False 561 | 562 | try: 563 | ss1 = _strip_string(str1) 564 | ss2 = _strip_string(str2) 565 | if verbose: 566 | print(ss1, ss2) 567 | return ss1 == ss2 568 | except: 569 | return str1 == str2 570 | 571 | 572 | def last_boxed_only(sample): 573 | """ 574 | Given a (q,a) sample, filter the answers so that they only contain 575 | the last \boxed{...} or \fbox{...} element 576 | """ 577 | q, a = sample 578 | a = last_boxed_only_string(a) 579 | if a == None: 580 | return None 581 | return (q, a) 582 | 583 | 584 | def last_boxed_only_string(string): 585 | idx = string.rfind("\\boxed") 586 | if idx < 0: 587 | idx = string.rfind("\\fbox") 588 | if idx < 0: 589 | return None 590 | 591 | i = idx 592 | right_brace_idx = None 593 | num_left_braces_open = 0 594 | while i < len(string): 595 | if string[i] == "{": 596 | num_left_braces_open += 1 597 | if string[i] == "}": 598 | num_left_braces_open -= 1 599 | if num_left_braces_open == 0: 600 | right_brace_idx = i 601 | break 602 | i += 1 603 | 604 | if right_brace_idx == None: 605 | retval = None 606 | else: 607 | retval = string[idx:right_brace_idx + 1] 608 | 609 | return retval 610 | 611 | 612 | def remove_boxed(s): 613 | left = "\\boxed{" 614 | try: 615 | assert s[:len(left)] == left 616 | assert s[-1] == "}" 617 | return s[len(left):-1] 618 | except: 619 | return None 620 | 621 | 622 | def parse_answer(args, input_str): 623 | solution = None 624 | if args.dataset in ["GSM8K", "GSM-Hard"]: 625 | pattern = r"boxed\{(.*?)\}" 626 | matches = re.findall(pattern, input_str) 627 | 628 | for match_str in matches[::-1]: 629 | match_str = match_str.split("=")[-1] 630 | if "boxed" not in match_str: 631 | solution = re.sub(r"[^0-9.-]", "", match_str) 632 | else: 633 | solution = parse_answer(args, match_str) 634 | if solution: 635 | break 636 | 637 | if solution == None or solution == "": 638 | pattern = r"boxed\{(.*)\}" 639 | matches = re.findall(pattern, input_str) 640 | 641 | for match_str in matches[::-1]: 642 | if "boxed" not in match_str: 643 | solution = re.sub(r"[^0-9.-]", "", match_str) 644 | else: 645 | solution = parse_answer(args, match_str) 646 | if solution: 647 | break 648 | 649 | if solution == None or solution == "": 650 | pattern = r"\{([0-9 \-.,$]*)\}" 651 | matches = re.findall(pattern, input_str) 652 | 653 | for match_str in matches[::-1]: 654 | solution = re.sub(r"[^0-9.-]", "", match_str) 655 | if solution: 656 | break 657 | 658 | if solution == None or solution == "": 659 | pattern = r"\*\*(.*)\*\*" 660 | matches = re.findall(pattern, input_str) 661 | 662 | for match_str in matches[::-1]: 663 | solution = re.sub(r"[^0-9.-]", "", match_str) 664 | if solution: 665 | break 666 | 667 | if solution == None or solution == "": 668 | matches = re.findall(r"[0-9\-.,$]+", input_str) 669 | for match_str in matches[::-1]: 670 | if re.findall(r"\d+", match_str) != []: 671 | solution = re.sub(r"[^0-9.-]", "", match_str) 672 | if solution[-1] == ".": 673 | solution = solution[:-1] 674 | break 675 | try: 676 | solution = float(solution) 677 | except: 678 | solution = None 679 | elif args.dataset in ["GPQA"] or "MMLU" in args.dataset: 680 | answers = re.findall(r"correct answer is \*\*(.*)\*\*", input_str) 681 | if args.dataset in ["GPQA"] or "MMLU" in args.dataset: 682 | letters = ["A", "B", "C", "D"] 683 | for answer in answers[::-1]: 684 | if answer[0] not in letters: 685 | try: 686 | solution = re.search(r"\((.)\)", answer).group(1)[0] 687 | if solution in letters: 688 | return solution 689 | else: 690 | return None 691 | except: 692 | solution = "M" 693 | else: 694 | solution = answer[0] 695 | return solution 696 | 697 | answers = re.findall(r"correct answer is (.?)", input_str) 698 | for answer in answers[::-1]: 699 | if answer[0] not in letters: 700 | try: 701 | solution = re.search(r"correct answer is \((.?)\)", input_str).group(1)[0] 702 | return solution 703 | except: 704 | answer = "M" 705 | else: 706 | solution = answer[0] 707 | return solution 708 | answers = re.findall(r"\((.)\)", input_str) 709 | for answer in answers[::-1]: 710 | if answer[0] in letters: 711 | solution = answer[0] 712 | return solution 713 | answers = re.findall(r"\{(.)\}", input_str) 714 | for answer in answers[::-1]: 715 | if answer[0] in letters: 716 | solution = answer[0] 717 | elif args.dataset in ["MATH", "AIME_2024"]: 718 | return remove_boxed(last_boxed_only_string(input_str)) 719 | else: 720 | raise ValueError(f"{args.dataset} is not in def parse_answer!") 721 | return solution 722 | 723 | 724 | def parse_best_solution(input_str): 725 | pattern = r"index of the best solution is (\d+)" 726 | matches = re.findall(pattern, input_str) 727 | 728 | if matches: 729 | return matches[-1] 730 | else: 731 | pattern = r"\*\*(\d+)\*\*" 732 | matches = re.findall(pattern, input_str) 733 | 734 | for match_str in matches[::-1]: 735 | if match_str: 736 | return match_str 737 | return None 738 | 739 | 740 | def parse_best_method(s): 741 | start_str = "most suitable method is " 742 | start_index = s.find(start_str) 743 | if start_index == -1: 744 | return "" 745 | start_index += len(start_str) 746 | end_index = start_index 747 | while end_index < len(s) and s[end_index] not in ".,!?;:\n": 748 | end_index += 1 749 | return s[start_index:end_index].strip() 750 | 751 | 752 | def check_solution_verdict(output_str): 753 | pattern = r"solution is (right|wrong)" 754 | match = re.search(pattern, output_str) 755 | 756 | if match: 757 | return match.group(1) 758 | else: 759 | s = "random_" + random.choice(["right", "wrong"]) 760 | return s 761 | --------------------------------------------------------------------------------