├── .gitignore ├── LICENSE ├── PromptCoT ├── README.md ├── calc_acc.py ├── concept_encoding.py ├── concept_sampling.py ├── configs │ ├── promptcot_ds_1_5b_config.json │ ├── promptcot_ds_7b_config.json │ └── promptcot_qwq_32b_config.json ├── data │ ├── aime2024.jsonl │ ├── aime2025.jsonl │ ├── gsm8k.jsonl │ ├── math500.jsonl │ ├── mathematics_concepts.jsonl │ └── qwq │ │ ├── qwq_aime2024_test.jsonl │ │ ├── qwq_aime2025_test.jsonl │ │ ├── qwq_gsm8k_test.jsonl │ │ └── qwq_math500_test.jsonl ├── data_synthesis.py ├── eval │ ├── math_equivalence.py │ ├── qwen_math.py │ └── util.py ├── infer_longcot.py ├── problem_filtering.py ├── problem_generation.py ├── rejection_sampling_reward.py ├── requirements.txt └── train.py ├── PromptCoT_Mamba ├── README.md ├── calc_acc_aime.py ├── calc_acc_lcb.py ├── configs │ └── promptcot_mamba_7b_config.json ├── data │ ├── aime_test.jsonl │ └── livecodebench_test.jsonl ├── infer_longcot.py ├── livecodebench_v5.py ├── livecodebench_v5_utils │ ├── compute_code_generation_metrics.py │ ├── pass_k_utils.py │ ├── process_data.py │ └── testing_util.py ├── math_opensource_utils │ ├── math_equivalence.py │ ├── qwen_math.py │ └── util.py ├── requirements.txt └── train.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xueliang Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PromptCoT/README.md: -------------------------------------------------------------------------------- 1 | # **PromptCoT: Synthesizing Olympiad-Level Problems for Mathematical Reasoning in Large Language Models** 2 | 3 | --- 4 | 5 | ## **Highlights** 6 | ### **✨ The Missing Piece for Test-Time Scaling** 7 | A **lightweight yet powerful problem generation model** that enables the construction of **prompt sets at any scale with sufficient quality**—perfect for initializing your **post-training project**, whether it's **Supervised Fine-Tuning (SFT) or Reinforcement Learning (RL)**. Say goodbye to the limitations of open-source data! 8 | 9 | ### **📖 A Fully Open Project** 10 | - **📂 Open-Source Problem Generation Model** 11 | - **Model**: [Hugging Face](https://huggingface.co/xl-zhao/PromptCoT-Problem-Generation-Model) | [ModelScope](https://www.modelscope.cn/models/zhaoxlpku/PromptCoT-Problem-Generation-Model) 12 | - **Training Data**: [Hugging Face](https://huggingface.co/datasets/xl-zhao/PromptCoT-Problem-Generation-Dataset) | [ModelScope](https://www.modelscope.cn/datasets/zhaoxlpku/PromptCoT-Problem-Generation-Dataset) 13 | - **🔹 Open-Source Distilled Models for Mathematical Reasoning** 14 | - **PromptCoT-DS-1.5B** (**Distilled from DeepSeek-R1-Distill-Qwen-7B, 1.5B parameters**) 15 | [Hugging Face](https://huggingface.co/xl-zhao/PromptCoT-DS-1.5B) | [ModelScope](https://www.modelscope.cn/models/zhaoxlpku/PromptCoT-DS-1.5B) 16 | - **PromptCoT-DS-7B** (**Distilled from DeepSeek-R1-Distill-Qwen-7B, 7B parameters**) 17 | [Hugging Face](https://huggingface.co/xl-zhao/PromptCoT-DS-7B) | [ModelScope](https://www.modelscope.cn/models/zhaoxlpku/PromptCoT-DS-7B) 18 | - **[New] 🚀🚀🚀** **PromptCoT-QwQ-32B** (**Distilled from QwQ-32B, 32B parameters**) 19 | [Hugging Face](https://huggingface.co/xl-zhao/PromptCoT-QwQ-32B) | [ModelScope](https://www.modelscope.cn/models/zhaoxlpku/PromptCoT-QwQ-32B) 20 | - **Training Data for Supervised Fine-Tuning (SFT) of PromptCoT-DS Series Models** 21 | [Hugging Face](https://huggingface.co/datasets/xl-zhao/PromptCoT-DS-Dataset) | [ModelScope](https://www.modelscope.cn/datasets/zhaoxlpku/PromptCoT-DS-Dataset) 22 | - **[New] 🚀🚀🚀** **Training Data for Supervised Fine-Tuning (SFT) of PromptCoT-QwQ-32B** 23 | [Hugging Face](https://huggingface.co/datasets/xl-zhao/PromptCoT-QwQ-Dataset) | [ModelScope](https://www.modelscope.cn/datasets/zhaoxlpku/PromptCoT-QwQ-Dataset) 24 | 25 | 26 | ### **🏆 Superior Performance** 27 | - **Consistent Improvements over Deepseek Counterparts** 28 | **PromptCoT-DS-7B** surpasses **DeepSeek-R1-Distill-Qwen-7B** across all major benchmarks, achieving consistent improvements in problem-solving accuracy. The results, averaged over **8 random seeds**, highlight the following gains: 29 | - **+0.9%** absolute improvement on **MATH-500** (**93.7%** vs. **92.8%**) 30 | - **+3.2%** absolute improvement on **AIME2024** (**58.7%** vs. **55.5%**) 31 | - **+9.2%** absolute improvement on **AIME2025** (**49.2%** vs. **40.0%**) 32 | 33 | - **Competitive with 32B Models** 34 | Despite having only 7B parameters, **PromptCoT-DS-7B** achieves results comparable to larger 32B models such as **S1-32B** and **LIMO-32B**. 35 | 36 | **Performance Comparison of Different Models** 37 | 38 | | **Model** | **GSM8K** | **MATH-500** | **AIME2024** | **AIME2025** | 39 | |----------------------------------------------|------------------|---------------------|---------------------|---------------------| 40 | | **🔹 1.5B Models** | | | | | 41 | | **DeepSeek-R1-Distill-Qwen-1.5B** | - | 83.9% | 28.9% | 28.1% | 42 | | **STILL-3-1.5B-preview** | - | 85.5% | 39.3% | - | 43 | | **DeepScaleR-1.5B-Preview** | - | 🟢 **87.8%** | 🟢 **43.1%** | 🟢 **37.1%** | 44 | | **PromptCoT-DS-1.5B** (**ours**) | 🟢 **87.6% ± 0.5%** | **85.3% ± 1.1%** | **41.2% ± 6.9%** | **36.7% ± 6.2%** | 45 | | **🔹 7B Models** | | | | | 46 | | **DeepSeek-R1-Distill-Qwen-7B** | - | 92.8% | 55.5% | 40.0% | 47 | | **Qwen2.5-7B-SimpleRL** | - | 82.4% | 26.7% | - | 48 | | **OpenThinker-7B** | - | 89.6% | 30.0% | 33.3% | 49 | | **OpenR1-Qwen-7B** | - | 90.6% | 36.7% | 40.0% | 50 | | **PromptCoT-DS-7B** (**ours**) | 🔥 **92.8% ± 0.5%** | 🔥 **93.7% ± 0.7%** | 🔥 **58.7% ± 3.1%** | 🔥 **49.2% ± 7.9%** | 51 | | **🔹 32B Models** | | | | | 52 | | **DeepSeek-R1-Distill-Qwen-32B** | - | 94.3% | 72.6% | - | 53 | | **S1-32B** | - | 93.0% | 56.7% | 26.6% | 54 | | **LIMO-32B** | - | 94.8% | 57.1% | 46.6% | 55 | | **QwQ-32B** | - | - | 82.1% | 70.8% | 56 | | **PromptCoT-QwQ-32B** (**ours**) | 🔥🔥 **96.4% ± 0.2%** | 🔥🔥 **96.7% ± 0.5%** | 🔥🔥 **83.8% ± 2.8%** | 🔥🔥 **75.4% ± 4.7%** | 57 | 58 | - **Challenging RL-Based Methods Without RL** 59 | Despite relying purely on distillation, **PromptCoT-DS-1.5B** achieves competitive results against RL-based models like **STILL-3-1.5B-preview** and **DeepScaleR-1.5B-Preview**, highlighting the strength of our problem generation pipeline. 60 | 61 | ### **⚡ Efficiency Without Compromise** 62 | Compared to **DeepScaleR-1.5B-Preview**, **PromptCoT-DS-1.5B** achieves **40+% AIME scores** while using **over 15× fewer A100 GPU hours** (240 A100 hours vs. 3,800 A100 hours). This makes **PromptCoT-DS-1.5B** a highly efficient and cost-effective solution for mathematical reasoning. 63 | 64 | 65 | --- 66 | 67 | ## **Overview** 68 | Large language models (LLMs) have demonstrated remarkable advancements in mathematical reasoning. However, acquiring **challenging and high-quality Olympiad-level problems** at scale remains a significant challenge. Existing datasets often lack the necessary complexity to further enhance the capabilities of state-of-the-art models. 69 | 70 | **PromptCoT** introduces a method to systematically generate high-quality Olympiad-level math problems by modeling the **rationale behind expert problem design**. This approach improves problem diversity and difficulty while ensuring **logically consistent problem construction**. 71 | 72 | 📄 **Paper**: [🔗 PromptCoT: Synthesizing Olympiad-Level Problems for Mathematical Reasoning in Large Language Models](http://arxiv.org/abs/2503.02324). 73 | 74 | ### **Key Features** 75 | - **Concept-Guided Problem Synthesis**: PromptCoT generates problems by systematically combining **mathematical concepts**, allowing for a **scalable** and **flexible** way to create a diverse range of challenging problems. 76 | 77 | - **Rationale-Driven Problem Formulation**: Instead of directly generating problems, PromptCoT first constructs an **intermediate reasoning process (rationale)**—a step-by-step thought process that mimics how expert problem designers craft questions. This rationale helps bridge the gap between abstract mathematical concepts and well-formed problems, ensuring logical consistency and problem difficulty. 78 | 79 | - **Rejection Sampling for Quality Control**: Problems undergo an automated evaluation process where multiple reward models assess their quality. Only problems receiving the highest scores are retained, ensuring the final dataset consists of **challenging and high-quality** mathematical problems. 80 | 81 | - **Scalability & Adaptability**: The method allows for **large-scale problem generation** across a wide range of mathematical domains. Additionally, the rationale-driven approach can be adapted to **other structured reasoning tasks** beyond mathematics. 82 | 83 | --- 84 | 85 | ## **Quick Start: Generating Olympiad-Level Problems** 86 | Follow these steps to generate problems using **PromptCoT**. 87 | 88 | ### **1. Install Dependencies** 89 | ```bash 90 | pip install sentence_transformers==3.2.1 scikit-learn==1.3.2 scipy==1.10.1 faiss-gpu==1.7.2 vllm==0.6.3 transformers==4.46.3 fire==0.7.0 91 | pip install str2bool 92 | ``` 93 | 94 | ### **2. Generating Problems** 95 | #### **Step 1: Generate Concept Embeddings** 96 | We first encode mathematical concepts into embeddings to enable efficient sampling: 97 | 98 | ```bash 99 | python concept_encoding.py \ 100 | --data_path data/mathematics_concepts.jsonl \ 101 | --output_path data/embeddings.jsonl \ 102 | --model_path /path/to/Llama-3.1-8B \ 103 | --n_gpus 4 104 | ``` 105 | 106 | #### **Step 2: Sample Concept Combinations** 107 | We then sample meaningful concept combinations for problem generation: 108 | 109 | ```bash 110 | python concept_sampling.py \ 111 | --data_path data/mathematics_concepts.jsonl \ 112 | --output_path data/problem_generation_inputs.jsonl \ 113 | --data_size 1000 \ 114 | --embed_path data/embeddings.jsonl 115 | ``` 116 | 117 | #### **Step 3: Generate Math Problems** 118 | Using the **pre-trained problem generation model** – available on [**Hugging Face**](https://huggingface.co/xl-zhao/PromptCoT-Problem-Generation-Model) | [**ModelScope**](https://www.modelscope.cn/models/zhaoxlpku/PromptCoT-Problem-Generation-Model) – we generate Olympiad-level math problems: 119 | 120 | ```bash 121 | python problem_generation.py \ 122 | --data_path data/problem_generation_inputs.jsonl \ 123 | --output_path data/problem_generation_outputs.jsonl \ 124 | --model_path /path/to/problem_generation_model \ 125 | --n_gpus 4 \ 126 | --temperature 0.6 \ 127 | --max_len 4096 \ 128 | --seed 8000 129 | ``` 130 | 131 | #### **Step 4: Reward-Based Filtering** 132 | To ensure high-quality problem selection, we compute **reward scores** using two evaluation models: 133 | 134 | ```bash 135 | python rejection_sampling_reward.py \ 136 | --data_path data/problem_generation_outputs.jsonl \ 137 | --output_path data/problem_generation_outputs_reward0.jsonl \ 138 | --model_path /path/to/Llama-3.1-70B-Instruct \ 139 | --n_gpus 4 \ 140 | --temperature 0.6 \ 141 | --use_chat_template True \ 142 | --seed 8000 143 | 144 | python rejection_sampling_reward.py \ 145 | --data_path data/problem_generation_outputs.jsonl \ 146 | --output_path data/problem_generation_outputs_reward1.jsonl \ 147 | --model_path /path/to/Qwen2.5-72B-Instruct \ 148 | --n_gpus 4 \ 149 | --temperature 0.6 \ 150 | --use_chat_template True \ 151 | --seed 8000 152 | ``` 153 | 154 | #### **Step 5: Select High-Quality Problems** 155 | To ensure only the **highest-quality** problems are used for training, we apply a filtering process based on reward scores. Problems that receive **perfect ratings from multiple evaluators** are retained. 156 | 157 | ```bash 158 | python problem_filtering.py \ 159 | --template data/problem_generation_outputs_reward{}.jsonl \ 160 | --output_path data/problem_generation_training.jsonl \ 161 | --only_perfect True \ 162 | --n_rewards 2 163 | ``` 164 | 165 | 📌 **Our curated dataset of high-quality problems** (where each problem received **perfect ratings** across all evaluation criteria) is available here: [Hugging Face](https://huggingface.co/datasets/xl-zhao/PromptCoT-Problem-Generation-Dataset) | [ModelScope](https://www.modelscope.cn/datasets/zhaoxlpku/PromptCoT-Problem-Generation-Dataset) 166 | 167 | --- 168 | 169 | ## **Distillation** 170 | After generating high-quality problems, we distill the knowledge into **smaller models** using **Deepseek-R1-Distill-Qwen-7B** as the teacher model. We train: 171 | - **PromptCoT-DS-1.5B** (Student: Deepseek-R1-Distill-Qwen-1.5B) 172 | - **PromptCoT-DS-7B** (Student: Deepseek-R1-Distill-Qwen-7B) 173 | 174 | 175 | --- 176 | 177 | ## **Reproducing Our Results** 178 | To reproduce the results, follow these steps. 179 | 180 | #### **Step 1: Install Dependencies** 181 | ```bash 182 | conda create -n promptcot python=3.10.14 183 | conda activate promptcot 184 | pip install -r requirements.txt --ignore-installed --no-deps 185 | ``` 186 | 187 | #### **Step 2: Run Inference on Benchmark Datasets** 188 | To run inference for the PromptCoT-DS series models, use the following command: 189 | 190 | ```bash 191 | python infer_longcot.py \ 192 | --data_path data/{dataset_name}.jsonl \ 193 | --output_path data/{dataset_name}_predictions.jsonl \ 194 | --model_path /path/to/{model_name} \ 195 | --tokenizer_path /path/to/Deepseek-R1-Distill-Qwen-1.5B \ 196 | --n_gpus 1 \ 197 | --temperature 0.6 \ 198 | --max_len 32768 199 | --n 8 200 | ``` 201 | where `{dataset_name}` can be: 202 | - `gsm8k` 203 | - `math500` 204 | - `aime2024` 205 | - `aime2025` 206 | 207 | and `{model_name}` can be: 208 | - `PromptCoT-DS-1.5B` 209 | - `PromptCoT-DS-7B` 210 | 211 | To run inference for PromptCoT-QwQ-32B, use the following command: 212 | 213 | ```bash 214 | python infer_longcot.py \ 215 | --data_path data/qwq/qwq_{dataset_name}_test.jsonl \ 216 | --output_path data/qwq/qwq_{dataset_name}_predictions.jsonl \ 217 | --model_path /path/to/PromptCoT-QwQ-32B \ 218 | --tokenizer_path /path/to/QwQ-32B \ 219 | --n_gpus 2 \ 220 | --temperature 0.6 \ 221 | --max_len 16384 222 | --n 8 223 | ``` 224 | where `{dataset_name}` can be: 225 | - `gsm8k` 226 | - `math500` 227 | - `aime2024` 228 | - `aime2025` 229 | 230 | #### **Step 3: Compute Accuracy** 231 | ```bash 232 | python calc_acc.py \ 233 | --output_path data/{dataset_name}_predictions.jsonl 234 | ``` 235 | 236 | #### **[New]** **Step 4: Train with DeepSpeed** 237 | 238 | You can reproduce the training process for the model using DeepSpeed with the following commands. Make sure to replace the paths with your own data and model paths. 239 | 240 | - **For PromptCoT-DS-1.5B**: 241 | 242 | ```bash 243 | deepspeed --num_gpus=8 train.py --bf16=True --data_path=/path/to/PromptCoT-DS-Dataset --ddp_find_unused_parameters=False --deepspeed=configs/promptcot_ds_1_5b_config.json --evaluation_strategy=no --fp16=False --gradient_accumulation_steps=8 --gradient_checkpointing=True --learning_rate=5e-06 --load_best_model_at_end=False --logging_steps=1 --model_max_length=16384 --model_name_or_path=/path/to/DeepSeek-R1-Distill-Qwen-1.5B --num_train_epochs=2 --output_dir=/path/to/PromptCoT-DS-1.5B --per_device_train_batch_size=1 --resume_from_checkpoint=False --save_steps=500 --save_strategy=steps --save_total_limit=6 --tokenizer_path=/path/to/DeepSeek-R1-Distill-Qwen-1.5B --warmup_steps=100 --weight_decay=0.01 244 | ``` 245 | 246 | - **For PromptCoT-DS-7B**: 247 | 248 | ```bash 249 | deepspeed --num_gpus=8 train.py --bf16=True --data_path=/path/to/PromptCoT-DS-Dataset --ddp_find_unused_parameters=False --deepspeed=configs/promptcot_ds_7b_config.json --evaluation_strategy=no --fp16=False --gradient_accumulation_steps=8 --gradient_checkpointing=True --learning_rate=5e-06 --load_best_model_at_end=False --logging_steps=1 --model_max_length=16384 --model_name_or_path=/path/to/DeepSeek-R1-Distill-Qwen-7B --num_train_epochs=2 --output_dir=/path/to/PromptCoT-DS-7B --per_device_train_batch_size=1 --resume_from_checkpoint=False --save_steps=500 --save_strategy=steps --save_total_limit=6 --tokenizer_path=/path/to/DeepSeek-R1-Distill-Qwen-7B --warmup_steps=100 --weight_decay=0.01 250 | ``` 251 | 252 | - **For PromptCoT-QwQ-32B**: 253 | 254 | ```bash 255 | deepspeed --num_gpus=8 train.py --bf16=True --data_path=/path/to/PromptCoT-QwQ-Dataset --ddp_find_unused_parameters=False --deepspeed=configs/promptcot_qwq_32b_config.json --evaluation_strategy=no --fp16=False --gradient_accumulation_steps=2 --gradient_checkpointing=True --learning_rate=2e-06 --load_best_model_at_end=False --logging_steps=1 --model_max_length=16384 --model_name_or_path=/path/to/QwQ-32B --num_train_epochs=2 --output_dir=/path/to/PromptCoT-QwQ-32B --per_device_train_batch_size=1 --resume_from_checkpoint=False --save_steps=500 --save_strategy=steps --save_total_limit=6 --tokenizer_path=/path/to/QwQ-32B --warmup_steps=100 --weight_decay=0.01 256 | ``` 257 | 258 | 259 | --- 260 | 261 | ## **Citation** 262 | If you find **PromptCoT** useful, please consider citing: 263 | 264 | ``` 265 | @article{zhao2025promptcot, 266 | author = {Zhao, Xueliang and Wu, Wei and Guan, Jian and Kong, Lingpeng}, 267 | title = {PromptCoT: Synthesizing Olympiad-Level Problems for Mathematical Reasoning in Large Language Models}, 268 | year = {2025}, 269 | journal = {arXiv preprint arXiv:2503.02324}, 270 | url = {http://arxiv.org/abs/2503.02324} 271 | } 272 | ``` 273 | -------------------------------------------------------------------------------- /PromptCoT/calc_acc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | import numpy as np 5 | import re 6 | from collections import defaultdict 7 | 8 | from eval.math_equivalence import is_equiv_minerva as is_equiv 9 | from eval.util import last_boxed_only_string, first_boxed_only_string, remove_boxed 10 | from eval.qwen_math import math_equal, extract_answer, strip_string 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description="Evaluate large language models") 15 | parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.") 16 | 17 | args = parser.parse_args() 18 | 19 | completions_by_prompt = defaultdict(lambda: defaultdict(list)) 20 | 21 | all_items = [] 22 | with open(args.output_path, encoding="utf-8") as f: 23 | for line in f: 24 | item = json.loads(line) 25 | all_items.append(item) 26 | 27 | if "prompt" not in item: 28 | raise ValueError(f"Item is missing required 'prompt' field") 29 | 30 | source = item["source"] 31 | prompt = item["prompt"] 32 | completions_by_prompt[source][prompt].append(item) 33 | 34 | results = defaultdict(lambda: defaultdict(dict)) 35 | 36 | prompts_by_source = defaultdict(set) 37 | 38 | for source in completions_by_prompt: 39 | for prompt in tqdm(completions_by_prompt[source]): 40 | completions = completions_by_prompt[source][prompt] 41 | 42 | prompts_by_source[source].add(prompt) 43 | 44 | for run_idx, item in enumerate(completions): 45 | completion = item["completion"] 46 | reference_solution = item.get("reference_solution", item.get("solution")) 47 | 48 | if source in ["gsm8k"]: 49 | correct = math_equal( 50 | extract_answer(completion), 51 | reference_solution.split("####")[-1].strip() 52 | ) 53 | elif source in ["math", "aime2024", "aime2025"]: 54 | correct = math_equal( 55 | extract_answer(completion), 56 | strip_string(reference_solution.split("####")[1].strip()), 57 | timeout=False, 58 | ) or is_equiv( 59 | remove_boxed(last_boxed_only_string(completion)), 60 | reference_solution.split("####")[-1].strip() if "####" in reference_solution else ( 61 | remove_boxed(last_boxed_only_string(reference_solution)), 62 | ) 63 | ) 64 | else: 65 | raise NotImplementedError(f"Source '{source}' is not implemented") 66 | 67 | results[source][run_idx][prompt] = int(correct) 68 | 69 | print("\nRESULTS BY SOURCE:") 70 | print("-" * 80) 71 | print(f"{'Source':<15} {'Accuracy':<20} {'Num Prompts':<15} {'Runs':<10}") 72 | print("-" * 80) 73 | 74 | for source in sorted(results.keys()): 75 | # We expect 8 runs (0-7) 76 | expected_runs = 8 77 | prompts = sorted(prompts_by_source[source]) 78 | 79 | run_accuracies = [] 80 | run_details = [] 81 | 82 | for run_idx in range(expected_runs): 83 | if run_idx not in results[source]: 84 | print(f"Warning: Source '{source}' is missing run index {run_idx}") 85 | continue 86 | 87 | correct_count = 0 88 | total_count = 0 89 | 90 | for prompt in prompts: 91 | if prompt in results[source][run_idx]: 92 | correct_count += results[source][run_idx][prompt] 93 | total_count += 1 94 | 95 | if total_count > 0: 96 | run_accuracy = correct_count / total_count 97 | run_accuracies.append(run_accuracy) 98 | run_details.append(f"Run {run_idx}: {run_accuracy:.4f} ({correct_count}/{total_count})") 99 | 100 | if run_accuracies: 101 | mean_accuracy = np.mean(run_accuracies) 102 | std_dev = np.std(run_accuracies, ddof=1) 103 | 104 | mean_pct = round(mean_accuracy * 100, 1) 105 | std_dev_pct = round(std_dev * 100, 1) 106 | 107 | accuracy_str = f"{mean_pct:.1f}% ± {std_dev_pct:.1f}%" 108 | 109 | print(f"{source:<15} {accuracy_str:<20} {len(prompts):<15} {len(run_accuracies)}") 110 | 111 | for detail in run_details: 112 | print(f" {detail}") 113 | 114 | print("-" * 80) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /PromptCoT/concept_encoding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from sentence_transformers import SentenceTransformer, models 4 | import os 5 | 6 | 7 | def create_last_token_pooling_model(model_path): 8 | # Load the transformer model 9 | word_embedding_model = models.Transformer(model_path) 10 | 11 | # Add a Pooling layer with `last_token` pooling 12 | pooling_model = models.Pooling( 13 | word_embedding_model.get_word_embedding_dimension(), 14 | pooling_mode_mean_tokens=False, 15 | pooling_mode_cls_token=False, 16 | pooling_mode_max_tokens=False, 17 | pooling_mode_lasttoken=True # Enable last token pooling 18 | ) 19 | 20 | # Combine the transformer and pooling model into a SentenceTransformer 21 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) 22 | model.tokenizer.pad_token = model.tokenizer.eos_token 23 | return model 24 | 25 | 26 | def encode_sentences(sentences, model_path, n_gpus, batch_size): 27 | model = create_last_token_pooling_model(model_path) 28 | 29 | eos_token = model.tokenizer.eos_token 30 | if eos_token is None: 31 | raise ValueError( 32 | "The tokenizer does not have an EOS token. Please define one or use a different model." 33 | ) 34 | 35 | sentences = [f"{sentence}{eos_token}" for sentence in sentences] 36 | pool = model.start_multi_process_pool(target_devices=[f"cuda:{i}" for i in range(n_gpus)]) 37 | embeddings = model.encode_multi_process( 38 | sentences, 39 | pool, 40 | batch_size=batch_size, 41 | show_progress_bar=True, 42 | ) 43 | model.stop_multi_process_pool(pool) 44 | 45 | return [embed.tolist() for embed in embeddings] 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file.") 51 | parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.") 52 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.") 53 | parser.add_argument("--n_gpus", type=int, default=8, help="Number of GPUs to use for tensor parallelism.") 54 | parser.add_argument("--batch_size", type=int, default=512) 55 | args = parser.parse_args() 56 | 57 | sentences = set() 58 | with open(args.data_path, encoding="utf-8") as f: 59 | for line in f.readlines(): 60 | lst = json.loads(line)["concepts"] 61 | sentences.update(lst) 62 | sentences = list(sentences) 63 | 64 | embeddings = encode_sentences( 65 | sentences, 66 | args.model_path, 67 | args.n_gpus, 68 | args.batch_size 69 | ) 70 | 71 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 72 | with open(args.output_path, "w", encoding="utf-8") as f: 73 | for sentence, embedding in zip(sentences, embeddings): 74 | f.write(json.dumps({ 75 | "sentence": sentence, 76 | "embedding": embedding, 77 | }) + "\n") 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /PromptCoT/concept_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data_synthesis import FastConceptSampler as ConceptSampler 3 | import json 4 | from tqdm import tqdm 5 | import random 6 | import os 7 | 8 | 9 | def load_embeddings(embed_path): 10 | embeddings = {} 11 | with open(embed_path, encoding="utf-8") as f: 12 | for line in tqdm(f.readlines(), desc="Loading embeddings"): 13 | item = json.loads(line) 14 | embeddings[item['sentence']] = item['embedding'] 15 | return embeddings 16 | 17 | 18 | def load_concept_lists(input_file): 19 | concept_lists = [] 20 | with open(input_file, encoding="utf-8") as f: 21 | for line in f.readlines(): 22 | item = json.loads(line) 23 | lst = item["concepts"] 24 | concept_lists.append(lst) 25 | return concept_lists 26 | 27 | 28 | def generate_samples(sampler, data_size, difficulty_levels): 29 | results = [] 30 | concept_text_pool = set() 31 | 32 | with tqdm(total=data_size, desc="Generating samples") as pbar: 33 | while len(results) < data_size: 34 | concept_list = sampler.sample_concept_list(size=5, temperature=0.2) 35 | concept_text = "\n".join(f"{i + 1}. {concept}" for i, concept in enumerate(concept_list)) 36 | 37 | if concept_text in concept_text_pool: 38 | continue 39 | 40 | concept_text_pool.add(concept_text) 41 | level = random.choice(difficulty_levels) 42 | 43 | prompt = ( 44 | f"Given foundational concepts and difficulty level, identify connections and " 45 | f"develop a question that integrates these concepts with appropriate complexity.\n\n" 46 | f"Foundational Concepts:\n{concept_text}\n\nDifficulty Level: {level}" 47 | ) 48 | 49 | results.append({ 50 | "foundational_concepts": concept_list, 51 | "level": level, 52 | "prompt": prompt, 53 | }) 54 | pbar.update(1) 55 | 56 | return results 57 | 58 | 59 | def save_results(results, output_file_path): 60 | os.makedirs(os.path.dirname(output_file_path), exist_ok=True) 61 | with open(output_file_path, "w", encoding="utf-8") as f: 62 | for item in results: 63 | f.write(json.dumps(item) + "\n") 64 | 65 | 66 | def main(): 67 | parser = argparse.ArgumentParser(description='Generate problem concepts dataset') 68 | parser.add_argument( 69 | '--data_path', 70 | type=str, 71 | required=True, 72 | help='Path to input JSONL file' 73 | ) 74 | parser.add_argument( 75 | '--output_path', 76 | type=str, 77 | required=True, 78 | help='Path to output JSONL file' 79 | ) 80 | parser.add_argument( 81 | '--data_size', 82 | type=int, 83 | default=2000, 84 | help='Number of samples to generate' 85 | ) 86 | parser.add_argument( 87 | '--embed_path', 88 | type=str, 89 | default="data/embeddings.jsonl", 90 | help='Path to embeddings file' 91 | ) 92 | args = parser.parse_args() 93 | 94 | # Difficulty levels configuration 95 | difficulty_levels = [ 96 | "AMC12", "HMMT-Nov", "HMMT-Feb", "AIME", 97 | "USAJMO", "USAMO", "USOJMO", "USOMO" 98 | ] 99 | # all_levels = {'AMC10', 'HMMT-Nov', 'AIME', 'AHSME', 100 | # 'HMMT-Feb', 'USOMO', 'AMC12', 'USOJMO', 101 | # 'AJHSME', 'AMC8', 'USAMO', 'USAJMO'} 102 | 103 | # Load data 104 | embeddings = load_embeddings(args.embed_path) 105 | concept_lists = load_concept_lists(args.data_path) 106 | 107 | # Initialize sampler and generate samples 108 | sampler = ConceptSampler(concept_lists=concept_lists, concept_embeddings=embeddings) 109 | results = generate_samples(sampler, args.data_size, difficulty_levels) 110 | 111 | # Save results 112 | save_results(results, args.output_path) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() -------------------------------------------------------------------------------- /PromptCoT/configs/promptcot_ds_1_5b_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": 1e-08, 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": 0, 26 | "warmup_max_lr": 5e-06, 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 2, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": false, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1000000000.0, 39 | "reduce_bucket_size": 1000000.0 40 | }, 41 | "steps_per_print": 1, 42 | "train_micro_batch_size_per_gpu": 1 43 | } 44 | -------------------------------------------------------------------------------- /PromptCoT/configs/promptcot_ds_7b_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": 1e-08, 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": 0, 26 | "warmup_max_lr": 5e-06, 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 2, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": false, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1000000000.0, 39 | "reduce_bucket_size": 1000000.0 40 | }, 41 | "steps_per_print": 1, 42 | "train_micro_batch_size_per_gpu": 1 43 | } 44 | -------------------------------------------------------------------------------- /PromptCoT/configs/promptcot_qwq_32b_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": 1e-08, 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": 0, 26 | "warmup_max_lr": 2e-06, 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_optimizer": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": false, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1000000000.0, 43 | "reduce_bucket_size": 1000000.0, 44 | "stage3_prefetch_bucket_size": 940000.0, 45 | "stage3_param_persistence_threshold": 10000.0, 46 | "stage3_max_live_parameters": 1000000000.0, 47 | "stage3_max_reuse_distance": 1000000000.0, 48 | "stage3_gather_16bit_weights_on_model_save": true 49 | }, 50 | "steps_per_print": 1, 51 | "train_micro_batch_size_per_gpu": 1 52 | } 53 | -------------------------------------------------------------------------------- /PromptCoT/data/aime2025.jsonl: -------------------------------------------------------------------------------- 1 | {"idx": 0, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{70}\n#### 70", "source": "aime2025"} 2 | {"idx": 1, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>On $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{588}\n#### 588", "source": "aime2025"} 3 | {"idx": 2, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>The 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{16}\n#### 16", "source": "aime2025"} 4 | {"idx": 3, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{117}\n#### 117", "source": "aime2025"} 5 | {"idx": 4, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{279}\n#### 279", "source": "aime2025"} 6 | {"idx": 5, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{504}\n#### 504", "source": "aime2025"} 7 | {"idx": 6, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{821}\n#### 821", "source": "aime2025"} 8 | {"idx": 7, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{77}\n#### 77", "source": "aime2025"} 9 | {"idx": 8, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>The parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{62}\n#### 62", "source": "aime2025"} 10 | {"idx": 9, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>The 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{81}\n#### 81", "source": "aime2025"} 11 | {"idx": 10, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>A piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{259}\n#### 259", "source": "aime2025"} 12 | {"idx": 11, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>The set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz", "reference_solution": "\\boxed{510}\n#### 510", "source": "aime2025"} 13 | {"idx": 12, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Alex divides a disk into four quadrants with two perpendicular diameters intersecting at the center of the disk. He draws 25 more line segments through the disk, drawing each segment by selecting two points at random on the perimeter of the disk in different quadrants and connecting those two points. Find the expected number of regions into which these 27 line segments divide the disk.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{204}\n#### 204", "source": "aime2025"} 14 | {"idx": 13, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Let $ABCDE$ be a convex pentagon with $AB=14, BC=7, CD=24, DE=13, EA=26,$ and $\\angle B=\\angle E=60^\\circ$. For each point $X$ in the plane, define $f(X)=AX+BX+CX+DX+EX$. The least possible value of $f(X)$ can be expressed as $m+n\\sqrt{p}$, where $m$ and $n$ are positive integers and $p$ is not divisible by the square of any prime. Find $m+n+p$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{60}\n#### 60", "source": "aime2025"} 15 | {"idx": 14, "prompt": "<\uff5cbegin\u2581of\u2581sentence\uff5c>Please reason step by step, and put your final answer within \\boxed{{}}.<\uff5cUser\uff5c>Let $N$ denote the number of ordered triples of positive integers $(a,b,c)$ such that $a,b,c\\leq3^6$ and $a^3+b^3+c^3$ is a multiple of $3^7$. Find the remainder when $N$ is divided by $1000$.<\uff5cAssistant\uff5c>", "reference_solution": "\\boxed{735}\n#### 735", "source": "aime2025"} 16 | -------------------------------------------------------------------------------- /PromptCoT/data/qwq/qwq_aime2025_test.jsonl: -------------------------------------------------------------------------------- 1 | {"idx": 0, "prompt": "<|im_start|>user\nFind the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{70}\n#### 70", "source": "aime2025"} 2 | {"idx": 1, "prompt": "<|im_start|>user\nOn $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{588}\n#### 588", "source": "aime2025"} 3 | {"idx": 2, "prompt": "<|im_start|>user\nThe 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{16}\n#### 16", "source": "aime2025"} 4 | {"idx": 3, "prompt": "<|im_start|>user\nFind the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{117}\n#### 117", "source": "aime2025"} 5 | {"idx": 4, "prompt": "<|im_start|>user\nThere are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{279}\n#### 279", "source": "aime2025"} 6 | {"idx": 5, "prompt": "<|im_start|>user\nAn isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{504}\n#### 504", "source": "aime2025"} 7 | {"idx": 6, "prompt": "<|im_start|>user\nThe twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{821}\n#### 821", "source": "aime2025"} 8 | {"idx": 7, "prompt": "<|im_start|>user\nLet $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{77}\n#### 77", "source": "aime2025"} 9 | {"idx": 8, "prompt": "<|im_start|>user\nThe parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{62}\n#### 62", "source": "aime2025"} 10 | {"idx": 9, "prompt": "<|im_start|>user\nThe 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{81}\n#### 81", "source": "aime2025"} 11 | {"idx": 10, "prompt": "<|im_start|>user\nA piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{259}\n#### 259", "source": "aime2025"} 12 | {"idx": 11, "prompt": "<|im_start|>user\nThe set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{510}\n#### 510", "source": "aime2025"} 13 | {"idx": 12, "prompt": "<|im_start|>user\nAlex divides a disk into four quadrants with two perpendicular diameters intersecting at the center of the disk. He draws 25 more line segments through the disk, drawing each segment by selecting two points at random on the perimeter of the disk in different quadrants and connecting those two points. Find the expected number of regions into which these 27 line segments divide the disk.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{204}\n#### 204", "source": "aime2025"} 14 | {"idx": 13, "prompt": "<|im_start|>user\nLet $ABCDE$ be a convex pentagon with $AB=14, BC=7, CD=24, DE=13, EA=26,$ and $\\angle B=\\angle E=60^\\circ$. For each point $X$ in the plane, define $f(X)=AX+BX+CX+DX+EX$. The least possible value of $f(X)$ can be expressed as $m+n\\sqrt{p}$, where $m$ and $n$ are positive integers and $p$ is not divisible by the square of any prime. Find $m+n+p$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{60}\n#### 60", "source": "aime2025"} 15 | {"idx": 14, "prompt": "<|im_start|>user\nLet $N$ denote the number of ordered triples of positive integers $(a,b,c)$ such that $a,b,c\\leq3^6$ and $a^3+b^3+c^3$ is a multiple of $3^7$. Find the remainder when $N$ is divided by $1000$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{735}\n#### 735", "source": "aime2025"} 16 | {"idx": 15, "prompt": "<|im_start|>user\nSix points $ A, B, C, D, E, $ and $ F $ lie in a straight line in that order. Suppose that $ G $ is a point not on the line and that $ AC = 26 $, $ BD = 22 $, $ CE = 31 $, $ DF = 33 $, $ AF = 73 $, $ CG = 40 $, and $ DG = 30 $. Find the area of $ \\triangle BGE $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{468}\n#### 468", "source": "aime2025"} 17 | {"idx": 16, "prompt": "<|im_start|>user\nFind the sum of all positive integers $ n $ such that $ n + 2 $ divides the product $ 3(n + 3)(n^2 + 9) $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{49}\n#### 49", "source": "aime2025"} 18 | {"idx": 17, "prompt": "<|im_start|>user\nFour unit squares form a $2 \\times 2$ grid. Each of the 12 unit line segments forming the sides of the squares is colored either red or blue in such a way that each unit square has 2 red sides and 2 blue sides. Find the number of such colorings.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{82}\n#### 82", "source": "aime2025"} 19 | {"idx": 18, "prompt": "<|im_start|>user\nThe product $ \\prod_{k=4}^{63} \\frac{\\log_k(5^{k^2-1})}{\\log_{k+1}(5^{k^2-4})} = \\frac{\\log_4(5^{15})}{\\log_5(5^{12})} \\cdot \\frac{\\log_5(5^{24})}{\\log_6(5^{21})} \\cdot \\frac{\\log_6(5^{35})}{\\log_7(5^{32})} \\cdots \\frac{\\log_{63}(5^{3968})}{\\log_{64}(5^{3965})} $ is equal to $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{106}\n#### 106", "source": "aime2025"} 20 | {"idx": 19, "prompt": "<|im_start|>user\nSuppose $ \\triangle ABC $ has angles $ \\angle BAC = 84^\\circ $, $ \\angle ABC = 60^\\circ $, and $ \\angle ACB = 36^\\circ $. Let $ D, E, $ and $ F $ be the midpoints of sides $ \\overline{BC} $, $ \\overline{AC} $, and $ \\overline{AB} $, respectively. The circumcircle of $ \\triangle DEF $ intersects $ \\overline{BD} $, $ \\overline{AE} $, and $ \\overline{AF} $ at points $ G, H, $ and $ J $, respectively. The points $ G, D, E, H, J, $ and $ F $ divide the circumcircle of $ \\triangle DEF $ into six minor arcs, as shown. Find $ \\widehat{DE} + 2 \\cdot \\widehat{HJ} + 3 \\cdot \\widehat{FG} $, where the arcs are measured in degrees.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{336^\\circ}\n#### 336^\\circ", "source": "aime2025"} 21 | {"idx": 20, "prompt": "<|im_start|>user\nCircle $\\omega_1$ with radius 6 centered at point $A$ is internally tangent at point $B$ to circle $\\omega_2$ with radius 15. Points $C$ and $D$ lie on $\\omega_2$ such that $\\overline{BC}$ is a diameter of $\\omega_2$ and $\\overline{BC} \\perp \\overline{AD}$. The rectangle $EFGH$ is inscribed in $\\omega_1$ such that $\\overline{EF} \\perp \\overline{BC}$, $C$ is closer to $\\overline{GH}$ than to $\\overline{EF}$, and $D$ is closer to $\\overline{FG}$ than to $\\overline{EH}$, as shown. Triangles $\\triangle DGF$ and $\\triangle CHG$ have equal areas. The area of rectangle $EFGH$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m + n$.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{293}\n#### 293", "source": "aime2025"} 22 | {"idx": 21, "prompt": "<|im_start|>user\nLet $ A $ be the set of positive integer divisors of 2025. Let $ B $ be a randomly selected subset of $ A $. The probability that $ B $ is a nonempty set with the property that the least common multiple of its elements is 2025 is $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{237}\n#### 237", "source": "aime2025"} 23 | {"idx": 22, "prompt": "<|im_start|>user\nFrom an unlimited supply of 1-cent coins, 10-cent coins, and 25-cent coins, Silas wants to find a collection of coins that has a total value of $ N $ cents, where $ N $ is a positive integer. He uses the so-called **greedy algorithm**, successively choosing the coin of greatest value that does not cause the value of his collection to exceed $ N $. For example, to get 42 cents, Silas will choose a 25-cent coin, then a 10-cent coin, then 7 1-cent coins. However, this collection of 9 coins uses more coins than necessary to get a total of 42 cents; indeed, choosing 4 10-cent coins and 2 1-cent coins achieves the same total value with only 6 coins.\n\nIn general, the greedy algorithm succeeds for a given $ N $ if no other collection of 1-cent, 10-cent, and 25-cent coins gives a total value of $ N $ cents using strictly fewer coins than the collection given by the greedy algorithm. Find the number of values of $ N $ between 1 and 1000 inclusive for which the greedy algorithm succeeds.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{610}\n#### 610", "source": "aime2025"} 24 | {"idx": 23, "prompt": "<|im_start|>user\nThere are $ n $ values of $ x $ in the interval $ 0 < x < 2\\pi $ where $ f(x) = \\sin(7\\pi \\cdot \\sin(5x)) = 0 $. For $ t $ of these $ n $ values of $ x $, the graph of $ y = f(x) $ is tangent to the $ x $-axis. Find $ n + t $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{149}\n#### 149", "source": "aime2025"} 25 | {"idx": 24, "prompt": "<|im_start|>user\nSixteen chairs are arranged in a row. Eight people each select a chair in which to sit so that no person sits next to two other people. Let $ N $ be the number of subsets of 16 chairs that could be selected. Find the remainder when $ N $ is divided by 1000.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{907}\n#### 907", "source": "aime2025"} 26 | {"idx": 25, "prompt": "<|im_start|>user\nLet $ S $ be the set of vertices of a regular 24-gon. Find the number of ways to draw 12 segments of equal lengths so that each vertex in $ S $ is an endpoint of exactly one of the 12 segments.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{113}\n#### 113", "source": "aime2025"} 27 | {"idx": 26, "prompt": "<|im_start|>user\nLet $ A_1A_2 \\ldots A_{11} $ be an 11-sided non-convex simple polygon with the following properties:\n* The area of $ A_iA_1A_{i+1} $ is 1 for each $ 2 \\leq i \\leq 10 $,\n* $ \\cos(\\angle A_iA_1A_{i+1}) = \\frac{12}{13} $ for each $ 2 \\leq i \\leq 10 $,\n* The perimeter of $ A_1A_2 \\ldots A_{11} $ is 20.\nIf $ A_1A_2 + A_1A_{11} $ can be expressed as $ \\frac{m\\sqrt{n} - p}{q} $ for positive integers $ m, n, p, q $ with $ n $ squarefree and no prime divides all of $ m, p, q$, find $ m + n + p + q $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{19}\n#### 19", "source": "aime2025"} 28 | {"idx": 27, "prompt": "<|im_start|>user\nLet the sequence of rationals $ x_1, x_2, \\ldots $ be defined such that $ x_1 = \\frac{25}{11} $ and\n$ x_{k+1} = \\frac{1}{3} \\left( x_k + \\frac{1}{x_k} - 1 \\right). $\n$ x_{2025} $ can be expressed as $ \\frac{m}{n} $ for relatively prime positive integers $ m $ and $ n $. Find the remainder when $ m + n $ is divided by 1000.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{248}\n#### 248", "source": "aime2025"} 29 | {"idx": 28, "prompt": "<|im_start|>user\nLet $ \\triangle ABC $ be a right triangle with $ \\angle A = 90^\\circ $ and $ BC = 38 $. There exist points $ K $ and $ L $ inside the triangle such that $ AK = AL = BK = CL = KL = 14. $ The area of the quadrilateral $ BKLC $ can be expressed as $ n \\sqrt{3} $ for some positive integer $ n $. Find $ n $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{104}\n#### 104", "source": "aime2025"} 30 | {"idx": 29, "prompt": "<|im_start|>user\nThere are exactly three positive real numbers $ k $ such that the function\n$ f(x) = \\frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $\ndefined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $.\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n\n", "reference_solution": "\\boxed{240}\n#### 240", "source": "aime2025"} 31 | -------------------------------------------------------------------------------- /PromptCoT/data_synthesis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | 4 | from scipy.sparse import csr_matrix 5 | import numpy as np 6 | from collections import Counter 7 | import faiss 8 | import torch 9 | 10 | 11 | class ConceptSampler: 12 | def __init__(self, concept_lists, concept_embeddings): 13 | """ 14 | concept_lists: List[List[str]] - lists of concept names 15 | concept_embeddings: dict - mapping from concept name to vector 16 | """ 17 | self.concept_lists = concept_lists 18 | self.embeddings = concept_embeddings 19 | # Create co-occurrence statistics 20 | self.build_cooccurrence_stats() 21 | 22 | def build_cooccurrence_stats(self): 23 | """Build co-occurrence matrix and concept frequencies""" 24 | # Get unique concepts 25 | self.all_concepts = list(set(c for lst in self.concept_lists for c in lst)) 26 | self.concept_to_idx = {c: i for i, c in enumerate(self.all_concepts)} 27 | 28 | # Count individual frequencies 29 | self.concept_freq = {c: 0 for c in self.all_concepts} 30 | for lst in self.concept_lists: 31 | for c in lst: 32 | self.concept_freq[c] += 1 33 | 34 | # Build co-occurrence matrix 35 | n = len(self.all_concepts) 36 | self.cooccur_matrix = np.zeros((n, n)) 37 | for lst in self.concept_lists: 38 | for c1 in lst: 39 | for c2 in lst: 40 | if c1 != c2: 41 | i, j = self.concept_to_idx[c1], self.concept_to_idx[c2] 42 | self.cooccur_matrix[i, j] += 1 43 | 44 | def sample_concept_list(self, size=3, temperature=0.1, seed_concept=None): 45 | """ 46 | Sample a list of concepts 47 | size: desired list size 48 | temperature: controls randomness (lower = more deterministic) 49 | seed_concept: optional starting concept 50 | """ 51 | result = [] 52 | 53 | # Start with seed concept or random concept 54 | if seed_concept is None: 55 | # Sample first concept based on frequency 56 | freqs = np.array([self.concept_freq[c] for c in self.all_concepts]) 57 | probs = freqs / freqs.sum() 58 | seed_concept = np.random.choice(self.all_concepts, p=probs) 59 | 60 | result.append(seed_concept) 61 | 62 | # Sample remaining concepts 63 | while len(result) < size: 64 | scores = self.get_next_concept_scores(result) 65 | 66 | # Apply temperature 67 | scores = np.exp(scores / temperature) 68 | probs = scores / scores.sum() 69 | 70 | # Sample next concept 71 | next_concept = np.random.choice(self.all_concepts, p=probs) 72 | result.append(next_concept) 73 | 74 | return result 75 | 76 | def get_next_concept_scores(self, current_list): 77 | """Calculate scores for potential next concepts""" 78 | scores = np.zeros(len(self.all_concepts)) 79 | 80 | for concept in self.all_concepts: 81 | if concept in current_list: 82 | continue 83 | 84 | # Co-occurrence score 85 | cooccur_score = 0 86 | for c in current_list: 87 | i, j = self.concept_to_idx[concept], self.concept_to_idx[c] 88 | cooccur_score += self.cooccur_matrix[i, j] 89 | 90 | # Semantic similarity score 91 | sim_score = 0 92 | concept_vec = self.embeddings[concept] 93 | for c in current_list: 94 | sim = cosine_similarity( 95 | [concept_vec], 96 | [self.embeddings[c]] 97 | )[0][0] 98 | sim_score += sim 99 | 100 | # Combine scores 101 | scores[self.concept_to_idx[concept]] = cooccur_score + sim_score 102 | 103 | return scores 104 | 105 | 106 | class FastConceptSampler: 107 | def __init__(self, concept_lists, concept_embeddings): 108 | self.concept_lists = concept_lists 109 | self.embeddings = concept_embeddings 110 | self.build_cooccurrence_stats() 111 | self.setup_faiss_index() 112 | 113 | def build_cooccurrence_stats(self): 114 | # Get unique concepts and create mapping 115 | self.all_concepts = list(set(c for lst in self.concept_lists for c in lst)) 116 | self.concept_to_idx = {c: i for i, c in enumerate(self.all_concepts)} 117 | self.idx_to_concept = {i: c for c, i in self.concept_to_idx.items()} 118 | 119 | # Count frequencies using Counter 120 | self.concept_freq = Counter(c for lst in self.concept_lists for c in lst) 121 | 122 | # Build sparse co-occurrence matrix 123 | rows, cols, data = [], [], [] 124 | for lst in self.concept_lists: 125 | lst_idx = [self.concept_to_idx[c] for c in lst] 126 | for i in lst_idx: 127 | for j in lst_idx: 128 | if i != j: 129 | rows.append(i) 130 | cols.append(j) 131 | data.append(1) 132 | 133 | n = len(self.all_concepts) 134 | self.cooccur_matrix = csr_matrix((data, (rows, cols)), shape=(n, n)) 135 | 136 | def setup_faiss_index(self): 137 | # Convert embeddings to numpy array 138 | n = len(self.all_concepts) 139 | d = len(next(iter(self.embeddings.values()))) 140 | embedding_matrix = np.zeros((n, d), dtype=np.float32) 141 | 142 | for concept, idx in self.concept_to_idx.items(): 143 | embedding_matrix[idx] = self.embeddings[concept] 144 | 145 | # Normalize embeddings for cosine similarity 146 | norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True) 147 | embedding_matrix /= norms 148 | 149 | # Create FAISS index 150 | self.index = faiss.IndexFlatIP(d) # Inner product = cosine similarity for normalized vectors 151 | self.index.add(embedding_matrix) 152 | 153 | def sample_concept_list(self, size=3, temperature=0.1, seed_concept=None): 154 | result = [] 155 | 156 | if seed_concept is None: 157 | # Sample based on frequency using numpy 158 | concepts = np.array(self.all_concepts) 159 | freqs = np.array([self.concept_freq[c] for c in concepts]) 160 | probs = freqs / freqs.sum() 161 | seed_concept = np.random.choice(concepts, p=probs) 162 | 163 | result.append(seed_concept) 164 | 165 | # Pre-compute indices for current list 166 | used_indices = {self.concept_to_idx[c] for c in result} 167 | 168 | while len(result) < size: 169 | scores = self.get_next_concept_scores_fast(result, used_indices) 170 | 171 | # Temperature scaling and sampling 172 | scores = np.exp(scores / temperature) 173 | probs = scores / scores.sum() 174 | 175 | next_idx = np.random.choice(len(self.all_concepts), p=probs) 176 | next_concept = self.idx_to_concept[next_idx] 177 | 178 | result.append(next_concept) 179 | used_indices.add(next_idx) 180 | 181 | return result 182 | 183 | def get_next_concept_scores_fast(self, current_list, used_indices): 184 | n = len(self.all_concepts) 185 | scores = np.zeros(n, dtype=np.float32) 186 | 187 | # Compute co-occurrence scores using sparse matrix 188 | current_indices = [self.concept_to_idx[c] for c in current_list] 189 | cooccur_scores = self.cooccur_matrix[current_indices].sum(axis=0).A1 190 | 191 | # Compute similarity scores using FAISS 192 | query = np.mean([self.embeddings[c] for c in current_list], axis=0) 193 | query = query.reshape(1, -1).astype(np.float32) 194 | query /= np.linalg.norm(query) 195 | sim_scores, _ = self.index.search(query, n) 196 | sim_scores = sim_scores[0] 197 | 198 | # Combine scores 199 | scores = cooccur_scores + sim_scores 200 | 201 | # Mask out used concepts 202 | scores[list(used_indices)] = -np.inf 203 | 204 | return scores 205 | 206 | 207 | def thinking_process_generation_prompt(problem, concepts, difficulty_level): 208 | concept_text = "\n".join(f"{i+1}. {concept}" for i, concept in enumerate(concepts)) 209 | prompt = ( 210 | "Imagine you are an expert in educational problem design.\n" 211 | f"You will be shown these components:\n\n" 212 | f"Problem: {problem}\n\n" 213 | f"Foundamental Concepts:\n{concept_text}\n\n" 214 | f"Difficulty Level: {difficulty_level}\n\n" 215 | "Your task is to reverse-engineer a clear thinking process that shows how a teacher might design this problem. This thinking process should:\n" 216 | "- Show how combining the given foundational concepts naturally leads to a problem at the specified difficulty level\n" 217 | "- Include all key decisions and reasoning that shaped the problem design\n" 218 | "- (IMPORTANT) The thinking process must be so precise and detailed that another teacher following these exact steps would recreate the identical problem\n" 219 | "- (IMPORTANT) The thinking process must be so natural and logical that another teacher could derive the same thinking process using only the foundational concepts and difficulty level\n\n" 220 | "Present your answer after 'Thinking Process: ' with the complete step-by-step thinking process described above." 221 | ) 222 | 223 | return prompt 224 | 225 | 226 | def topic_extraction_prompt(problem, num_concepts): 227 | prompt = ( 228 | "As an expert in educational assessment, analyze this problem:\n\n" 229 | f"{problem}\n\n" 230 | f"Break down and identify {num_concepts} foundational concepts being tested. List these knowledge points that:\n" 231 | "- Are core curriculum concepts typically taught in standard courses\n" 232 | "- Are precise and measurable (not vague like 'understanding math')\n" 233 | "- Are essential building blocks needed to solve this problem\n" 234 | "- Represent fundamental principles rather than problem-specific techniques\n\n" 235 | f"Think through your analysis step by step, then format your response as a Python code snippet containing a list of {num_concepts} strings, where each string clearly describes one fundamental knowledge point." 236 | ) 237 | return prompt 238 | 239 | 240 | def get_catch_all_prompt(problem, num_concepts): 241 | prompt = ( 242 | "As an expert in educational assessment, analyze this problem:\n\n" 243 | f"{problem}\n\n" 244 | f"Break down and identify {num_concepts} foundational concepts being tested. List these knowledge points that:\n" 245 | "- Are core curriculum concepts typically taught in standard courses\n" 246 | "- Are precise and measurable (not vague like 'understanding math')\n" 247 | "- Are essential building blocks needed to solve this problem\n" 248 | "- Represent fundamental principles rather than problem-specific techniques\n\n" 249 | f"Return only {num_concepts} lines, each starting with 'Knowledge Point: ' followed by one fundamental concept, without any other text." 250 | ) 251 | return prompt 252 | 253 | 254 | def rationale_judgement_prompt(concepts, level, rationale_and_problem): 255 | concept_text = "\n".join(f"- {concept}" for i, concept in enumerate(concepts)) 256 | prompt = ( 257 | "As a critical expert in educational problem design, evaluate the following problem components:\n\n" 258 | f"=== GIVEN MATERIALS ===\n" 259 | f"1. Problem & Design Rationale:\n{rationale_and_problem}\n" 260 | " (The rationale describes the author's thinking process and justification in designing this problem)\n\n" 261 | f"2. Foundational Concepts:\n{concept_text}\n\n" 262 | f"3. Target Difficulty Level: {level}\n\n" 263 | 264 | "=== EVALUATION CRITERIA ===\n" 265 | "Rate each criterion as: [Perfect | Acceptable | Bad]\n\n" 266 | 267 | "1. FORMAT\n" 268 | "- Verify correct implementation of markup tags:\n" 269 | " [design thinking process] \n" 270 | " [problem] \n\n" 271 | 272 | "2. FACTUAL ACCURACY\n" 273 | "- Check for any incorrect or misleading information in both problem and rationale\n" 274 | "- Verify mathematical, scientific, or logical consistency\n\n" 275 | 276 | "3. DIFFICULTY ALIGNMENT\n" 277 | "- Assess if problem complexity matches the specified difficulty level\n" 278 | "- Evaluate if cognitive demands align with target level\n\n" 279 | 280 | "4. CONCEPT COVERAGE\n" 281 | "- Evaluate how well the problem incorporates the given foundational concepts\n" 282 | "- Check for missing concept applications\n\n" 283 | 284 | "5. SOLVABILITY\n" 285 | "- Verify if the problem has at least one valid solution\n" 286 | "- Check if all necessary information for solving is provided\n\n" 287 | 288 | "=== RESPONSE FORMAT ===\n" 289 | "For each criterion, provide:\n" 290 | "1. Rating: [Perfect | Acceptable | Bad]\n" 291 | "2. Justification: Clear explanation for the rating\n\n" 292 | 293 | "=== FINAL VERDICT ===\n" 294 | "After providing all criterion evaluations, conclude your response with:\n" 295 | "'Final Judgement: [verdict]'\n" 296 | "where verdict must be one of:\n" 297 | "- 'perfect' (if both FACTUAL ACCURACY and SOLVABILITY are Perfect, at least two other criteria are Perfect, and no Bad ratings)\n" 298 | "- 'acceptable' (if no Bad ratings and doesn't qualify for perfect)\n" 299 | "- 'bad' (if ANY Bad ratings)\n\n" 300 | "Note: The 'Final Judgement: [verdict]' line must be the final line of your response." 301 | ) 302 | 303 | return prompt 304 | 305 | 306 | def extract_knowledge_points(response, keyword="Knowledge Point:"): 307 | knowledge_points = [] 308 | for line in response.split("\n"): 309 | if len(line.split(keyword)) >= 2: 310 | knowledge_points.append(line.split(keyword)[-1].strip()) 311 | return knowledge_points 312 | 313 | 314 | def extract_thinking_process(response, keyword="Thinking Process:"): 315 | if len(response.split(keyword)) >= 2: 316 | return response.split(keyword)[-1].strip() 317 | else: 318 | return response 319 | 320 | 321 | def count_ratings(response): 322 | ratings = ["perfect", "acceptable", "bad"] 323 | counts = {rating: response.lower().count(rating) for rating in ratings} 324 | max_count = max(counts.values()) 325 | max_ratings = [rating for rating, count in counts.items() if count == max_count] 326 | return max_ratings 327 | 328 | 329 | def extract_final_judgement(response, keyword="Final Judgement:"): 330 | if len(response.split(keyword)) >= 2: 331 | return response.split(keyword)[-1].strip().lower() 332 | else: 333 | ratings = count_ratings(response) 334 | return ratings[0] 335 | 336 | -------------------------------------------------------------------------------- /PromptCoT/eval/math_equivalence.py: -------------------------------------------------------------------------------- 1 | import re 2 | import signal 3 | from typing import Dict, List, Optional 4 | # from timeout_decorator import timeout 5 | # from timeout_decorator import timeout as timeout_decorator 6 | 7 | try: 8 | import sympy 9 | from sympy.parsing.latex import parse_latex 10 | from sympy.simplify import simplify 11 | 12 | # Set global timeout for SymPy operations 13 | sympy.TIMEOUT = 5 # Set global timeout in seconds 14 | 15 | # Or specifically for simplify operations 16 | simplify.TIMEOUT = 5 # Set timeout for simplify operations 17 | 18 | # If needed, you can also try these settings 19 | sympy.core.cache.NO_CACHE = True # Disable caching 20 | # sympy.core.evalf.maxprec = 1000 # Limit precision 21 | 22 | 23 | except ModuleNotFoundError: 24 | raise ModuleNotFoundError( 25 | "`sympy` is required for generating translation task prompt templates. \ 26 | please install sympy via pip install lm-eval[math] or pip install -e .[math]", 27 | ) 28 | 29 | 30 | def _fix_fracs(string): 31 | substrs = string.split("\\frac") 32 | new_str = substrs[0] 33 | if len(substrs) > 1: 34 | substrs = substrs[1:] 35 | for substr in substrs: 36 | new_str += "\\frac" 37 | if substr[0] == "{": 38 | new_str += substr 39 | else: 40 | try: 41 | assert len(substr) >= 2 42 | except: 43 | return string 44 | a = substr[0] 45 | b = substr[1] 46 | if b != "{": 47 | if len(substr) > 2: 48 | post_substr = substr[2:] 49 | new_str += "{" + a + "}{" + b + "}" + post_substr 50 | else: 51 | new_str += "{" + a + "}{" + b + "}" 52 | else: 53 | if len(substr) > 2: 54 | post_substr = substr[2:] 55 | new_str += "{" + a + "}" + b + post_substr 56 | else: 57 | new_str += "{" + a + "}" + b 58 | string = new_str 59 | return string 60 | 61 | 62 | def _fix_a_slash_b(string): 63 | if len(string.split("/")) != 2: 64 | return string 65 | a = string.split("/")[0] 66 | b = string.split("/")[1] 67 | try: 68 | a = int(a) 69 | b = int(b) 70 | assert string == "{}/{}".format(a, b) 71 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 72 | return new_string 73 | except: 74 | return string 75 | 76 | 77 | def _remove_right_units(string): 78 | # "\\text{ " only ever occurs (at least in the val set) when describing units 79 | if "\\text{ " in string: 80 | splits = string.split("\\text{ ") 81 | assert len(splits) == 2 82 | return splits[0] 83 | else: 84 | return string 85 | 86 | 87 | def _fix_sqrt(string): 88 | if "\\sqrt" not in string: 89 | return string 90 | splits = string.split("\\sqrt") 91 | new_string = splits[0] 92 | for split in splits[1:]: 93 | if split[0] != "{": 94 | a = split[0] 95 | new_substr = "\\sqrt{" + a + "}" + split[1:] 96 | else: 97 | new_substr = "\\sqrt" + split 98 | new_string += new_substr 99 | return new_string 100 | 101 | 102 | def _strip_string(string): 103 | # linebreaks 104 | string = string.replace("\n", "") 105 | # print(string) 106 | 107 | # remove inverse spaces 108 | string = string.replace("\\!", "") 109 | # print(string) 110 | 111 | # replace \\ with \ 112 | string = string.replace("\\\\", "\\") 113 | # print(string) 114 | 115 | # replace tfrac and dfrac with frac 116 | string = string.replace("tfrac", "frac") 117 | string = string.replace("dfrac", "frac") 118 | # print(string) 119 | 120 | # remove \left and \right 121 | string = string.replace("\\left", "") 122 | string = string.replace("\\right", "") 123 | # print(string) 124 | 125 | # Remove circ (degrees) 126 | string = string.replace("^{\\circ}", "") 127 | string = string.replace("^\\circ", "") 128 | 129 | # remove dollar signs 130 | string = string.replace("\\$", "") 131 | 132 | # remove units (on the right) 133 | string = _remove_right_units(string) 134 | 135 | # remove percentage 136 | string = string.replace("\\%", "") 137 | string = string.replace("\%", "") 138 | 139 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 140 | string = string.replace(" .", " 0.") 141 | string = string.replace("{.", "{0.") 142 | # if empty, return empty string 143 | if len(string) == 0: 144 | return string 145 | if string[0] == ".": 146 | string = "0" + string 147 | 148 | # to consider: get rid of e.g. "k = " or "q = " at beginning 149 | if len(string.split("=")) == 2: 150 | if len(string.split("=")[0]) <= 2: 151 | string = string.split("=")[1] 152 | 153 | # fix sqrt3 --> sqrt{3} 154 | string = _fix_sqrt(string) 155 | 156 | # remove spaces 157 | string = string.replace(" ", "") 158 | 159 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 160 | string = _fix_fracs(string) 161 | 162 | # manually change 0.5 --> \frac{1}{2} 163 | if string == "0.5": 164 | string = "\\frac{1}{2}" 165 | 166 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 167 | string = _fix_a_slash_b(string) 168 | 169 | return string 170 | 171 | 172 | def is_equiv(str1, str2, verbose=False): 173 | if str1 is None and str2 is None: 174 | print("WARNING: Both None") 175 | return True 176 | if str1 is None or str2 is None: 177 | return False 178 | 179 | try: 180 | ss1 = _strip_string(str1) 181 | ss2 = _strip_string(str2) 182 | if verbose: 183 | print(ss1, ss2) 184 | return ss1 == ss2 185 | except: 186 | return str1 == str2 187 | 188 | 189 | class timeout: 190 | def __init__(self, seconds=1, error_message="Timeout"): 191 | self.seconds = seconds 192 | self.error_message = error_message 193 | 194 | def handle_timeout(self, signum, frame): 195 | raise TimeoutError(self.error_message) 196 | 197 | def __enter__(self): 198 | signal.signal(signal.SIGALRM, self.handle_timeout) 199 | signal.alarm(self.seconds) 200 | 201 | def __exit__(self, type, value, traceback): 202 | signal.alarm(0) 203 | 204 | def is_equiv_minerva(x1: str, x2: str) -> bool: 205 | """ 206 | x1 and x2 are normalized latex string 207 | """ 208 | try: 209 | with timeout(seconds=5): 210 | try: 211 | parsed_x1 = parse_latex(normalize_final_answer(x1)) 212 | parsed_x2 = parse_latex(normalize_final_answer(x2)) 213 | except ( 214 | sympy.parsing.latex.errors.LaTeXParsingError, 215 | sympy.SympifyError, 216 | TypeError, 217 | AttributeError, 218 | ): 219 | # print(f"couldn't parse one of {x1} or {x2}") 220 | return False 221 | 222 | try: 223 | diff = parsed_x1 - parsed_x2 224 | except TypeError: 225 | # print(f"couldn't subtract {x1} and {x2}") 226 | return False 227 | 228 | try: 229 | if sympy.simplify(diff) == 0: 230 | return True 231 | else: 232 | return False 233 | except ValueError: 234 | # print( 235 | # f"Had some trouble simplifying when comparing {x1} and {x2}" 236 | # ) 237 | return False 238 | except TimeoutError: 239 | # print(f"Timed out comparing {x1} and {x2}") 240 | return False 241 | except ImportError as e: 242 | print(e) 243 | raise 244 | except Exception as e: 245 | # print(f"Failed comparing {x1} and {x2} with {e}") 246 | return False 247 | 248 | 249 | SUBSTITUTIONS = [ 250 | ("an ", ""), 251 | ("a ", ""), 252 | (".$", "$"), 253 | ("\\$", ""), 254 | (r"\ ", ""), 255 | (" ", ""), 256 | ("mbox", "text"), 257 | (",\\text{and}", ","), 258 | ("\\text{and}", ","), 259 | ("\\text{m}", "\\text{}"), 260 | ] 261 | REMOVED_EXPRESSIONS = [ 262 | "square", 263 | "ways", 264 | "integers", 265 | "dollars", 266 | "mph", 267 | "inches", 268 | "ft", 269 | "hours", 270 | "km", 271 | "units", 272 | "\\ldots", 273 | "sue", 274 | "points", 275 | "feet", 276 | "minutes", 277 | "digits", 278 | "cents", 279 | "degrees", 280 | "cm", 281 | "gm", 282 | "pounds", 283 | "meters", 284 | "meals", 285 | "edges", 286 | "students", 287 | "childrentickets", 288 | "multiples", 289 | "\\text{s}", 290 | "\\text{.}", 291 | "\\text{\ns}", 292 | "\\text{}^2", 293 | "\\text{}^3", 294 | "\\text{\n}", 295 | "\\text{}", 296 | r"\mathrm{th}", 297 | r"^\circ", 298 | r"^{\circ}", 299 | r"\;", 300 | r",\!", 301 | "{,}", 302 | '"', 303 | "\\dots", 304 | ] 305 | 306 | 307 | def normalize_final_answer(final_answer: str) -> str: 308 | """ 309 | Normalize a final answer to a quantitative reasoning question. 310 | 311 | Copied character for character from appendix D of Lewkowycz et al. (2022) 312 | """ 313 | final_answer = final_answer.split("=")[-1] 314 | 315 | for before, after in SUBSTITUTIONS: 316 | final_answer = final_answer.replace(before, after) 317 | for expr in REMOVED_EXPRESSIONS: 318 | final_answer = final_answer.replace(expr, "") 319 | 320 | # Extract answer that is in LaTeX math, is bold, 321 | # is surrounded by a box, etc. 322 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 323 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 324 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 325 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 326 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 327 | 328 | # Normalize shorthand TeX: 329 | # \fracab -> \frac{a}{b} 330 | # \frac{abc}{bef} -> \frac{abc}{bef} 331 | # \fracabc -> \frac{a}{b}c 332 | # \sqrta -> \sqrt{a} 333 | # \sqrtab -> sqrt{a}b 334 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 335 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 336 | final_answer = final_answer.replace("$", "") 337 | 338 | # Normalize 100,000 -> 100000 339 | if final_answer.replace(",", "").isdigit(): 340 | final_answer = final_answer.replace(",", "") 341 | 342 | return final_answer -------------------------------------------------------------------------------- /PromptCoT/eval/util.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | 3 | def remove_boxed(s): 4 | left = "\\boxed{" 5 | try: 6 | assert s[:len(left)] == left 7 | assert s[-1] == "}" 8 | return s[len(left):-1] 9 | except: 10 | return None 11 | 12 | def last_boxed_only(sample): 13 | """ 14 | Given a (q,a) sample, filter the answers so that they only contain 15 | the last \boxed{...} or \fbox{...} element 16 | """ 17 | q, a = sample 18 | a = last_boxed_only_string(a) 19 | if a == None: 20 | return None 21 | return (q, a) 22 | 23 | 24 | def last_boxed_only_string(string): 25 | idx = string.rfind("\\boxed") 26 | if idx < 0: 27 | idx = string.rfind("\\fbox") 28 | if idx < 0: 29 | return None 30 | 31 | i = idx 32 | right_brace_idx = None 33 | num_left_braces_open = 0 34 | while i < len(string): 35 | if string[i] == "{": 36 | num_left_braces_open += 1 37 | if string[i] == "}": 38 | num_left_braces_open -= 1 39 | if num_left_braces_open == 0: 40 | right_brace_idx = i 41 | break 42 | i += 1 43 | 44 | if right_brace_idx == None: 45 | retval = None 46 | else: 47 | retval = string[idx:right_brace_idx + 1] 48 | 49 | return retval 50 | 51 | 52 | def first_boxed_only_string(string): 53 | # Find the first occurrence of \boxed or \fbox 54 | idx_boxed = string.find("\\boxed") 55 | idx_fbox = string.find("\\fbox") 56 | 57 | # Determine which comes first (if either exists) 58 | if idx_boxed < 0 and idx_fbox < 0: 59 | return None 60 | elif idx_boxed < 0: 61 | idx = idx_fbox 62 | elif idx_fbox < 0: 63 | idx = idx_boxed 64 | else: 65 | idx = min(idx_boxed, idx_fbox) 66 | 67 | # Find matching closing brace 68 | i = idx 69 | right_brace_idx = None 70 | num_left_braces_open = 0 71 | 72 | while i < len(string): 73 | if string[i] == "{": 74 | num_left_braces_open += 1 75 | elif string[i] == "}": 76 | num_left_braces_open -= 1 77 | if num_left_braces_open == 0: 78 | right_brace_idx = i 79 | break 80 | i += 1 81 | 82 | if right_brace_idx is None: 83 | return None 84 | 85 | return string[idx:right_brace_idx + 1] 86 | 87 | 88 | def only_until_first_boxed_from_tokens(string, tokens): 89 | idx = string.find("\\boxed") 90 | if idx < 0: 91 | idx = string.find("\\fbox") 92 | if idx < 0: 93 | return None 94 | 95 | cum_length = 0 96 | for i, t in enumerate(tokens): 97 | cum_length += len(t) 98 | if cum_length >= idx: 99 | break 100 | 101 | return tokens[:i] 102 | 103 | 104 | def clean_numbers(sample): 105 | if not sample: 106 | return None 107 | new_sample = list() 108 | for s in sample: 109 | new_sample.append(_clean_numbers(s)) 110 | 111 | return tuple(new_sample) 112 | 113 | 114 | def _clean_numbers(string): 115 | """ 116 | Clean Numbers in the given string 117 | 118 | >>> _clean_numbers(None, "Hello 123") 119 | 'Hello 123' 120 | >>> _clean_numbers(None, "Hello 1234") 121 | 'Hello 1,234' 122 | >>> _clean_numbers(None, "Hello 1234324asdasd") 123 | 'Hello 1,234,324asdasd' 124 | """ 125 | num_prev_digits = 0 126 | new_string = "" 127 | for i, c in enumerate(string): 128 | # isdigit() doesnt work here because of weird unicode chars. 129 | if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}: 130 | num_prev_digits += 1 131 | else: 132 | if num_prev_digits > 3: 133 | # Some fixing 134 | string_number = new_string[-num_prev_digits:] 135 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 136 | num_prev_digits = 0 137 | new_string += c 138 | 139 | if num_prev_digits > 3: 140 | # Some fixing 141 | string_number = new_string[-num_prev_digits:] 142 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 143 | 144 | return new_string -------------------------------------------------------------------------------- /PromptCoT/infer_longcot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | from vllm import LLM, SamplingParams 5 | from transformers import AutoTokenizer 6 | from str2bool import str2bool 7 | import os 8 | import re 9 | 10 | 11 | def is_valid_think_format(text): 12 | # Check basic pattern first 13 | pattern = r'^(.*?)(.+)$' # Note the .+ for non-empty after content 14 | match = re.match(pattern, text, re.DOTALL) 15 | 16 | if not match: 17 | return False 18 | 19 | # Extract the content inside and after the think tags 20 | inside_content = match.group(1) 21 | after_content = match.group(2) 22 | 23 | # Verify inside content is not empty 24 | if not inside_content.strip(): 25 | return False 26 | 27 | # Verify neither part contains additional think tags 28 | if '' in inside_content or '' in inside_content: 29 | return False 30 | 31 | if '' in after_content or '' in after_content: 32 | return False 33 | 34 | return True 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Evaluate large language models on critical datasets.") 39 | parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file.") 40 | parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.") 41 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.") 42 | parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to the pretrained model.") 43 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type to use for the model (e.g., fp16, bf16, etc.).") 44 | parser.add_argument("--n_gpus", type=int, default=8, help="Number of GPUs to use for tensor parallelism.") 45 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for generation.") 46 | parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling for generation.") 47 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 48 | parser.add_argument("--max_len", type=int, default=32768, help="Maximum number of tokens to generate.") 49 | parser.add_argument("--use_chat_template", type=str2bool, default=False) 50 | parser.add_argument("--n", type=int, default=8) 51 | parser.add_argument("--max_retries", type=int, default=8) 52 | 53 | args = parser.parse_args() 54 | 55 | if args.tokenizer_path is None: 56 | args.tokenizer_path = args.model_path 57 | 58 | # Load the tokenizer for LLaMA or any model 59 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 60 | 61 | # Load inference framework 62 | model = LLM( 63 | model=args.model_path, 64 | tokenizer=args.tokenizer_path, 65 | tokenizer_mode="slow", 66 | dtype=args.dtype, 67 | tensor_parallel_size=args.n_gpus, 68 | enforce_eager=True, 69 | ) 70 | 71 | items = [] 72 | completions = [] 73 | seed = 0 74 | for _ in range(args.n): 75 | prompts = [] 76 | with open(args.data_path, encoding="utf-8") as f: 77 | for line in f.readlines(): 78 | item = json.loads(line) 79 | prompt = item["prompt"] 80 | if args.use_chat_template: 81 | messages = [ 82 | {"role": "user", "content": prompt} 83 | ] 84 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 85 | prompts.append(prompt) 86 | items.append(item) 87 | 88 | with torch.no_grad(): 89 | # Initialize with None for each prompt position 90 | output_lst = [None] * len(prompts) 91 | pending_prompts = prompts.copy() 92 | pending_indices = list(range(len(prompts))) 93 | 94 | retry_count = 0 95 | 96 | while pending_prompts and retry_count < args.max_retries: 97 | sampling_params = SamplingParams( 98 | temperature=args.temperature, 99 | top_p=args.top_p, 100 | max_tokens=args.max_len, 101 | repetition_penalty=args.repetition_penalty, 102 | seed=seed, 103 | ) 104 | seed += 1 105 | 106 | # Generate completions for remaining prompts 107 | batch_outputs = model.generate(pending_prompts, sampling_params) 108 | batch_texts = [completion.outputs[0].text for completion in batch_outputs] 109 | 110 | # Process current batch results 111 | still_pending_prompts = [] 112 | still_pending_indices = [] 113 | 114 | for i, (idx, text) in enumerate(zip(pending_indices, batch_texts)): 115 | if is_valid_think_format(text): 116 | # If valid, add to results at original position 117 | output_lst[idx] = text 118 | else: 119 | # If invalid, keep for retry 120 | still_pending_prompts.append(pending_prompts[i]) 121 | still_pending_indices.append(idx) 122 | # Store the invalid output in case we reach max retries 123 | output_lst[idx] = text 124 | 125 | # Update for next iteration 126 | pending_prompts = still_pending_prompts 127 | pending_indices = still_pending_indices 128 | retry_count += 1 129 | 130 | completions.extend(output_lst) 131 | 132 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 133 | with open(args.output_path, "w", encoding="utf-8") as f: 134 | for item, completion in zip(items, completions): 135 | item["completion"] = completion 136 | f.write(json.dumps(item) + "\n") 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /PromptCoT/problem_filtering.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import List, Dict, Tuple 4 | from dataclasses import dataclass 5 | from str2bool import str2bool 6 | 7 | 8 | @dataclass 9 | class ProcessedItem: 10 | prompt: str 11 | completion: str 12 | 13 | 14 | def parse_judgement(judgement_text: str) -> str: 15 | judgement_text = judgement_text.lower() 16 | if judgement_text.startswith("perfect") or "perfect" in judgement_text: 17 | return "perfect" 18 | elif judgement_text.startswith("acceptable") or "acceptable" in judgement_text: 19 | return "acceptable" 20 | return "bad" 21 | 22 | 23 | def load_and_process_file(file_path: str) -> Tuple[List[Dict], List[str]]: 24 | items = [] 25 | ratings = [] 26 | with open(file_path, encoding="utf-8") as f: 27 | for line in f.readlines(): 28 | json_obj = json.loads(line) 29 | judgement = parse_judgement(json_obj["judgement"]) 30 | assert judgement in ["perfect", "acceptable", "bad"] 31 | ratings.append(judgement) 32 | items.append(json_obj) 33 | return items, ratings 34 | 35 | 36 | def process_items(items: List[Dict], ratings_list: List[List[str]], 37 | only_perfect: bool) -> List[ProcessedItem]: 38 | processed_items = [] 39 | n_rewards = len(ratings_list) 40 | 41 | for idx in range(len(items)): 42 | all_perfect = all( 43 | ratings_list[reward_idx][idx] == "perfect" 44 | for reward_idx in range(n_rewards) 45 | ) 46 | has_perfect_no_bad = ( 47 | any(ratings_list[reward_idx][idx] == "perfect" 48 | for reward_idx in range(n_rewards)) and 49 | all(ratings_list[reward_idx][idx] != "bad" 50 | for reward_idx in range(n_rewards)) 51 | ) 52 | 53 | if (only_perfect and all_perfect) or (not only_perfect and has_perfect_no_bad): 54 | processed_items.append(ProcessedItem( 55 | prompt=items[idx]["prompt"], 56 | completion=items[idx]["rationale_and_problem"] 57 | )) 58 | 59 | return processed_items 60 | 61 | 62 | def main(): 63 | parser = argparse.ArgumentParser(description='Process and filter completion data') 64 | parser.add_argument('--template', type=str, 65 | required=True, 66 | help='Template for input files') 67 | parser.add_argument('--output_path', type=str, 68 | required=True, 69 | help='Path for output file') 70 | parser.add_argument('--only_perfect', type=str2bool, 71 | default=True, 72 | help='Only include items rated as perfect by all rewards') 73 | parser.add_argument('--n_rewards', type=int, default=2, 74 | help='Number of reward models') 75 | args = parser.parse_args() 76 | 77 | items_list = [] 78 | ratings_list = [] 79 | 80 | # Load and process files for each reward model 81 | for reward_idx in range(args.n_rewards): 82 | file_path = args.template.format(reward_idx) 83 | items, ratings = load_and_process_file(file_path) 84 | if reward_idx == 0: 85 | items_list = items 86 | ratings_list.append(ratings) 87 | 88 | # Verify consistency 89 | assert all(len(ratings) == len(items_list) 90 | for ratings in ratings_list) 91 | 92 | # Process items 93 | processed_items = process_items( 94 | items_list, ratings_list, 95 | args.only_perfect 96 | ) 97 | final_items = [vars(item) for item in processed_items] 98 | 99 | # Save results 100 | with open(args.output_path, "w", encoding="utf-8") as f: 101 | for item in final_items: 102 | f.write(json.dumps(item) + "\n") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() -------------------------------------------------------------------------------- /PromptCoT/problem_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | from vllm import LLM, SamplingParams 5 | from transformers import AutoTokenizer 6 | from str2bool import str2bool 7 | import os 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file.") 13 | parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.") 14 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.") 15 | parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to the pretrained model.") 16 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type to use for the model (e.g., fp16, bf16, etc.).") 17 | parser.add_argument("--n_gpus", type=int, default=8, help="Number of GPUs to use for tensor parallelism.") 18 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for generation.") 19 | parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling for generation.") 20 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 21 | parser.add_argument("--max_len", type=int, default=2048, help="Maximum number of tokens to generate.") 22 | parser.add_argument("--use_chat_template", type=str2bool, default=False) 23 | parser.add_argument("--seed", type=int, default=42) 24 | 25 | args = parser.parse_args() 26 | 27 | if args.tokenizer_path is None: 28 | args.tokenizer_path = args.model_path 29 | 30 | # Load the tokenizer for LLaMA or any model 31 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 32 | 33 | # Load inference framework 34 | model = LLM( 35 | model=args.model_path, 36 | tokenizer=args.tokenizer_path, 37 | tokenizer_mode="slow", 38 | dtype=args.dtype, 39 | tensor_parallel_size=args.n_gpus, 40 | enforce_eager=True, 41 | ) 42 | 43 | # Setup sampling parameters for model generation 44 | sampling_params = SamplingParams( 45 | temperature=args.temperature, 46 | top_p=args.top_p, 47 | max_tokens=args.max_len, 48 | repetition_penalty=args.repetition_penalty, 49 | seed=args.seed, 50 | ) 51 | 52 | prompts = [] 53 | items = [] 54 | with open(args.data_path, encoding="utf-8") as f: 55 | for line in f.readlines(): 56 | item = json.loads(line) 57 | prompt = item["prompt"] 58 | if args.use_chat_template: 59 | messages = [ 60 | {"role": "user", "content": prompt} 61 | ] 62 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 63 | prompts.append(prompt) 64 | items.append(item) 65 | 66 | with torch.no_grad(): 67 | completions = model.generate(prompts, sampling_params) 68 | completions = [completion.outputs[0].text for completion in completions] 69 | 70 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 71 | with open(args.output_path, "w", encoding="utf-8") as f: 72 | for item, completion in zip(items, completions): 73 | item["completion"] = completion 74 | f.write(json.dumps(item) + "\n") 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /PromptCoT/rejection_sampling_reward.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import os 3 | from transformers import AutoTokenizer 4 | from vllm import LLM, SamplingParams 5 | import json 6 | import torch 7 | from data_synthesis import rationale_judgement_prompt, extract_final_judgement 8 | 9 | 10 | def main( 11 | data_path, 12 | output_path, 13 | model_path, 14 | dtype="bfloat16", 15 | n_gpus=8, 16 | temperature=0.0, 17 | max_len=2048, 18 | use_chat_template=False, 19 | seed=42, 20 | ): 21 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 22 | 23 | tokenizer = AutoTokenizer.from_pretrained(model_path) 24 | 25 | model = LLM( 26 | model=model_path, 27 | tokenizer=model_path, 28 | tokenizer_mode="slow", 29 | dtype=dtype, 30 | tensor_parallel_size=n_gpus, 31 | seed=seed, 32 | ) 33 | sampling_params = SamplingParams( 34 | temperature=temperature, 35 | max_tokens=max_len, 36 | ) 37 | 38 | prompts = [] 39 | items = [] 40 | rationales = [] 41 | 42 | with open(data_path, encoding="utf-8") as f: 43 | for line in f.readlines(): 44 | item = json.loads(line) 45 | foundational_concepts = item["foundational_concepts"] 46 | level = item["level"] 47 | rationale_and_problem = item["completion"] 48 | prompt = rationale_judgement_prompt( 49 | concepts=foundational_concepts, 50 | level=level, 51 | rationale_and_problem=rationale_and_problem, 52 | ) 53 | if use_chat_template: 54 | prompt = tokenizer.apply_chat_template( 55 | [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) 56 | prompts.append(prompt) 57 | items.append(item) 58 | rationales.append(rationale_and_problem) 59 | 60 | with torch.no_grad(): 61 | completions = model.generate(prompts, sampling_params) 62 | completions = [completion.outputs[0].text for completion in completions] 63 | 64 | for item, rationale, completion in zip(items, rationales, completions): 65 | final_judgement = extract_final_judgement(completion) 66 | item["rationale_and_problem"] = rationale 67 | item["judgement"] = final_judgement 68 | item["completion"] = completion 69 | 70 | with open(output_path, "w", encoding="utf-8") as f: 71 | for item in items: 72 | f.write(json.dumps(item) + "\n") 73 | 74 | 75 | if __name__ == "__main__": 76 | fire.Fire(main) -------------------------------------------------------------------------------- /PromptCoT/requirements.txt: -------------------------------------------------------------------------------- 1 | vllm==0.4.3 2 | aiohttp==3.10.5 3 | aiohttp-jinja2==1.6 4 | cmake==3.30.3 5 | fastapi==0.115.8 6 | filelock==3.13.1 7 | lm-format-enforcer==0.10.1 8 | ninja==1.11.1.1 9 | numpy==1.23.5 10 | nvidia-ml-py==12.560.30 11 | openai==1.47.0 12 | outlines==0.0.34 13 | prometheus-fastapi-instrumentator==7.0.0 14 | psutil==5.9.0 15 | py-cpuinfo==9.0.0 16 | pydantic==2.10.6 17 | pydantic_core==2.27.2 18 | ray==2.36.1 19 | requests==2.32.3 20 | requests-oauthlib==2.0.0 21 | sentencepiece==0.2.0 22 | tiktoken==0.9.0 23 | tokenizers==0.21.0 24 | torch==2.3.0 25 | torchaudio==2.3.0 26 | torchdata==0.6.1 27 | torchtext==0.18.0 28 | torchvision==0.18.0 29 | transformers==4.49.0 30 | uvicorn==0.34.0 31 | xformers==0.0.26 -------------------------------------------------------------------------------- /PromptCoT/train.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import random 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Dict, Sequence 6 | 7 | import torch 8 | import datetime 9 | 10 | torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=12000)) 11 | 12 | 13 | import transformers 14 | from datasets import load_dataset 15 | 16 | IGNORE_INDEX = -100 17 | 18 | 19 | @dataclass 20 | class ModelArguments: 21 | model_name_or_path: Optional[str] = field(default="/ossfs/workspace/nas/xueliang/hf_models/Meta-Llama-3.1-8B") 22 | tokenizer_path: Optional[str] = field(default="/ossfs/workspace/nas/xueliang/hf_models/Meta-Llama-3.1-8B") 23 | 24 | 25 | @dataclass 26 | class DataArguments: 27 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 28 | 29 | 30 | @dataclass 31 | class TrainingArguments(transformers.TrainingArguments): 32 | cache_dir: Optional[str] = field(default=None) 33 | optim: str = field(default="adamw_torch") 34 | model_max_length: int = field( 35 | default=4096, 36 | metadata={"help": "Maximum sequence length."} 37 | ) 38 | resume_from_checkpoint: Optional[bool] = field(default=None) 39 | 40 | 41 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 42 | tokenized_list = [ 43 | tokenizer( 44 | text, 45 | return_tensors="pt", 46 | padding="longest", 47 | max_length=tokenizer.model_max_length, 48 | truncation=True, 49 | ) 50 | for text in strings 51 | ] 52 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 53 | input_ids_lens = labels_lens = [ 54 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 55 | ] 56 | return dict( 57 | input_ids=input_ids, 58 | labels=labels, 59 | input_ids_lens=input_ids_lens, 60 | labels_lens=labels_lens, 61 | ) 62 | 63 | 64 | def preprocess( 65 | sources: Sequence[str], 66 | targets: Sequence[str], 67 | tokenizer: transformers.PreTrainedTokenizer, 68 | ) -> Dict: 69 | examples = [s + t for s, t in zip(sources, targets)] 70 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 71 | input_ids = examples_tokenized["input_ids"] 72 | labels = copy.deepcopy(input_ids) 73 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 74 | label[:source_len] = IGNORE_INDEX 75 | return dict(input_ids=input_ids, labels=labels) 76 | 77 | 78 | @dataclass 79 | class DataCollatorForSuperviseDataset(object): 80 | tokenizer: transformers.PreTrainedTokenizer 81 | 82 | def __call__(self, items: Sequence[Dict]) -> Dict[str, torch.Tensor]: 83 | input_ids, labels = tuple([item[key] for item in items] for key in ("input_ids", "labels")) 84 | input_ids = [torch.tensor(x) for x in input_ids] 85 | input_ids = torch.nn.utils.rnn.pad_sequence( 86 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 87 | ) 88 | labels = [torch.tensor(x) for x in labels] 89 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 90 | return dict( 91 | input_ids=input_ids, 92 | labels=labels, 93 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id) 94 | ) 95 | 96 | 97 | def train_tokenize_function(examples, tokenizer): 98 | sources = [prompt for prompt in examples["prompt"]] 99 | targets = [f"{output}{tokenizer.eos_token}" for output in examples["completion"]] 100 | return preprocess(sources, targets, tokenizer) 101 | 102 | 103 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 104 | state_dict = trainer.model.state_dict() 105 | if trainer.args.should_save: 106 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 107 | del state_dict 108 | trainer._save(output_dir, state_dict=cpu_state_dict) 109 | 110 | 111 | def train(): 112 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 113 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 114 | 115 | model = transformers.AutoModelForCausalLM.from_pretrained( 116 | model_args.model_name_or_path, 117 | torch_dtype=torch.bfloat16, 118 | ) 119 | tokenizer = transformers.AutoTokenizer.from_pretrained( 120 | model_args.tokenizer_path, 121 | model_max_length=training_args.model_max_length, 122 | use_fast=False 123 | ) 124 | if tokenizer.pad_token is None: 125 | tokenizer.pad_token = tokenizer.eos_token 126 | tokenizer.pad_token_id = tokenizer.eos_token_id 127 | 128 | raw_train_dataset = load_dataset( 129 | "json", 130 | data_files=data_args.data_path, 131 | split="train", 132 | ) 133 | 134 | if training_args.local_rank > 0: 135 | torch.distributed.barrier() 136 | 137 | train_dataset = raw_train_dataset.map( 138 | train_tokenize_function, 139 | batched=True, 140 | batch_size=4096, 141 | num_proc=64, 142 | remove_columns=raw_train_dataset.column_names, 143 | desc="Running tokenizer on train dataset", 144 | fn_kwargs={ 145 | "tokenizer": tokenizer, 146 | } 147 | ) 148 | 149 | if training_args.local_rank == 0: 150 | torch.distributed.barrier() 151 | 152 | if training_args.local_rank == 0: 153 | print(len(train_dataset)) 154 | for index in random.sample(range(len(train_dataset)), 3): 155 | print(f"Sample {index} of the training set: {train_dataset[index]}.") 156 | 157 | data_collator = DataCollatorForSuperviseDataset(tokenizer=tokenizer) 158 | trainer = transformers.Trainer( 159 | model=model, 160 | tokenizer=tokenizer, 161 | args=training_args, 162 | train_dataset=train_dataset, 163 | data_collator=data_collator, 164 | ) 165 | 166 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 167 | trainer.save_state() 168 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 169 | 170 | 171 | if __name__ == "__main__": 172 | train() 173 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/README.md: -------------------------------------------------------------------------------- 1 | # **Scaling Reasoning without Attention** 2 | 3 | --- 4 | 5 | ## 🚀 Overview 6 | 7 | **PromptCoT-Mamba** establishes the first **attention-free foundation model** capable of surpassing strong Transformer baselines across a broad suite of competition-level math and code reasoning tasks. Built on the **Mamba-2** architecture and trained through a structured, two-stage curriculum using the [**PromptCoT**](http://arxiv.org/abs/2503.02324) pipeline, it delivers **high accuracy with constant-memory inference**, eliminating the need for KV caching. 8 | 9 | 10 | 11 | ## 📈 Key Results 12 | 13 | ### 🔹 General Performance 14 | 15 | | Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench | 16 | | ---------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- | 17 | | **PromptCoT-Mamba-7B** | 84.6 | **35.2 🔥** | **24.6 🔥** | 50.7 | 81.7 | 75.0 | **29.9🔥** | 18 | | Gemma3-27B | **89.0** | 32.6 | 24.0 | **54.2** | **86.0** | **78.0** | 26.9 | 19 | | Gemma3-12B | 83.8 | 22.9 | 19.2 | 49.9 | 81.1 | 73.2 | 22.2 | 20 | | Sky-T1-7B | 85.0 | 19.2 | 19.2 | 49.2 | 41.5 | 37.2 | 18.3 | 21 | | S1.1-7B | 82.0 | 19.2 | 17.5 | 43.1 | 64.0 | 56.7 | 13.3 | 22 | | Bespoke-Stratos-7B | 81.2 | 18.3 | 16.3 | 45.0 | 73.2 | 68.3 | 8.6 | 23 | | Nemotron-H-8B | 77.6 | -- | -- | -- | 79.3 | 74.4 | -- | 24 | | M1-3B | 81.7 | 23.0 | 22.0 | 43.6 | -- | -- | -- | 25 | 26 | > 🔍 **PromptCoT-Mamba-7B** consistently outperforms all 7B-scale Transformer and hybrid Mamba-Transformer baselines across all tasks. 27 | 28 | --- 29 | 30 | ### 🔹 Math Specialization vs. Generalist 31 | 32 | | Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench | 33 | | --------------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- | 34 | | **PromptCoT-Mamba-Math-7B** | **88.0 🔥** | **42.9 🔥** | **30.8 🔥** | **52.1 🔥** | 71.3 | 66.5 | 20.3 | 35 | | PromptCoT-Mamba-7B | 84.6 | 35.2 | 24.6 | 50.7 | **81.7** | **75.0** | **29.9** | 36 | 37 | > 🎯 The math-specialized variant improves AIME 24 by **+7.7%** and AIME 25 by **+6.2%**, with a slight trade-off in code-related performance. 38 | 39 | --- 40 | 41 | ### ⚡ Inference Efficiency 42 | 43 | Using `vLLM` under constrained memory, PromptCoT-Mamba-7B demonstrates substantial speedups over the S1.1-7B Transformer baseline: 44 | 45 | * 💡 **3.66× faster** at long-sequence generation on **24GB GPU** 46 | * 💡 **1.69× faster** under **72GB memory** 47 | 48 | > ⚙️ Practical for cost-sensitive or long-context inference workloads at scale. 49 | 50 | 51 | ## 🧪 Quick Start 52 | 53 | ### Install Requirements 54 | 55 | ```bash 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | ### Training with DeepSpeed 60 | 61 | ```bash 62 | deepspeed --num_gpus=8 train.py \ 63 | --adam_beta1=0.9 \ 64 | --adam_beta2=0.95 \ 65 | --bf16=True \ 66 | --data_path=/path/to/sft_data \ 67 | --ddp_find_unused_parameters=False \ 68 | --deepspeed=configs/promptcot_mamba_7b_config.json \ 69 | --fp16=False \ 70 | --gradient_accumulation_steps=8 \ 71 | --gradient_checkpointing=True \ 72 | --learning_rate=5e-06 \ 73 | --max_grad_norm=1.0 \ 74 | --max_length=20480 \ 75 | --model_name_or_path=/path/to/PromptCoT-Mamba-7B \ 76 | --num_train_epochs=1 \ 77 | --output_dir=/path/to/output_dir \ 78 | --per_device_train_batch_size=1 \ 79 | --save_steps=1000 \ 80 | --save_strategy=steps \ 81 | --save_total_limit=10 \ 82 | --tokenizer_path=/path/to/PromptCoT-Mamba-7B \ 83 | --warmup_steps=100 \ 84 | --weight_decay=0.01 85 | ``` 86 | 87 | ### AIME Test 88 | ```bash 89 | python infer_longcot.py \ 90 | --data_path data/aime_test.jsonl \ 91 | --output_path data/aime_test_predictions.jsonl \ 92 | --model_path /path/to/PromptCoT-Mamba-7B \ 93 | --n_gpus 1 \ 94 | --temperature 0.8 \ 95 | --repetition_penalty 1.1 \ 96 | --max_len 65536 \ 97 | --n 16 \ 98 | --max_retries 1 \ 99 | --use_mamba2 True 100 | 101 | python calc_acc_aime.py \ 102 | --input_path data/aime_test_predictions.jsonl \ 103 | --expected_runs 16 104 | ``` 105 | 106 | ### LiveCodeBench Test 107 | 108 | ```bash 109 | python infer_longcot.py \ 110 | --data_path data/livecodebench_test.jsonl \ 111 | --output_path data/livecodebench_test_predictions.jsonl \ 112 | --model_path /path/to/PromptCoT-Mamba-7B \ 113 | --n_gpus 1 \ 114 | --temperature 0.8 \ 115 | --repetition_penalty 1.1 \ 116 | --max_len 65536 \ 117 | --n 8 \ 118 | --max_retries 1 \ 119 | --use_mamba2 True 120 | 121 | python calc_acc_lcb.py \ 122 | --input_path data/livecodebench_test_predictions.jsonl \ 123 | --cache_path cache/livecodebench_test_predictions.jsonl 124 | ``` 125 | 126 | 127 | ## 📜 Citation 128 | 129 | ```bibtex 130 | @article{zhao2025scaling, 131 | author = {Xueliang Zhao and Wei Wu and Lingpeng Kong}, 132 | title = {Scaling Reasoning without Attention}, 133 | journal = {arXiv preprint arXiv:2505.22425}, 134 | year = {2025}, 135 | url = {https://arxiv.org/abs/2505.22425} 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/calc_acc_aime.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | import numpy as np 5 | import re 6 | from collections import defaultdict 7 | 8 | from math_opensource_utils.math_equivalence import is_equiv_minerva as is_equiv 9 | from math_opensource_utils.util import last_boxed_only_string, first_boxed_only_string, remove_boxed 10 | from math_opensource_utils.qwen_math import math_equal, extract_answer, strip_string 11 | 12 | 13 | def extract_parentheses(text): 14 | pattern = r'The correct answer is \((.*?)\)' 15 | match = re.search(pattern, text) 16 | if match: 17 | return match.group(1) 18 | return '' 19 | 20 | 21 | def is_valid(completion): 22 | if "" not in completion or "" not in completion: 23 | return False 24 | think = completion.split("")[1].split("")[0] 25 | solution = completion.split("")[1] 26 | final_answer = remove_boxed(last_boxed_only_string(solution)) 27 | if final_answer is None: 28 | return False 29 | if final_answer.strip() == "": 30 | return False 31 | return True 32 | 33 | 34 | def calculate_stats(results): 35 | """Calculate mean and standard error from a list of binary outcomes""" 36 | if not results: 37 | return 0.0, 0.0 38 | 39 | mean = np.mean(results) 40 | std_err = np.std(results, ddof=1) / np.sqrt(len(results)) if len(results) > 1 else 0.0 41 | 42 | return mean, std_err 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser(description="Evaluate large language models on critical datasets.") 47 | parser.add_argument("--input_path", type=str, required=True, help="Directory to store cached outputs.") 48 | parser.add_argument("--expected_runs", type=int, default=8) 49 | 50 | args = parser.parse_args() 51 | 52 | # First, organize completions by source and prompt to determine run_index 53 | # source -> prompt -> list of completions 54 | completions_by_prompt = defaultdict(lambda: defaultdict(list)) 55 | 56 | # Read all items first to organize by prompt 57 | all_items = [] 58 | with open(args.input_path, encoding="utf-8") as f: 59 | for line in f: 60 | item = json.loads(line) 61 | all_items.append(item) 62 | 63 | if "prompt" not in item: 64 | raise ValueError(f"Item is missing required 'prompt' field") 65 | 66 | # Store completion information by source and prompt 67 | source = item["source"] 68 | prompt = item["prompt"] 69 | completions_by_prompt[source][prompt].append(item) 70 | 71 | # Now, determine run_index based on order of completions for each prompt 72 | # source -> run_index -> prompt -> correct or not 73 | results = defaultdict(lambda: defaultdict(dict)) 74 | 75 | # Keep track of prompts by source 76 | prompts_by_source = defaultdict(set) 77 | 78 | for source in completions_by_prompt: 79 | for prompt in tqdm(completions_by_prompt[source]): 80 | # Sort completions if there's any specific ordering required 81 | # (assuming they're already in the right order in the file) 82 | completions = completions_by_prompt[source][prompt] 83 | 84 | prompts_by_source[source].add(prompt) 85 | 86 | # Process each completion with its calculated run_index 87 | for run_idx, item in enumerate(completions): 88 | completion = item["completion"] 89 | reference_solution = item.get("reference_solution", item.get("solution")) 90 | 91 | valid = is_valid(completion) 92 | 93 | if source in ["aime2024", "aime2025", "math500"]: 94 | correct = math_equal( 95 | extract_answer(completion), 96 | strip_string(reference_solution.split("####")[1].strip()), 97 | timeout=False, 98 | ) or is_equiv( 99 | remove_boxed(last_boxed_only_string(completion)), 100 | reference_solution.split("####")[-1].strip() if "####" in reference_solution else ( 101 | remove_boxed(last_boxed_only_string(reference_solution)), 102 | ) 103 | ) 104 | else: 105 | raise NotImplementedError(f"Source '{source}' is not implemented") 106 | 107 | # Store the correctness result with calculated run_index 108 | results[source][run_idx][prompt] = int(correct) 109 | 110 | # Calculate and print statistics for each source 111 | print("\nRESULTS BY SOURCE:") 112 | print("-" * 80) 113 | print(f"{'Source':<15} {'Accuracy':<20} {'Num Prompts':<15} {'Runs':<10}") 114 | print("-" * 80) 115 | 116 | for source in sorted(results.keys()): 117 | # We expect 8 runs (0-7) 118 | expected_runs = args.expected_runs 119 | prompts = sorted(prompts_by_source[source]) 120 | 121 | # Calculate accuracy for each run 122 | run_accuracies = [] 123 | run_details = [] 124 | 125 | for run_idx in range(expected_runs): 126 | if run_idx not in results[source]: 127 | print(f"Warning: Source '{source}' is missing run index {run_idx}") 128 | continue 129 | 130 | correct_count = 0 131 | total_count = 0 132 | 133 | for prompt in prompts: 134 | # Some prompts might be missing in some runs 135 | if prompt in results[source][run_idx]: 136 | correct_count += results[source][run_idx][prompt] 137 | total_count += 1 138 | 139 | # Calculate accuracy for this run 140 | if total_count > 0: 141 | run_accuracy = correct_count / total_count 142 | run_accuracies.append(run_accuracy) 143 | run_details.append(f"Run {run_idx}: {run_accuracy:.4f} ({correct_count}/{total_count})") 144 | 145 | # Calculate stats across all runs 146 | if run_accuracies: 147 | mean_accuracy = np.mean(run_accuracies) 148 | # std_error = np.std(run_accuracies, ddof=1) / np.sqrt(len(run_accuracies)) # Standard error of the mean 149 | std_dev = np.std(run_accuracies, ddof=1) 150 | 151 | # Format the result as mean ± standard error with percentage 152 | # accuracy_str = f"{mean_accuracy:.2%} ± {std_dev:.2%}" 153 | 154 | mean_pct = round(mean_accuracy * 100, 1) 155 | std_dev_pct = round(std_dev * 100, 1) 156 | 157 | accuracy_str = f"{mean_pct:.1f}% ± {std_dev_pct:.1f}%" 158 | 159 | print(f"{source:<15} {accuracy_str:<20} {len(prompts):<15} {len(run_accuracies)}") 160 | 161 | # Print individual run details 162 | for detail in run_details: 163 | print(f" {detail}") 164 | 165 | print("-" * 80) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/calc_acc_lcb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | import os 5 | import shutil 6 | import copy 7 | 8 | from livecodebench_v5 import compute_scores as compute_scores_livecodebench_v5 9 | 10 | 11 | def get_after_think(text): 12 | parts = text.split("\n\n\n", 1) 13 | if len(parts) > 1: 14 | return parts[1] 15 | else: 16 | return text 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description="Evaluate model outputs") 21 | parser.add_argument("--input_path", type=str, required=True, help="Path to input jsonl file") 22 | parser.add_argument("--cache_path", type=str, required=True, help="Path to save cache results") 23 | # parser.add_argument("--task_name", type=str, required=True, help="Task should be in ['math_opensource/aime24', 'math_opensource/aime25' ,'livecodebench', 'ifeval']") 24 | args = parser.parse_args() 25 | 26 | os.makedirs(os.path.dirname(args.cache_path), exist_ok=True) 27 | 28 | if os.path.exists(args.cache_path): 29 | if os.path.isdir(args.cache_path): 30 | shutil.rmtree(args.cache_path) # Remove directory and all contents 31 | else: 32 | os.remove(args.cache_path) # Remove file 33 | 34 | data = [] 35 | with open(args.input_path, 'r', encoding='utf-8') as f: 36 | for line in f.readlines(): 37 | item = json.loads(line) 38 | if "completion" in item: 39 | completion = item.pop("completion") 40 | item["gen"] = [completion] 41 | data.append(item) 42 | elif "completions" in item: 43 | completions = item.pop("completions") 44 | for completion in completions: 45 | item_copy = copy.deepcopy(item) 46 | item_copy["gen"] = [completion] 47 | data.append(item_copy) 48 | else: 49 | raise NotImplementedError 50 | 51 | for item in data: 52 | item["task"] = "livecodebench" 53 | temp = get_after_think(item['gen'][0]) 54 | item['gen'][0] = temp 55 | acc = compute_scores_livecodebench_v5(data, args.cache_path) 56 | print(f"Input: {args.input_path}, Pass@1: {acc}") 57 | print("Evaluation complete!") 58 | 59 | 60 | if __name__ == "__main__": 61 | main() -------------------------------------------------------------------------------- /PromptCoT_Mamba/configs/promptcot_mamba_7b_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "initial_scale_power": 12, 6 | "loss_scale_window": 200, 7 | "hysteresis": 5, 8 | "consecutive_hysteresis": true, 9 | "min_loss_scale": 1 10 | }, 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | "optimizer": { 15 | "type": "AdamW", 16 | "params": { 17 | "lr": "auto", 18 | "betas": "auto", 19 | "eps": 1e-08, 20 | "weight_decay": "auto" 21 | } 22 | }, 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": 0, 27 | "warmup_max_lr": 5e-06, 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | "zero_optimization": { 32 | "stage": 2, 33 | "offload_optimizer": { 34 | "device": "cpu", 35 | "pin_memory": true 36 | }, 37 | "overlap_comm": false, 38 | "contiguous_gradients": false, 39 | "sub_group_size": 1000000000.0, 40 | "reduce_bucket_size": 1000000.0 41 | }, 42 | "steps_per_print": 1, 43 | "train_micro_batch_size_per_gpu": 1 44 | } 45 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/infer_longcot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | 5 | from vllm import LLM, SamplingParams 6 | from transformers import AutoTokenizer 7 | from str2bool import str2bool 8 | import os 9 | import re 10 | 11 | 12 | def is_valid_think_format(text): 13 | # Check basic pattern first 14 | pattern = r'^(.*?)(.+)$' # Note the .+ for non-empty after content 15 | match = re.match(pattern, text, re.DOTALL) 16 | 17 | if not match: 18 | return False 19 | 20 | # Extract the content inside and after the think tags 21 | inside_content = match.group(1) 22 | after_content = match.group(2) 23 | 24 | # Verify inside content is not empty 25 | if not inside_content.strip(): 26 | return False 27 | 28 | # Verify neither part contains additional think tags 29 | if '' in inside_content or '' in inside_content: 30 | return False 31 | 32 | if '' in after_content or '' in after_content: 33 | return False 34 | 35 | return True 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description="Evaluate large language models on critical datasets.") 40 | parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file.") 41 | parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.") 42 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.") 43 | parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to the pretrained model.") 44 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type to use for the model (e.g., fp16, bf16, etc.).") 45 | parser.add_argument("--n_gpus", type=int, default=8, help="Number of GPUs to use for tensor parallelism.") 46 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for generation.") 47 | parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling for generation.") 48 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 49 | parser.add_argument("--max_len", type=int, default=32768, help="Maximum number of tokens to generate.") 50 | parser.add_argument("--use_chat_template", type=str2bool, default=False) 51 | parser.add_argument("--n", type=int, default=8) 52 | parser.add_argument("--max_retries", type=int, default=8) 53 | parser.add_argument("--use_mamba2", type=str2bool, default=False) 54 | 55 | args = parser.parse_args() 56 | 57 | if args.tokenizer_path is None: 58 | args.tokenizer_path = args.model_path 59 | 60 | # Load the tokenizer for LLaMA or any model 61 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 62 | 63 | # Load inference framework 64 | if not args.use_mamba2: 65 | model = LLM( 66 | model=args.model_path, 67 | tokenizer=args.tokenizer_path, 68 | tokenizer_mode="slow", 69 | dtype=args.dtype, 70 | tensor_parallel_size=args.n_gpus, 71 | enforce_eager=True, # new 72 | disable_custom_all_reduce=True, # todo: need check 73 | ) 74 | else: 75 | model = LLM( 76 | model=args.model_path, 77 | tokenizer=args.tokenizer_path, 78 | max_model_len=args.max_len + 4096, 79 | tokenizer_mode="slow", 80 | dtype=args.dtype, 81 | tensor_parallel_size=args.n_gpus, 82 | ) 83 | 84 | items = [] 85 | completions = [] 86 | seed = 0 87 | for _ in range(args.n): 88 | prompts = [] 89 | with open(args.data_path, encoding="utf-8") as f: 90 | for line in f.readlines(): 91 | item = json.loads(line) 92 | prompt = item["prompt"] 93 | if args.use_chat_template: 94 | messages = [ 95 | {"role": "user", "content": prompt} 96 | ] 97 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 98 | prompts.append(prompt) 99 | items.append(item) 100 | 101 | with torch.no_grad(): 102 | # Initialize with None for each prompt position 103 | output_lst = [None] * len(prompts) 104 | pending_prompts = prompts.copy() 105 | pending_indices = list(range(len(prompts))) 106 | 107 | retry_count = 0 108 | 109 | while pending_prompts and retry_count < args.max_retries: 110 | sampling_params = SamplingParams( 111 | temperature=args.temperature, 112 | top_p=args.top_p, 113 | max_tokens=args.max_len, 114 | repetition_penalty=args.repetition_penalty, 115 | seed=seed, 116 | ) 117 | seed += 1 118 | 119 | # Generate completions for remaining prompts 120 | batch_outputs = model.generate(pending_prompts, sampling_params) 121 | batch_texts = [completion.outputs[0].text.strip() for completion in batch_outputs] 122 | batch_texts = [f"\n{text}" if not text.startswith("") else text for text in batch_texts] 123 | 124 | # Process current batch results 125 | still_pending_prompts = [] 126 | still_pending_indices = [] 127 | 128 | for i, (idx, text) in enumerate(zip(pending_indices, batch_texts)): 129 | if is_valid_think_format(text): 130 | # If valid, add to results at original position 131 | output_lst[idx] = text 132 | else: 133 | # If invalid, keep for retry 134 | still_pending_prompts.append(pending_prompts[i]) 135 | still_pending_indices.append(idx) 136 | # Store the invalid output in case we reach max retries 137 | output_lst[idx] = text 138 | 139 | # Update for next iteration 140 | pending_prompts = still_pending_prompts 141 | pending_indices = still_pending_indices 142 | retry_count += 1 143 | 144 | completions.extend(output_lst) 145 | 146 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 147 | with open(args.output_path, "w", encoding="utf-8") as f: 148 | for item, completion in zip(items, completions): 149 | item["completion"] = completion 150 | f.write(json.dumps(item) + "\n") 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/livecodebench_v5.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from collections import defaultdict 3 | from datetime import datetime 4 | 5 | import os 6 | import hashlib 7 | import json 8 | import logging 9 | import multiprocessing 10 | from multiprocessing.pool import ThreadPool 11 | import numpy as np 12 | from statistics import mean 13 | from tqdm import tqdm 14 | import copy 15 | 16 | from livecodebench_v5_utils.compute_code_generation_metrics import _temp_run 17 | 18 | LIVECODEBENCH_TESTS = os.getenv("LIVECODEBENCH_TESTS", "data/livecodebench_v5_tests") 19 | 20 | def _extract_code(text: str) -> str: 21 | outputlines = text.split("\n") 22 | indexlines = [i for i, line in enumerate(outputlines) if "```" in line] 23 | if len(indexlines) < 2: 24 | return "" 25 | return "\n".join(outputlines[indexlines[-2] + 1:indexlines[-1]]) 26 | 27 | def preprocess(job): 28 | tests = job['tests'] 29 | raw_gen = job['gen'] if isinstance(job['gen'], str) else job['gen'][0] 30 | gen_code = _extract_code(raw_gen) 31 | 32 | return tests, gen_code 33 | 34 | def work(job): 35 | tests, generation = preprocess(job) 36 | res = check_correctness( 37 | tests=tests, 38 | generation=generation, 39 | ) 40 | assert res['md5'] == tests['md5'], "test md5 mismatched" 41 | return res, job 42 | 43 | def compute_scores(jobs, cache_path): 44 | with ThreadPool(max(1, int(os.cpu_count() * 0.5))) as pool: 45 | for res, job in tqdm(pool.imap_unordered(work, jobs), total=len(jobs)): 46 | extraction_failed = 0 47 | ispass = res['ispass'] 48 | metadata = res['metadata'] 49 | extraction_failed = metadata.get("error_code", 0) == -1 50 | results = res['results'] 51 | 52 | job.update({ 53 | "pass-1": ispass, 54 | "results": results, 55 | "metadata": metadata, 56 | "extraction_failed": extraction_failed, 57 | }) 58 | save_cache(job, cache_path) 59 | with open(cache_path, "r") as f: 60 | jobs = [json.loads(l) for l in f] 61 | 62 | # Retry all timeout jobs sequentially (without using multiprocessing) 63 | new_jobs = [] 64 | for job in tqdm(jobs, desc="Processing jobs"): 65 | error_code = job["metadata"].get("error_code", 0) 66 | if error_code == -3: 67 | res, job = work(job) 68 | job.update(res) 69 | new_jobs.append(job) 70 | save_cache(job, cache_path.replace(".jsonl", "_try2.jsonl")) 71 | else: 72 | new_jobs.append(job) 73 | 74 | return mean(x['pass-1'] for x in new_jobs) 75 | def check_correctness(tests: dict, generation: str, timeout: int = 30, debug: bool = False): 76 | """Check correctness of code generation with a global timeout. 77 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 78 | inside `run_test`""" 79 | 80 | tests_path = Path(LIVECODEBENCH_TESTS) / tests['fname'] 81 | with open(tests_path, "r") as f: 82 | sample = json.load(f) 83 | 84 | md5 = calculate_string_md5(json.dumps(sample)) 85 | 86 | manager = multiprocessing.Manager() 87 | result = manager.list() 88 | metadata_list = manager.list() 89 | p = multiprocessing.Process( 90 | target=_temp_run, 91 | args=(sample, generation, debug, result, metadata_list, timeout), 92 | ) 93 | p.start() 94 | p.join(timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5) 95 | if p.is_alive(): 96 | p.kill() 97 | if not result: 98 | in_outs = json.loads(sample["input_output"]) 99 | # consider that all tests failed 100 | result = [[-1 for i in range(len(in_outs["inputs"]))]] 101 | metadata_list = [{"error_code": -3}] 102 | if debug: 103 | print(f"global timeout") 104 | 105 | res, metadata = result[0], metadata_list[0] 106 | fixed = [] 107 | for e in res: 108 | if isinstance(e, np.ndarray): 109 | e = e.item(0) 110 | if isinstance(e, np.bool_): 111 | e = bool(e) 112 | if e != True and e != False: 113 | e = False 114 | fixed.append(e) 115 | res = fixed 116 | # print(res) 117 | if not np.all(res): 118 | # print("fail") 119 | return dict(ispass=0, md5=md5, results=res, metadata=metadata) 120 | else: 121 | # print("pass") 122 | return dict(ispass=1, md5=md5, results=res, metadata=metadata) 123 | 124 | def calculate_string_md5(input_string: str): 125 | md5 = hashlib.md5() 126 | md5.update(input_string.encode('utf-8')) 127 | return md5.hexdigest() 128 | 129 | def save_cache(job, cache_path): 130 | with open(cache_path, "a") as g: 131 | g.write(json.dumps(job, ensure_ascii=False) + "\n") 132 | g.flush() -------------------------------------------------------------------------------- /PromptCoT_Mamba/livecodebench_v5_utils/compute_code_generation_metrics.py: -------------------------------------------------------------------------------- 1 | # borrowed and extended from 2 | # https://github.com/Naman-ntc/codescratch/blob/main/evaluation/bigcode-evaluation-harness/lm_eval/tasks/custom_metrics/apps_custom_metrics/utils.py 3 | 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | sys.set_int_max_str_digits(50000) 9 | 10 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 11 | import json 12 | import multiprocessing 13 | from collections import defaultdict 14 | from concurrent.futures import ProcessPoolExecutor, as_completed 15 | 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | from livecodebench_v5_utils.testing_util import run_test 20 | from livecodebench_v5_utils.pass_k_utils import compute_metrics_from_results 21 | 22 | 23 | def _temp_run(sample, generation, debug, result, metadata_list, timeout): 24 | res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout) 25 | result.append(res) 26 | metadata_list.append(metadata) 27 | 28 | 29 | def check_correctness(sample, generation, timeout, debug=True): 30 | """Check correctness of code generation with a global timeout. 31 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 32 | inside `run_test`""" 33 | 34 | manager = multiprocessing.Manager() 35 | result = manager.list() 36 | metadata_list = manager.list() 37 | p = multiprocessing.Process( 38 | target=_temp_run, 39 | args=(sample, generation, debug, result, metadata_list, timeout), 40 | ) 41 | p.start() 42 | p.join(timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5) 43 | if p.is_alive(): 44 | p.kill() 45 | if not result: 46 | in_outs = json.loads(sample["input_output"]) 47 | # consider that all tests failed 48 | result = [[-1 for i in range(len(in_outs["inputs"]))]] 49 | if debug: 50 | print(f"global timeout") 51 | return result[0], metadata_list[0] 52 | 53 | 54 | def evaluate_generations_by_problem(args): 55 | problem_generations: List[str] = args[0] 56 | sample = args[1] 57 | debug: bool = args[2] 58 | timeout: int = args[3] 59 | 60 | res = [] 61 | metadata = [] 62 | for o_idx, o in enumerate(problem_generations): 63 | curr_res = [-2] 64 | try: 65 | curr_res, curr_metadata = check_correctness(sample, o, timeout=timeout, debug=debug) 66 | if debug: 67 | print(f"\nSuccessful compilation of task {o_idx}!") 68 | fixed = [] 69 | for e in curr_res: 70 | if isinstance(e, np.ndarray): 71 | e = e.item(0) 72 | if isinstance(e, np.bool_): 73 | e = bool(e) 74 | fixed.append(e) 75 | curr_res = fixed 76 | if not np.all(curr_res): 77 | if debug: 78 | print(f"Results were not True for all test cases {curr_res=}\n") 79 | except Exception as e: 80 | if debug: 81 | print(f"Compilation failed, test framework exception = {repr(e)}{e}\n") 82 | # break 83 | curr_metadata = { 84 | "error": repr(e), 85 | "error_code": -5, 86 | "error_message": "TestRunnerError", 87 | } 88 | finally: 89 | assert isinstance(curr_res, list), curr_res 90 | assert isinstance(curr_metadata, dict), curr_metadata 91 | res.append(curr_res) 92 | metadata.append(curr_metadata) 93 | if debug: 94 | for i, r in enumerate(problem_generations): 95 | print("Sample\n") 96 | print(r) 97 | print("\n") 98 | print("Result\n") 99 | print(res[i]) 100 | print("*" * 30 + "\n\n") 101 | return res, metadata 102 | 103 | 104 | def evaluate_generations( 105 | samples_list: list, 106 | generations_list: List[List[str]], 107 | debug: bool = False, 108 | num_process_evaluate: int = 16, 109 | timeout=6, 110 | ): 111 | """We take the list of code generations and try to compile them 112 | and the run their corresponding unit tests which are retrieved from the APPS dataset. 113 | 114 | Args: 115 | generations: list of code generations (same order as samples in APPS dataset) 116 | level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition" 117 | 118 | Returns: 119 | results: dictionary of results, key is the problem index, value is a list of results for each generation 120 | """ 121 | 122 | # generations are code generations in the same order of the dataset 123 | 124 | inputs = [[(generations_list[index], samples_list[index], debug, timeout), index] for index in range(len(generations_list))] 125 | 126 | with tqdm(total=len(inputs)) as pbar: 127 | with ProcessPoolExecutor(max_workers=1 if debug else num_process_evaluate) as executor: 128 | futures = {executor.submit(evaluate_generations_by_problem, arg): index for arg, index in inputs} 129 | 130 | results = {} 131 | metadata = {} 132 | for future in as_completed(futures): 133 | index = futures[future] 134 | results[index], metadata[index] = future.result() 135 | pbar.update(1) 136 | 137 | assert len(results) == len(inputs), f"results = {len(results)} inputs = {len(inputs)} {results=}" 138 | # results = {i: r for r, (_, i) in zip(results, inputs)} 139 | 140 | return results, metadata 141 | 142 | 143 | def codegen_metrics( 144 | samples_list, 145 | generations_list, 146 | k_list=[1, 5, 10, 20, 40, 50, 75, 100, 125, 150, 200, 500, 1000], 147 | num_process_evaluate=16, 148 | timeout=6, 149 | debug=False, 150 | ): 151 | 152 | samples_linear = [] 153 | generations_linear = [] 154 | remap_index = [] 155 | results = defaultdict(list) 156 | metadatas = defaultdict(list) 157 | for idx, (sample, generation_list) in enumerate(zip(samples_list, generations_list)): 158 | assert isinstance(generation_list, list), generations_list[0] 159 | for generation in generation_list: 160 | assert isinstance(generation, str), generations_list[0] 161 | samples_linear.append(sample) 162 | generations_linear.append([generation]) 163 | remap_index.append(idx) 164 | 165 | print(f"Evaluating {len(samples_linear)}...") 166 | 167 | results_linear, metadatas_linear = evaluate_generations( 168 | samples_linear, 169 | generations_linear, 170 | debug=debug, 171 | num_process_evaluate=num_process_evaluate, 172 | timeout=timeout, 173 | ) 174 | 175 | for idx, sub_results in sorted(results_linear.items(), key=lambda x: x[0]): 176 | results[remap_index[idx]].append(sub_results[0]) 177 | 178 | for idx, sub_metadatas in sorted(metadatas_linear.items(), key=lambda x: x[0]): 179 | metadatas[remap_index[idx]].append(sub_metadatas[0]) 180 | 181 | metrics = compute_metrics_from_results(results, k_list=k_list) 182 | 183 | final_metadata = [] 184 | for key in sorted(list(metadatas.keys())): 185 | final_metadata.append(metadatas[key]) 186 | for i in range(len(final_metadata)): 187 | if type(final_metadata[i]) is not list: 188 | final_metadata[i] = [json.dumps(final_metadata[i])] 189 | else: 190 | final_metadata[i] = [json.dumps(x) for x in final_metadata[i]] 191 | 192 | assert len(final_metadata[i]) == len(generations_list[0]), f"{len(final_metadata[i])=}" 193 | 194 | return [metrics, results, final_metadata] 195 | 196 | 197 | if __name__ == "__main__": 198 | # print( 199 | # check_correctness( 200 | # { 201 | # "input_output": json.dumps( 202 | # { 203 | # "inputs": [ 204 | # json.dumps([1] * 100000) 205 | # + "\n" 206 | # + json.dumps([100000, -100000] * (100000 // 2)) 207 | # ], 208 | # "outputs": [json.dumps([100000, 0] * (100000 // 2))], 209 | # "fn_name": "mostFrequentIDs", 210 | # } 211 | # ) 212 | # }, 213 | # "class Solution:\n def mostFrequentIDs(self, nums: List[int], freq: List[int]) -> List[int]:\n from collections import defaultdict\n \n # Count of each ID\n count = defaultdict(int)\n # How many IDs exist for a given frequency\n freq_of_count = defaultdict(int)\n \n max_freq = 0\n ans = []\n \n for i in range(len(nums)):\n x = nums[i]\n change = freq[i]\n \n old_freq = count[x]\n new_freq = old_freq + change\n \n # If there was an old frequency, decrease its usage\n if old_freq > 0:\n freq_of_count[old_freq] -= 1\n if freq_of_count[old_freq] == 0:\n del freq_of_count[old_freq]\n \n # Update with the new frequency\n count[x] = new_freq\n freq_of_count[new_freq] += 1\n \n # Update max_freq if needed\n if new_freq > max_freq:\n max_freq = new_freq\n \n # If the collection at max_freq is empty, reduce max_freq until we find a non-empty bin\n while max_freq > 0 and max_freq not in freq_of_count:\n max_freq -= 1\n \n # If the collection is empty, max_freq will be 0\n ans.append(max_freq)\n \n return ans", 214 | # 6, 215 | # debug=True, 216 | # ) 217 | # ) 218 | 219 | print( 220 | check_correctness( 221 | {"input_output": json.dumps({ 222 | "inputs": ")))))", 223 | "outputs": "0", 224 | },)}, 225 | "\nMOD = 998244353\n\nS = input().strip()\nn = len(S)\n\nif n % 2 != 0:\n print(0)\n exit()\n\n# Initialize DP table\ndp = [[0] * (n + 2) for _ in range(n + 1)]\ndp[0][0] = 1\n\nfor i in range(1, n + 1):\n c = S[i-1]\n for b in range(n + 1):\n if dp[i-1][b] == 0:\n continue\n if c == '(':\n new_b = b + 1\n if new_b <= n:\n dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n elif c == ')':\n if b > 0:\n new_b = b - 1\n dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n else: # '?'\n # Replace with '('\n new_b = b + 1\n if new_b <= n:\n dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n # Replace with ')'\n if b > 0:\n new_b = b - 1\n dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n\nprint(dp[n][0] % MOD)\n", 226 | 6, 227 | debug=True, 228 | )) -------------------------------------------------------------------------------- /PromptCoT_Mamba/livecodebench_v5_utils/pass_k_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def estimate_pass_at_k(num_samples, num_correct, k): 5 | """Estimates pass@k of each problem and returns them in an array.""" 6 | 7 | def estimator(n: int, c: int, k: int) -> float: 8 | """Calculates 1 - comb(n - c, k) / comb(n, k).""" 9 | if n - c < k: 10 | return 1.0 11 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 12 | 13 | import itertools 14 | 15 | if isinstance(num_samples, int): 16 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 17 | else: 18 | assert len(num_samples) == len(num_correct) 19 | num_samples_it = iter(num_samples) 20 | 21 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 22 | 23 | 24 | def compute_metrics_from_results(results, k_list=[1, 5]): 25 | total = [] 26 | correct = [] 27 | task_ids = [] 28 | for task_id, res in results.items(): 29 | all_correct = [] 30 | for generation in res: 31 | gen = np.array(generation) 32 | all_correct.append(np.all(gen > 0)) 33 | task_ids.append(task_id) 34 | total.append(len(all_correct)) 35 | correct.append(sum(all_correct)) 36 | total = np.array(total) 37 | correct = np.array(correct) 38 | ks = k_list 39 | detail_pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).tolist() for k in ks if (total >= k).all()} 40 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()} 41 | detail_metrics = {k: dict(zip(task_ids, v)) for k, v in detail_pass_at_k.items()} 42 | pass_at_k["detail"] = detail_metrics 43 | return pass_at_k 44 | 45 | 46 | def extract_instance_results(results): 47 | instance_wise_grades = {} 48 | for task_id, res in results.items(): 49 | instance_wise_grades[task_id] = [] 50 | for generation in res: 51 | instance_wise_grades[task_id].append(all([g > 0 for g in generation])) 52 | 53 | instance_wise_grades = [v for _, v in sorted(instance_wise_grades.items(), key=lambda item: item[0])] 54 | return instance_wise_grades -------------------------------------------------------------------------------- /PromptCoT_Mamba/livecodebench_v5_utils/process_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zlib 3 | import pickle 4 | import base64 5 | import hashlib 6 | from enum import Enum 7 | from datetime import datetime 8 | from dataclasses import dataclass 9 | 10 | from pathlib import Path 11 | from datasets import load_dataset 12 | from tqdm import tqdm 13 | 14 | 15 | class Platform(Enum): 16 | LEETCODE = "leetcode" 17 | CODEFORCES = "codeforces" 18 | ATCODER = "atcoder" 19 | 20 | 21 | class Difficulty(Enum): 22 | EASY = "easy" 23 | MEDIUM = "medium" 24 | HARD = "hard" 25 | 26 | 27 | class TestType(Enum): 28 | STDIN = "stdin" 29 | FUNCTIONAL = "functional" 30 | 31 | 32 | @dataclass 33 | class Test: 34 | input: str 35 | output: str 36 | testtype: TestType 37 | 38 | def __post_init__(self): 39 | self.testtype = TestType(self.testtype) 40 | # if self.testtype == TestType.FUNCTIONAL: 41 | # self.input = json.loads(self.input) 42 | # self.output = json.loads(self.output) 43 | 44 | 45 | @dataclass 46 | class CodeGenerationProblem: 47 | question_title: str 48 | question_content: str 49 | platform: Platform 50 | question_id: str 51 | contest_id: str 52 | contest_date: datetime 53 | starter_code: str 54 | difficulty: Difficulty 55 | public_test_cases: list[Test] 56 | private_test_cases: list[Test] 57 | metadata: dict 58 | 59 | def __post_init__(self): 60 | self.platform = Platform(self.platform) 61 | self.difficulty = Difficulty(self.difficulty) 62 | self.contest_date = datetime.fromisoformat(self.contest_date) 63 | 64 | self.public_test_cases = json.loads(self.public_test_cases) # type: ignore 65 | self.public_test_cases = [Test(**t) for t in self.public_test_cases] 66 | 67 | try: 68 | self.private_test_cases = json.loads(self.private_test_cases) # type: ignore 69 | except: 70 | self.private_test_cases = json.loads(pickle.loads(zlib.decompress(base64.b64decode(self.private_test_cases.encode("utf-8")) # type: ignore 71 | ))) # type: ignore 72 | self.private_test_cases = [Test(**t) for t in self.private_test_cases] 73 | 74 | self.metadata = json.loads(self.metadata) # type: ignore 75 | 76 | def insert_output(self, output_list: list[str], code_list: list[str]) -> dict: 77 | return { 78 | "question_title": self.question_title, 79 | "question_content": self.question_content, 80 | "platform": self.platform.value, 81 | "question_id": self.question_id, 82 | "contest_id": self.contest_id, 83 | "contest_date": self.contest_date.isoformat(), 84 | "starter_code": self.starter_code, 85 | "difficulty": self.difficulty.value, 86 | "output_list": output_list, 87 | "code_list": code_list, 88 | } 89 | 90 | def insert_output_evaluation( 91 | self, 92 | output_list: list[str], 93 | code_list: list[str], 94 | graded_list: list[bool], 95 | **kwargs, 96 | ) -> dict: 97 | output = self.insert_output(output_list, code_list) 98 | output["graded_list"] = graded_list 99 | output["pass@1"] = graded_list.count(True) / len(graded_list) 100 | for k, v in kwargs.items(): 101 | output[k] = v 102 | return output 103 | 104 | def get_evaluation_sample(self): 105 | return { 106 | "input_output": json.dumps({ 107 | "inputs": [t.input for t in self.public_test_cases + self.private_test_cases], 108 | "outputs": [t.output for t in self.public_test_cases + self.private_test_cases], 109 | "fn_name": self.metadata.get("func_name", None), 110 | }), 111 | } 112 | 113 | 114 | class PromptConstants: 115 | SYSTEM_MESSAGE_GENERIC = f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests." 116 | 117 | SYSTEM_MESSAGE_GEMINI = f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. Do NOT use system calls like `exit` in the generated program. Ensure that the first code block contains the solution." 118 | 119 | SYSTEM_MESSAGE_GEMINITHINK = f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests." 120 | 121 | SYSTEM_MESSAGE_DEEPSEEK = f"You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you answer questions related to computer science." 122 | 123 | SYSTEM_MESSAGE_CODEQWEN = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user" 124 | 125 | FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters." 126 | 127 | FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." 128 | 129 | 130 | def load_code_generation_dataset(release_version="release_v5") -> list[CodeGenerationProblem]: 131 | dataset = load_dataset("livecodebench/code_generation_lite", split="test", version_tag=release_version, trust_remote_code=True) 132 | dataset = [CodeGenerationProblem(**p) for p in dataset] # type: ignore 133 | print(f"Loaded {len(dataset)} problems") 134 | return dataset 135 | 136 | 137 | def get_qwen_question_template_answer(question: CodeGenerationProblem): 138 | prompt = "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. You will NOT return anything except for the program.\n\n" 139 | prompt += f"Question: {question.question_content}\n\n" 140 | if question.starter_code: 141 | prompt += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n" 142 | prompt += f"```python\n{question.starter_code}\n```\n\n" 143 | else: 144 | prompt += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n" 145 | prompt += f"```python\n# YOUR CODE HERE\n```\n\n" 146 | return prompt 147 | 148 | 149 | def get_qwen_reasoning_question_template_answer(question: CodeGenerationProblem): 150 | prompt = "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\n\n" 151 | prompt += f"Question: {question.question_content}\n\n" 152 | if question.starter_code: 153 | prompt += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n" 154 | prompt += f"```python\n{question.starter_code}\n```\n\n" 155 | else: 156 | prompt += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n" 157 | prompt += f"```python\n# YOUR CODE HERE\n```\n\n" 158 | return prompt 159 | 160 | 161 | def calculate_string_md5(input_string: str): 162 | md5 = hashlib.md5() 163 | md5.update(input_string.encode('utf-8')) 164 | return md5.hexdigest() 165 | 166 | 167 | if __name__ == "__main__": 168 | 169 | output_livecodebench_v5_tests_dir = "/home/data/public/data/eval/code/livecodebench_v5_tests" 170 | output_livecodebench_v5_data_path = "/home/data/public/data/eval/code/livecodebench_v5.jsonl" 171 | Path(output_livecodebench_v5_tests_dir).mkdir(parents=True, exist_ok=True) 172 | Path(output_livecodebench_v5_data_path).parent.mkdir(parents=True, exist_ok=True) 173 | 174 | dataset = load_code_generation_dataset(release_version="release_v5") 175 | num_samples = 10 176 | 177 | livecodebench_v5_inputs_outputs = [] 178 | livecodebench_v5_dataset = [] 179 | 180 | # template for general language model 181 | # prompt_template = get_qwen_question_template_answer 182 | # template for reasoning model 183 | prompt_template = get_qwen_reasoning_question_template_answer 184 | 185 | for global_id, sample in enumerate(tqdm(dataset)): 186 | inputs_outputs = sample.get_evaluation_sample() 187 | livecodebench_v5_dataset.append({ 188 | "global_id": global_id, 189 | "question_id": sample.question_id, 190 | "contest_id": sample.contest_id, 191 | "contest_date": sample.contest_date.isoformat(), 192 | "prompt": prompt_template(sample), 193 | "tests": { 194 | "fname": f"{global_id}.json", 195 | "md5": calculate_string_md5(json.dumps(inputs_outputs)), 196 | }, 197 | "tags": "coding,en,python,core", 198 | "task": "livecodebench_v5", 199 | "source": "livecodebench_v5", 200 | "beam_size": num_samples, 201 | # "eval_args": eval_args, 202 | }) 203 | livecodebench_v5_inputs_outputs.append(inputs_outputs) 204 | 205 | # save test cases 206 | with open(Path(output_livecodebench_v5_tests_dir) / f"{global_id}.json", "w") as f: 207 | json.dump(inputs_outputs, f) 208 | 209 | # save dataset 210 | with open(output_livecodebench_v5_data_path, "w") as f: 211 | for sample in livecodebench_v5_dataset: 212 | f.write(json.dumps(sample) + "\n") -------------------------------------------------------------------------------- /PromptCoT_Mamba/livecodebench_v5_utils/testing_util.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import sys 4 | import faulthandler 5 | import platform 6 | 7 | # used for debugging to time steps 8 | from datetime import datetime 9 | from typing import Tuple, List 10 | 11 | # to run the solution files we're using a timing based approach 12 | import signal 13 | 14 | import numpy as np 15 | 16 | from io import StringIO 17 | 18 | # used for testing the code that reads from input 19 | from unittest.mock import patch, mock_open 20 | 21 | # from pyext import RuntimeModule 22 | from types import ModuleType 23 | 24 | from enum import Enum 25 | from decimal import Decimal 26 | import time 27 | 28 | import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n" 29 | 30 | 31 | def truncatefn(s, length=300): 32 | if isinstance(s, str): 33 | pass 34 | else: 35 | s = str(s) 36 | if len(s) <= length: 37 | return s 38 | 39 | return s[:length // 2] + "...(truncated) ..." + s[-length // 2:] 40 | 41 | 42 | class CODE_TYPE(Enum): 43 | call_based = 0 44 | standard_input = 1 45 | 46 | 47 | # stuff for setting up signal timer 48 | class TimeoutException(Exception): 49 | pass 50 | 51 | 52 | def timeout_handler(signum, frame): 53 | print("timeout occured: alarm went off") 54 | raise TimeoutException 55 | 56 | 57 | # used to capture stdout as a list 58 | # from https://stackoverflow.com/a/16571630/6416660 59 | # alternative use redirect_stdout() from contextlib 60 | class Capturing(list): 61 | 62 | def __enter__(self): 63 | self._stdout = sys.stdout 64 | sys.stdout = self._stringio = StringIO() 65 | # Make closing the StringIO a no-op 66 | self._stringio.close = lambda x: 1 67 | return self 68 | 69 | def __exit__(self, *args): 70 | self.append(self._stringio.getvalue()) 71 | del self._stringio # free up some memory 72 | sys.stdout = self._stdout 73 | 74 | 75 | def clean_if_name(code: str) -> str: 76 | try: 77 | astree = ast.parse(code) 78 | last_block = astree.body[-1] 79 | if isinstance(last_block, ast.If): 80 | condition = last_block.test 81 | if ast.unparse(condition).strip() == "__name__ == '__main__'": 82 | code = ( 83 | ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore 84 | ) 85 | except: 86 | pass 87 | 88 | return code 89 | 90 | 91 | def make_function(code: str) -> str: 92 | try: 93 | import_stmts = [] 94 | all_other_stmts = [] 95 | astree = ast.parse(code) 96 | for stmt in astree.body: 97 | if isinstance(stmt, (ast.Import, ast.ImportFrom)): 98 | import_stmts.append(stmt) 99 | else: 100 | all_other_stmts.append(stmt) 101 | 102 | function_ast = ast.FunctionDef( 103 | name="wrapped_function", 104 | args=ast.arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), 105 | body=all_other_stmts, 106 | decorator_list=[], 107 | lineno=-1, 108 | ) 109 | main_code = ( 110 | import_string + "\n" + ast.unparse(import_stmts) # type: ignore 111 | + "\n" + ast.unparse(function_ast) # type: ignore 112 | ) 113 | return main_code 114 | except Exception as e: 115 | return code 116 | 117 | 118 | def call_method(method, inputs): 119 | 120 | if isinstance(inputs, list): 121 | inputs = "\n".join(inputs) 122 | 123 | inputs_line_iterator = iter(inputs.split("\n")) 124 | 125 | # sys.setrecursionlimit(10000) 126 | 127 | # @patch('builtins.input', side_effect=inputs.split("\n")) 128 | @patch("builtins.open", mock_open(read_data=inputs)) 129 | @patch("sys.stdin", StringIO(inputs)) 130 | @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) 131 | @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) 132 | @patch("sys.stdin.read", lambda *args: inputs) 133 | # @patch('sys.stdout.write', print) 134 | def _inner_call_method(_method): 135 | try: 136 | return _method() 137 | except SystemExit as e: 138 | pass 139 | finally: 140 | pass 141 | 142 | return _inner_call_method(method) 143 | 144 | 145 | def get_function(compiled_sol, fn_name: str): # type: ignore 146 | try: 147 | assert hasattr(compiled_sol, fn_name) 148 | return getattr(compiled_sol, fn_name) 149 | except Exception as e: 150 | return 151 | 152 | 153 | def compile_code(code: str, timeout: int): 154 | signal.alarm(timeout) 155 | try: 156 | tmp_sol = ModuleType("tmp_sol", "") 157 | exec(code, tmp_sol.__dict__) 158 | if "class Solution" in code: 159 | # leetcode wraps solutions in `Solution` 160 | # this is a hack to check if it is leetcode solution or not 161 | # currently livecodebench only supports LeetCode but 162 | # else condition allows future extensibility to other platforms 163 | compiled_sol = tmp_sol.Solution() 164 | else: 165 | # do nothing in the other case since function is accesible 166 | compiled_sol = tmp_sol 167 | 168 | assert compiled_sol is not None 169 | finally: 170 | signal.alarm(0) 171 | 172 | return compiled_sol 173 | 174 | 175 | def convert_line_to_decimals(line: str) -> Tuple[bool, List[Decimal]]: 176 | try: 177 | decimal_line = [Decimal(elem) for elem in line.split()] 178 | except: 179 | return False, [] 180 | return True, decimal_line 181 | 182 | 183 | def get_stripped_lines(val: str): 184 | ## you don't want empty lines to add empty list after splitlines! 185 | val = val.strip() 186 | 187 | return [val_line.strip() for val_line in val.split("\n")] 188 | 189 | 190 | def grade_call_based(code: str, all_inputs: list, all_outputs: list, fn_name: str, timeout: int): 191 | # call-based clean up logic 192 | # need to wrap in try-catch logic after to catch the correct errors, but for now this is fine. 193 | code = import_string + "\n\n" + code 194 | compiled_sol = compile_code(code, timeout) 195 | 196 | if compiled_sol is None: 197 | return 198 | 199 | method = get_function(compiled_sol, fn_name) 200 | 201 | if method is None: 202 | return 203 | 204 | all_inputs = [[json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs] 205 | 206 | all_outputs = [json.loads(output) for output in all_outputs] 207 | 208 | total_execution = 0 209 | all_results = [] 210 | for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)): 211 | signal.alarm(timeout) 212 | faulthandler.enable() 213 | try: 214 | # can lock here so time is useful 215 | start = time.time() 216 | prediction = method(*gt_inp) 217 | total_execution += time.time() - start 218 | signal.alarm(0) 219 | 220 | # don't penalize model if it produces tuples instead of lists 221 | # ground truth sequences are not tuples 222 | if isinstance(prediction, tuple): 223 | prediction = list(prediction) 224 | 225 | tmp_result = prediction == gt_out 226 | 227 | # handle floating point comparisons 228 | 229 | all_results.append(tmp_result) 230 | 231 | if not tmp_result: 232 | return all_results, { 233 | "output": truncatefn(prediction), 234 | "inputs": truncatefn(gt_inp), 235 | "expected": truncatefn(gt_out), 236 | "error_code": -2, 237 | "error_message": "Wrong Answer", 238 | } 239 | except Exception as e: 240 | signal.alarm(0) 241 | if "timeoutexception" in repr(e).lower(): 242 | all_results.append(-3) 243 | return all_results, { 244 | "error": repr(e), 245 | "error_code": -3, 246 | "error_message": "Time Limit Exceeded", 247 | "inputs": truncatefn(gt_inp), 248 | "expected": truncatefn(gt_out), 249 | } 250 | else: 251 | all_results.append(-4) 252 | return all_results, { 253 | "error": repr(e), 254 | "error_code": -4, 255 | "error_message": "Runtime Error", 256 | "inputs": truncatefn(gt_inp), 257 | "expected": truncatefn(gt_out), 258 | } 259 | 260 | finally: 261 | signal.alarm(0) 262 | faulthandler.disable() 263 | 264 | return all_results, {"execution time": total_execution} 265 | 266 | 267 | def grade_stdio( 268 | code: str, 269 | all_inputs: list, 270 | all_outputs: list, 271 | timeout: int, 272 | ): 273 | ## runtime doesn't interact well with __name__ == '__main__' 274 | code = clean_if_name(code) 275 | 276 | ## we wrap the given code inside another function 277 | code = make_function(code) 278 | 279 | compiled_sol = compile_code(code, timeout) 280 | if compiled_sol is None: 281 | return 282 | 283 | method = get_function(compiled_sol, "wrapped_function") 284 | 285 | if method is None: 286 | return 287 | 288 | all_results = [] 289 | total_execution_time = 0 290 | for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)): 291 | signal.alarm(timeout) 292 | faulthandler.enable() 293 | 294 | signal.alarm(timeout) 295 | with Capturing() as captured_output: 296 | try: 297 | start = time.time() 298 | call_method(method, gt_inp) 299 | total_execution_time += time.time() - start 300 | # reset the alarm 301 | signal.alarm(0) 302 | except Exception as e: 303 | signal.alarm(0) 304 | if "timeoutexception" in repr(e).lower(): 305 | all_results.append(-3) 306 | return all_results, { 307 | "error": repr(e), 308 | "error_code": -3, 309 | "error_message": "Time Limit Exceeded", 310 | "inputs": truncatefn(gt_inp), 311 | "expected": truncatefn(gt_out), 312 | } 313 | else: 314 | all_results.append(-4) 315 | return all_results, { 316 | "error": repr(e), 317 | "error_code": -4, 318 | "error_message": "Runtime Error", 319 | "inputs": truncatefn(gt_inp), 320 | "expected": truncatefn(gt_out), 321 | } 322 | 323 | finally: 324 | signal.alarm(0) 325 | faulthandler.disable() 326 | 327 | prediction = captured_output[0] 328 | 329 | stripped_prediction_lines = get_stripped_lines(prediction) 330 | stripped_gt_out_lines = get_stripped_lines(gt_out) 331 | 332 | ## WA happens in multiple circumstances 333 | ## so cache the return to make it clean! 334 | WA_send_args = { 335 | "output": truncatefn(prediction), 336 | "inputs": truncatefn(gt_inp), 337 | "expected": truncatefn(gt_out), 338 | "error_code": -2, 339 | } 340 | 341 | if len(stripped_prediction_lines) != len(stripped_gt_out_lines): 342 | all_results.append(-2) 343 | WA_send_args["error_message"] = "Wrong answer: mismatched output length" 344 | return all_results, WA_send_args 345 | 346 | for output_line_idx, ( 347 | stripped_prediction_line, 348 | stripped_gt_out_line, 349 | ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)): 350 | WA_send_args["error_message"] = (f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}") 351 | 352 | ## CASE 1: exact match 353 | if stripped_prediction_line == stripped_gt_out_line: 354 | continue 355 | 356 | ## CASE 2: element-wise comparision 357 | ## if there are floating elements 358 | ## use `decimal` library for good floating point comparision 359 | ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True 360 | ## note that we should always be able to convert to decimals 361 | 362 | success, decimal_prediction_line = convert_line_to_decimals(stripped_prediction_line) 363 | if not success: 364 | all_results.append(-2) 365 | return all_results, WA_send_args 366 | success, decimal_gtout_line = convert_line_to_decimals(stripped_gt_out_line) 367 | if not success: 368 | all_results.append(-2) 369 | return all_results, WA_send_args 370 | 371 | if decimal_prediction_line == decimal_gtout_line: 372 | continue 373 | 374 | all_results.append(-2) 375 | return all_results, WA_send_args 376 | all_results.append(True) 377 | 378 | return all_results, {"execution time": total_execution_time} 379 | 380 | 381 | def run_test(sample, test=None, debug=False, timeout=6): 382 | """ 383 | if test(generated_code) is not None it'll try to run the code. 384 | otherwise it'll just return an input and output pair. 385 | """ 386 | signal.signal(signal.SIGALRM, timeout_handler) 387 | 388 | # Disable functionalities that can make destructive changes to the test. 389 | # max memory is set to 4GB 390 | reliability_guard() 391 | 392 | if debug: 393 | print(f"start = {datetime.now().time()}") 394 | 395 | try: 396 | in_outs = json.loads(sample["input_output"]) 397 | except ValueError as e: 398 | raise e 399 | in_outs = None 400 | 401 | if in_outs: 402 | if in_outs.get("fn_name") is None: 403 | which_type = CODE_TYPE.standard_input # Standard input 404 | method_name = None 405 | 406 | else: 407 | which_type = CODE_TYPE.call_based # Call-based 408 | method_name = in_outs["fn_name"] 409 | 410 | if debug: 411 | print(f"loaded input_output = {datetime.now().time()}") 412 | 413 | if test is None: 414 | assert False, "should not happen: test code is none" 415 | return in_outs, {"error": "no test code provided"} 416 | elif test is not None: 417 | results = [] 418 | sol = import_string 419 | if debug: 420 | print(f"loading test code = {datetime.now().time()}") 421 | 422 | if which_type == CODE_TYPE.call_based: 423 | signal.alarm(timeout) 424 | try: 425 | results, metadata = grade_call_based( 426 | code=test, 427 | all_inputs=in_outs["inputs"], 428 | all_outputs=in_outs["outputs"], 429 | fn_name=method_name, 430 | timeout=timeout, 431 | ) 432 | return results, metadata 433 | except Exception as e: 434 | return [-4], { 435 | "error_code": -4, 436 | "error_message": f"Error during testing: {e}", 437 | } 438 | finally: 439 | signal.alarm(0) 440 | elif which_type == CODE_TYPE.standard_input: 441 | # sol 442 | # if code has if __name__ == "__main__": then remove it 443 | 444 | signal.alarm(timeout) 445 | try: 446 | results, metadata = grade_stdio( 447 | code=test, 448 | all_inputs=in_outs["inputs"], 449 | all_outputs=in_outs["outputs"], 450 | timeout=timeout, 451 | ) 452 | return results, metadata 453 | except Exception as e: 454 | return [-4], { 455 | "error_code": -4, 456 | "error_message": f"Error during testing: {e}", 457 | } 458 | finally: 459 | signal.alarm(0) 460 | 461 | 462 | def reliability_guard(maximum_memory_bytes=None): 463 | """ 464 | This disables various destructive functions and prevents the generated code 465 | from interfering with the test (e.g. fork bomb, killing other processes, 466 | removing filesystem files, etc.) 467 | WARNING 468 | This function is NOT a security sandbox. Untrusted code, including, model- 469 | generated code, should not be blindly executed outside of one. See the 470 | Codex paper for more information about OpenAI's code sandbox, and proceed 471 | with caution. 472 | """ 473 | 474 | if maximum_memory_bytes is not None: 475 | import resource 476 | 477 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 478 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 479 | if not platform.uname().system == "Darwin": 480 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 481 | 482 | faulthandler.disable() 483 | 484 | import builtins 485 | 486 | # builtins.exit = None 487 | builtins.quit = None 488 | 489 | import os 490 | 491 | os.environ["OMP_NUM_THREADS"] = "1" 492 | 493 | os.kill = None 494 | os.system = None 495 | os.putenv = None 496 | os.remove = None 497 | os.removedirs = None 498 | os.rmdir = None 499 | os.fchdir = None 500 | os.setuid = None 501 | os.fork = None 502 | os.forkpty = None 503 | os.killpg = None 504 | os.rename = None 505 | os.renames = None 506 | os.truncate = None 507 | os.replace = None 508 | os.unlink = None 509 | os.fchmod = None 510 | os.fchown = None 511 | os.chmod = None 512 | os.chown = None 513 | os.chroot = None 514 | os.fchdir = None 515 | os.lchflags = None 516 | os.lchmod = None 517 | os.lchown = None 518 | os.getcwd = None 519 | os.chdir = None 520 | 521 | import shutil 522 | 523 | shutil.rmtree = None 524 | shutil.move = None 525 | shutil.chown = None 526 | 527 | import subprocess 528 | 529 | subprocess.Popen = None # type: ignore 530 | 531 | __builtins__["help"] = None 532 | 533 | import sys 534 | 535 | sys.modules["ipdb"] = None 536 | sys.modules["joblib"] = None 537 | sys.modules["resource"] = None 538 | sys.modules["psutil"] = None 539 | sys.modules["tkinter"] = None -------------------------------------------------------------------------------- /PromptCoT_Mamba/math_opensource_utils/math_equivalence.py: -------------------------------------------------------------------------------- 1 | import re 2 | import signal 3 | from typing import Dict, List, Optional 4 | # from timeout_decorator import timeout 5 | # from timeout_decorator import timeout as timeout_decorator 6 | 7 | try: 8 | import sympy 9 | from sympy.parsing.latex import parse_latex 10 | from sympy.simplify import simplify 11 | 12 | # Set global timeout for SymPy operations 13 | sympy.TIMEOUT = 5 # Set global timeout in seconds 14 | 15 | # Or specifically for simplify operations 16 | simplify.TIMEOUT = 5 # Set timeout for simplify operations 17 | 18 | # If needed, you can also try these settings 19 | sympy.core.cache.NO_CACHE = True # Disable caching 20 | # sympy.core.evalf.maxprec = 1000 # Limit precision 21 | 22 | 23 | except ModuleNotFoundError: 24 | raise ModuleNotFoundError( 25 | "`sympy` is required for generating translation task prompt templates. \ 26 | please install sympy via pip install lm-eval[math] or pip install -e .[math]", 27 | ) 28 | 29 | 30 | def _fix_fracs(string): 31 | substrs = string.split("\\frac") 32 | new_str = substrs[0] 33 | if len(substrs) > 1: 34 | substrs = substrs[1:] 35 | for substr in substrs: 36 | new_str += "\\frac" 37 | if substr[0] == "{": 38 | new_str += substr 39 | else: 40 | try: 41 | assert len(substr) >= 2 42 | except: 43 | return string 44 | a = substr[0] 45 | b = substr[1] 46 | if b != "{": 47 | if len(substr) > 2: 48 | post_substr = substr[2:] 49 | new_str += "{" + a + "}{" + b + "}" + post_substr 50 | else: 51 | new_str += "{" + a + "}{" + b + "}" 52 | else: 53 | if len(substr) > 2: 54 | post_substr = substr[2:] 55 | new_str += "{" + a + "}" + b + post_substr 56 | else: 57 | new_str += "{" + a + "}" + b 58 | string = new_str 59 | return string 60 | 61 | 62 | def _fix_a_slash_b(string): 63 | if len(string.split("/")) != 2: 64 | return string 65 | a = string.split("/")[0] 66 | b = string.split("/")[1] 67 | try: 68 | a = int(a) 69 | b = int(b) 70 | assert string == "{}/{}".format(a, b) 71 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 72 | return new_string 73 | except: 74 | return string 75 | 76 | 77 | def _remove_right_units(string): 78 | # "\\text{ " only ever occurs (at least in the val set) when describing units 79 | if "\\text{ " in string: 80 | splits = string.split("\\text{ ") 81 | assert len(splits) == 2 82 | return splits[0] 83 | else: 84 | return string 85 | 86 | 87 | def _fix_sqrt(string): 88 | if "\\sqrt" not in string: 89 | return string 90 | splits = string.split("\\sqrt") 91 | new_string = splits[0] 92 | for split in splits[1:]: 93 | if split[0] != "{": 94 | a = split[0] 95 | new_substr = "\\sqrt{" + a + "}" + split[1:] 96 | else: 97 | new_substr = "\\sqrt" + split 98 | new_string += new_substr 99 | return new_string 100 | 101 | 102 | def _strip_string(string): 103 | # linebreaks 104 | string = string.replace("\n", "") 105 | # print(string) 106 | 107 | # remove inverse spaces 108 | string = string.replace("\\!", "") 109 | # print(string) 110 | 111 | # replace \\ with \ 112 | string = string.replace("\\\\", "\\") 113 | # print(string) 114 | 115 | # replace tfrac and dfrac with frac 116 | string = string.replace("tfrac", "frac") 117 | string = string.replace("dfrac", "frac") 118 | # print(string) 119 | 120 | # remove \left and \right 121 | string = string.replace("\\left", "") 122 | string = string.replace("\\right", "") 123 | # print(string) 124 | 125 | # Remove circ (degrees) 126 | string = string.replace("^{\\circ}", "") 127 | string = string.replace("^\\circ", "") 128 | 129 | # remove dollar signs 130 | string = string.replace("\\$", "") 131 | 132 | # remove units (on the right) 133 | string = _remove_right_units(string) 134 | 135 | # remove percentage 136 | string = string.replace("\\%", "") 137 | string = string.replace("\%", "") 138 | 139 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 140 | string = string.replace(" .", " 0.") 141 | string = string.replace("{.", "{0.") 142 | # if empty, return empty string 143 | if len(string) == 0: 144 | return string 145 | if string[0] == ".": 146 | string = "0" + string 147 | 148 | # to consider: get rid of e.g. "k = " or "q = " at beginning 149 | if len(string.split("=")) == 2: 150 | if len(string.split("=")[0]) <= 2: 151 | string = string.split("=")[1] 152 | 153 | # fix sqrt3 --> sqrt{3} 154 | string = _fix_sqrt(string) 155 | 156 | # remove spaces 157 | string = string.replace(" ", "") 158 | 159 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 160 | string = _fix_fracs(string) 161 | 162 | # manually change 0.5 --> \frac{1}{2} 163 | if string == "0.5": 164 | string = "\\frac{1}{2}" 165 | 166 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 167 | string = _fix_a_slash_b(string) 168 | 169 | return string 170 | 171 | 172 | def is_equiv(str1, str2, verbose=False): 173 | if str1 is None and str2 is None: 174 | print("WARNING: Both None") 175 | return True 176 | if str1 is None or str2 is None: 177 | return False 178 | 179 | try: 180 | ss1 = _strip_string(str1) 181 | ss2 = _strip_string(str2) 182 | if verbose: 183 | print(ss1, ss2) 184 | return ss1 == ss2 185 | except: 186 | return str1 == str2 187 | 188 | 189 | class timeout: 190 | def __init__(self, seconds=1, error_message="Timeout"): 191 | self.seconds = seconds 192 | self.error_message = error_message 193 | 194 | def handle_timeout(self, signum, frame): 195 | raise TimeoutError(self.error_message) 196 | 197 | def __enter__(self): 198 | signal.signal(signal.SIGALRM, self.handle_timeout) 199 | signal.alarm(self.seconds) 200 | 201 | def __exit__(self, type, value, traceback): 202 | signal.alarm(0) 203 | 204 | def is_equiv_minerva(x1: str, x2: str) -> bool: 205 | """ 206 | x1 and x2 are normalized latex string 207 | """ 208 | try: 209 | with timeout(seconds=5): 210 | try: 211 | parsed_x1 = parse_latex(normalize_final_answer(x1)) 212 | parsed_x2 = parse_latex(normalize_final_answer(x2)) 213 | except ( 214 | sympy.parsing.latex.errors.LaTeXParsingError, 215 | sympy.SympifyError, 216 | TypeError, 217 | AttributeError, 218 | ): 219 | # print(f"couldn't parse one of {x1} or {x2}") 220 | return False 221 | 222 | try: 223 | diff = parsed_x1 - parsed_x2 224 | except TypeError: 225 | # print(f"couldn't subtract {x1} and {x2}") 226 | return False 227 | 228 | try: 229 | if sympy.simplify(diff) == 0: 230 | return True 231 | else: 232 | return False 233 | except ValueError: 234 | # print( 235 | # f"Had some trouble simplifying when comparing {x1} and {x2}" 236 | # ) 237 | return False 238 | except TimeoutError: 239 | # print(f"Timed out comparing {x1} and {x2}") 240 | return False 241 | except ImportError as e: 242 | print(e) 243 | raise 244 | except Exception as e: 245 | # print(f"Failed comparing {x1} and {x2} with {e}") 246 | return False 247 | 248 | 249 | SUBSTITUTIONS = [ 250 | ("an ", ""), 251 | ("a ", ""), 252 | (".$", "$"), 253 | ("\\$", ""), 254 | (r"\ ", ""), 255 | (" ", ""), 256 | ("mbox", "text"), 257 | (",\\text{and}", ","), 258 | ("\\text{and}", ","), 259 | ("\\text{m}", "\\text{}"), 260 | ] 261 | REMOVED_EXPRESSIONS = [ 262 | "square", 263 | "ways", 264 | "integers", 265 | "dollars", 266 | "mph", 267 | "inches", 268 | "ft", 269 | "hours", 270 | "km", 271 | "units", 272 | "\\ldots", 273 | "sue", 274 | "points", 275 | "feet", 276 | "minutes", 277 | "digits", 278 | "cents", 279 | "degrees", 280 | "cm", 281 | "gm", 282 | "pounds", 283 | "meters", 284 | "meals", 285 | "edges", 286 | "students", 287 | "childrentickets", 288 | "multiples", 289 | "\\text{s}", 290 | "\\text{.}", 291 | "\\text{\ns}", 292 | "\\text{}^2", 293 | "\\text{}^3", 294 | "\\text{\n}", 295 | "\\text{}", 296 | r"\mathrm{th}", 297 | r"^\circ", 298 | r"^{\circ}", 299 | r"\;", 300 | r",\!", 301 | "{,}", 302 | '"', 303 | "\\dots", 304 | ] 305 | 306 | 307 | def normalize_final_answer(final_answer: str) -> str: 308 | """ 309 | Normalize a final answer to a quantitative reasoning question. 310 | 311 | Copied character for character from appendix D of Lewkowycz et al. (2022) 312 | """ 313 | final_answer = final_answer.split("=")[-1] 314 | 315 | for before, after in SUBSTITUTIONS: 316 | final_answer = final_answer.replace(before, after) 317 | for expr in REMOVED_EXPRESSIONS: 318 | final_answer = final_answer.replace(expr, "") 319 | 320 | # Extract answer that is in LaTeX math, is bold, 321 | # is surrounded by a box, etc. 322 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 323 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 324 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 325 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 326 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 327 | 328 | # Normalize shorthand TeX: 329 | # \fracab -> \frac{a}{b} 330 | # \frac{abc}{bef} -> \frac{abc}{bef} 331 | # \fracabc -> \frac{a}{b}c 332 | # \sqrta -> \sqrt{a} 333 | # \sqrtab -> sqrt{a}b 334 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 335 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 336 | final_answer = final_answer.replace("$", "") 337 | 338 | # Normalize 100,000 -> 100000 339 | if final_answer.replace(",", "").isdigit(): 340 | final_answer = final_answer.replace(",", "") 341 | 342 | return final_answer -------------------------------------------------------------------------------- /PromptCoT_Mamba/math_opensource_utils/util.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | 3 | def remove_boxed(s): 4 | left = "\\boxed{" 5 | try: 6 | assert s[:len(left)] == left 7 | assert s[-1] == "}" 8 | return s[len(left):-1] 9 | except: 10 | return None 11 | 12 | def last_boxed_only(sample): 13 | """ 14 | Given a (q,a) sample, filter the answers so that they only contain 15 | the last \boxed{...} or \fbox{...} element 16 | """ 17 | q, a = sample 18 | a = last_boxed_only_string(a) 19 | if a == None: 20 | return None 21 | return (q, a) 22 | 23 | 24 | def last_boxed_only_string(string): 25 | idx = string.rfind("\\boxed") 26 | if idx < 0: 27 | idx = string.rfind("\\fbox") 28 | if idx < 0: 29 | return None 30 | 31 | i = idx 32 | right_brace_idx = None 33 | num_left_braces_open = 0 34 | while i < len(string): 35 | if string[i] == "{": 36 | num_left_braces_open += 1 37 | if string[i] == "}": 38 | num_left_braces_open -= 1 39 | if num_left_braces_open == 0: 40 | right_brace_idx = i 41 | break 42 | i += 1 43 | 44 | if right_brace_idx == None: 45 | retval = None 46 | else: 47 | retval = string[idx:right_brace_idx + 1] 48 | 49 | return retval 50 | 51 | 52 | def first_boxed_only_string(string): 53 | # Find the first occurrence of \boxed or \fbox 54 | idx_boxed = string.find("\\boxed") 55 | idx_fbox = string.find("\\fbox") 56 | 57 | # Determine which comes first (if either exists) 58 | if idx_boxed < 0 and idx_fbox < 0: 59 | return None 60 | elif idx_boxed < 0: 61 | idx = idx_fbox 62 | elif idx_fbox < 0: 63 | idx = idx_boxed 64 | else: 65 | idx = min(idx_boxed, idx_fbox) 66 | 67 | # Find matching closing brace 68 | i = idx 69 | right_brace_idx = None 70 | num_left_braces_open = 0 71 | 72 | while i < len(string): 73 | if string[i] == "{": 74 | num_left_braces_open += 1 75 | elif string[i] == "}": 76 | num_left_braces_open -= 1 77 | if num_left_braces_open == 0: 78 | right_brace_idx = i 79 | break 80 | i += 1 81 | 82 | if right_brace_idx is None: 83 | return None 84 | 85 | return string[idx:right_brace_idx + 1] 86 | 87 | 88 | def only_until_first_boxed_from_tokens(string, tokens): 89 | idx = string.find("\\boxed") 90 | if idx < 0: 91 | idx = string.find("\\fbox") 92 | if idx < 0: 93 | return None 94 | 95 | cum_length = 0 96 | for i, t in enumerate(tokens): 97 | cum_length += len(t) 98 | if cum_length >= idx: 99 | break 100 | 101 | return tokens[:i] 102 | 103 | 104 | def clean_numbers(sample): 105 | if not sample: 106 | return None 107 | new_sample = list() 108 | for s in sample: 109 | new_sample.append(_clean_numbers(s)) 110 | 111 | return tuple(new_sample) 112 | 113 | 114 | def _clean_numbers(string): 115 | """ 116 | Clean Numbers in the given string 117 | 118 | >>> _clean_numbers(None, "Hello 123") 119 | 'Hello 123' 120 | >>> _clean_numbers(None, "Hello 1234") 121 | 'Hello 1,234' 122 | >>> _clean_numbers(None, "Hello 1234324asdasd") 123 | 'Hello 1,234,324asdasd' 124 | """ 125 | num_prev_digits = 0 126 | new_string = "" 127 | for i, c in enumerate(string): 128 | # isdigit() doesnt work here because of weird unicode chars. 129 | if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}: 130 | num_prev_digits += 1 131 | else: 132 | if num_prev_digits > 3: 133 | # Some fixing 134 | string_number = new_string[-num_prev_digits:] 135 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 136 | num_prev_digits = 0 137 | new_string += c 138 | 139 | if num_prev_digits > 3: 140 | # Some fixing 141 | string_number = new_string[-num_prev_digits:] 142 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 143 | 144 | return new_string -------------------------------------------------------------------------------- /PromptCoT_Mamba/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.6.0 2 | addict==2.4.0 3 | aiohappyeyeballs==2.6.1 4 | aiohttp==3.11.14 5 | aiohttp-cors==0.8.0 6 | aiosignal==1.3.2 7 | airportsdata==20250224 8 | annotated-types==0.7.0 9 | anyio==4.9.0 10 | astor==0.8.1 11 | attrs==25.3.0 12 | blake3==1.0.4 13 | cachetools==5.5.2 14 | causal-conv1d==1.5.0 15 | certifi==2025.1.31 16 | charset-normalizer==3.4.1 17 | click==8.1.8 18 | cloudpickle==3.1.1 19 | cmake==3.31.6 20 | colorful==0.5.6 21 | compressed-tensors==0.9.1 22 | cupy-cuda12x==13.4.1 23 | datasets==3.2.0 24 | deepspeed==0.16.7 25 | depyf==0.18.0 26 | dill==0.3.8 27 | diskcache==5.6.3 28 | distlib==0.3.9 29 | distro==1.9.0 30 | dnspython==2.7.0 31 | einops==0.8.1 32 | email_validator==2.2.0 33 | fastapi==0.115.11 34 | fastapi-cli==0.0.7 35 | fastrlock==0.8.3 36 | filelock==3.18.0 37 | fire==0.7.0 38 | flash_attn==2.7.4.post1 39 | frozenlist==1.5.0 40 | fsspec==2024.9.0 41 | gguf==0.10.0 42 | google-api-core==2.24.2 43 | google-auth==2.38.0 44 | googleapis-common-protos==1.69.2 45 | grpcio==1.71.0 46 | h11==0.14.0 47 | hjson==3.1.0 48 | httpcore==1.0.7 49 | httptools==0.6.4 50 | httpx==0.28.1 51 | huggingface-hub==0.29.3 52 | idna==3.10 53 | immutables==0.20 54 | importlib_metadata==8.6.1 55 | iniconfig==2.1.0 56 | intel-cmplr-lib-ur==2025.0.5 57 | intel-openmp==2025.0.5 58 | interegular==0.3.3 59 | Jinja2==3.1.6 60 | jiter==0.9.0 61 | joblib==1.4.2 62 | jsonschema==4.23.0 63 | jsonschema-specifications==2024.10.1 64 | lark==1.2.2 65 | llvmlite==0.43.0 66 | lm-format-enforcer==0.10.11 67 | mamba-ssm==2.2.3 68 | mambapy==1.2.0 69 | markdown-it-py==3.0.0 70 | MarkupSafe==3.0.2 71 | mdurl==0.1.2 72 | mistral_common==1.5.4 73 | mkl==2025.0.1 74 | mkl-include==2025.0.1 75 | modelscope==1.24.0 76 | mpmath==1.3.0 77 | msgpack==1.1.0 78 | msgspec==0.19.0 79 | multidict==6.2.0 80 | multiprocess==0.70.16 81 | nest-asyncio==1.6.0 82 | networkx==3.4.2 83 | ninja==1.11.1.3 84 | numba==0.60.0 85 | numpy==1.26.4 86 | nvidia-cublas-cu12==12.1.3.1 87 | nvidia-cuda-cupti-cu12==12.1.105 88 | nvidia-cuda-nvrtc-cu12==12.1.105 89 | nvidia-cuda-runtime-cu12==12.1.105 90 | nvidia-cudnn-cu12==9.1.0.70 91 | nvidia-cufft-cu12==11.0.2.54 92 | nvidia-curand-cu12==10.3.2.106 93 | nvidia-cusolver-cu12==11.4.5.107 94 | nvidia-cusparse-cu12==12.1.0.106 95 | nvidia-ml-py==12.570.86 96 | nvidia-nccl-cu12==2.21.5 97 | nvidia-nvjitlink-cu12==12.8.93 98 | nvidia-nvtx-cu12==12.1.105 99 | openai==1.66.3 100 | opencensus==0.11.4 101 | opencensus-context==0.1.3 102 | opencv-python-headless==4.11.0.86 103 | outlines==0.1.11 104 | outlines_core==0.1.26 105 | packaging==24.2 106 | pandas==2.2.3 107 | partial-json-parser==0.2.1.1.post5 108 | peft==0.15.2 109 | pillow==11.1.0 110 | platformdirs==4.3.6 111 | pluggy==1.5.0 112 | prometheus-fastapi-instrumentator==7.0.2 113 | prometheus_client==0.21.1 114 | propcache==0.3.0 115 | proto-plus==1.26.1 116 | protobuf==6.30.1 117 | psutil==7.0.0 118 | py-cpuinfo==9.0.0 119 | py-spy==0.4.0 120 | pyarrow==19.0.1 121 | pyasn1==0.6.1 122 | pyasn1_modules==0.4.1 123 | pybind11==2.13.6 124 | pycountry==24.6.1 125 | pydantic==2.10.6 126 | pydantic_core==2.27.2 127 | Pygments==2.19.1 128 | pytest==8.3.5 129 | python-dateutil==2.9.0.post0 130 | python-dotenv==1.0.1 131 | python-multipart==0.0.20 132 | pytz==2025.1 133 | PyYAML==6.0.2 134 | pyzmq==26.3.0 135 | ray==2.40.0 136 | referencing==0.36.2 137 | regex==2024.11.6 138 | requests==2.32.3 139 | rich==14.0.0 140 | rich-toolkit==0.14.1 141 | rpds-py==0.23.1 142 | rsa==4.9 143 | safetensors==0.5.3 144 | scikit-learn==1.6.1 145 | scipy==1.15.2 146 | sentencepiece==0.2.0 147 | shellingham==1.5.4 148 | simplejson==3.20.1 149 | six==1.17.0 150 | smart-open==7.1.0 151 | sniffio==1.3.1 152 | sortedcontainers==2.4.0 153 | starlette==0.46.1 154 | str2bool==1.1 155 | sympy==1.13.1 156 | tbb==2022.0.0 157 | tcmlib==1.2.0 158 | termcolor==2.5.0 159 | threadpoolctl==3.6.0 160 | tiktoken==0.9.0 161 | timm==1.0.15 162 | tokenizers==0.21.1 163 | torch==2.5.1 164 | torchaudio==2.5.1 165 | torchvision==0.20.1 166 | tqdm==4.66.5 167 | transformers==4.49.0 168 | triton==3.1.0 169 | trl==0.16.1 170 | typer==0.15.2 171 | typing_extensions==4.12.2 172 | tzdata==2025.1 173 | umf==0.9.1 174 | urllib3==2.3.0 175 | uvicorn==0.34.0 176 | uvloop==0.21.0 177 | virtualenv==20.29.3 178 | vllm==0.7.3 179 | watchfiles==1.0.4 180 | websockets==15.0.1 181 | wrapt==1.17.2 182 | xformers==0.0.28.post3 183 | xgrammar==0.1.11 184 | xxhash==3.5.0 185 | yarl==1.18.3 186 | zipp==3.21.0 187 | -------------------------------------------------------------------------------- /PromptCoT_Mamba/train.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, load_from_disk 2 | import warnings 3 | from dataclasses import dataclass, field 4 | from typing import Iterable, Optional, Union, Dict, Any, List, Tuple, Sequence 5 | 6 | import sys 7 | import copy 8 | import torch 9 | import datetime 10 | 11 | torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=12000)) 12 | 13 | 14 | import random 15 | import yaml 16 | import transformers 17 | from transformers.hf_argparser import DataClass, DataClassType 18 | import os 19 | 20 | from trl import SFTTrainer 21 | 22 | IGNORE_INDEX = -100 23 | 24 | @dataclass 25 | class DataArguments: 26 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 27 | 28 | 29 | @dataclass 30 | class ModelConfig: 31 | model_name_or_path: Optional[str] = field(default="/ossfs/workspace/nas/xueliang/hf_models/Meta-Llama-3.1-8B") 32 | tokenizer_path: Optional[str] = field(default="/ossfs/workspace/nas/xueliang/hf_models/Meta-Llama-3.1-8B") 33 | 34 | 35 | @dataclass 36 | class SFTConfig(transformers.TrainingArguments): 37 | # Parameters that control the model 38 | model_init_kwargs: Optional[Dict[str, Any]] = field( 39 | default=None, 40 | metadata={ 41 | "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " 42 | "the `SFTTrainer` is provided as a string." 43 | }, 44 | ) 45 | 46 | # Parameters that control the data preprocessing 47 | dataset_text_field: str = field( 48 | default="text", 49 | metadata={"help": "Name of the column that contains text data in the dataset."}, 50 | ) 51 | dataset_kwargs: Optional[Dict[str, Any]] = field( 52 | default=None, 53 | metadata={ 54 | "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " 55 | "`skip_prepare_dataset`." 56 | }, 57 | ) 58 | dataset_num_proc: Optional[int] = field( 59 | default=None, 60 | metadata={"help": "Number of processes to use for processing the dataset."}, 61 | ) 62 | pad_token: Optional[str] = field( 63 | default=None, 64 | metadata={ 65 | "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " 66 | "is also `None`, it falls back to `processing_class.eos_token`." 67 | }, 68 | ) 69 | max_length: Optional[int] = field( 70 | default=1024, 71 | metadata={ 72 | "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" 73 | "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " 74 | "sequence length." 75 | }, 76 | ) 77 | packing: bool = field( 78 | default=False, 79 | metadata={ 80 | "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define " 81 | "sequence length." 82 | }, 83 | ) 84 | padding_free: bool = field( 85 | default=False, 86 | metadata={ 87 | "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " 88 | "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, " 89 | "this is only supported with the `flash_attention_2` attention implementation, which can efficiently " 90 | "handle the flattened batch structure." 91 | }, 92 | ) 93 | eval_packing: Optional[bool] = field( 94 | default=None, 95 | metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, 96 | ) 97 | 98 | # Parameters that control the training 99 | learning_rate: float = field( 100 | default=2.0e-5, 101 | metadata={ 102 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 103 | "`TrainingArguments`." 104 | }, 105 | ) 106 | 107 | # Deprecated parameters 108 | dataset_batch_size: Optional[int] = field( 109 | default=None, 110 | metadata={ 111 | "help": "This parameter is deprecated and will be removed in version 0.18.0. You can safely remove this " 112 | "parameter from your configuration." 113 | }, 114 | ) 115 | num_of_sequences: Optional[int] = field( 116 | default=None, 117 | metadata={ 118 | "help": "This parameter is deprecated and will be removed in version 0.18.0. Use `max_length` instead, " 119 | "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which referred " 120 | "to string sequences." 121 | }, 122 | ) 123 | chars_per_token: Optional[float] = field( 124 | default=None, 125 | metadata={ 126 | "help": "This parameter is deprecated and will be removed in version 0.18.0. If you want to customize the " 127 | "packing length, use `max_length`." 128 | }, 129 | ) 130 | max_seq_length: Optional[int] = field( 131 | default=None, 132 | metadata={ 133 | "help": "This parameter is deprecated and will be removed in version 0.20.0. Use `max_length` instead." 134 | }, 135 | ) 136 | use_liger: Optional[bool] = field( 137 | default=None, 138 | metadata={ 139 | "help": "This parameter is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` " 140 | "instead." 141 | }, 142 | ) 143 | 144 | def __post_init__(self): 145 | super().__post_init__() 146 | 147 | if self.dataset_batch_size is not None: 148 | warnings.warn( 149 | "`dataset_batch_size` is deprecated and will be removed in version 0.18.0. You can safely remove this " 150 | "parameter from your configuration.", 151 | DeprecationWarning, 152 | ) 153 | 154 | if self.num_of_sequences is not None: 155 | warnings.warn( 156 | "`num_of_sequences` is deprecated and will be removed in version 0.18.0. Use `max_length` instead, " 157 | "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which " 158 | "referred to string sequences.", 159 | DeprecationWarning, 160 | ) 161 | 162 | if self.chars_per_token is not None: 163 | warnings.warn( 164 | "`chars_per_token` is deprecated and will be removed in version 0.18.0. If you want to customize the " 165 | "packing length, use `max_length`.", 166 | DeprecationWarning, 167 | ) 168 | 169 | if self.max_seq_length is not None: 170 | warnings.warn( 171 | "`max_seq_length` is deprecated and will be removed in version 0.20.0. Use `max_length` instead.", 172 | DeprecationWarning, 173 | ) 174 | self.max_length = self.max_seq_length 175 | 176 | if self.use_liger is not None: 177 | warnings.warn( 178 | "`use_liger` is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` instead.", 179 | DeprecationWarning, 180 | ) 181 | self.use_liger_kernel = self.use_liger 182 | 183 | 184 | class TrlParser(transformers.HfArgumentParser): 185 | 186 | def __init__( 187 | self, 188 | dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, 189 | **kwargs, 190 | ): 191 | # Make sure dataclass_types is an iterable 192 | if dataclass_types is None: 193 | dataclass_types = [] 194 | elif not isinstance(dataclass_types, Iterable): 195 | dataclass_types = [dataclass_types] 196 | 197 | # Check that none of the dataclasses have the "config" field 198 | for dataclass_type in dataclass_types: 199 | if "config" in dataclass_type.__dataclass_fields__: 200 | raise ValueError( 201 | f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " 202 | f"config file path and should not be used in the dataclass." 203 | ) 204 | 205 | super().__init__(dataclass_types=dataclass_types, **kwargs) 206 | 207 | def parse_args_and_config( 208 | self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False 209 | ) -> Tuple[DataClass, ...]: 210 | """ 211 | Parse command-line args and config file into instances of the specified dataclass types. 212 | 213 | This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file 214 | specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the 215 | default values in the dataclasses. Command line arguments can override values set by the config file. The 216 | method also sets any environment variables specified in the `env` field of the config file. 217 | """ 218 | args = list(args) if args is not None else sys.argv[1:] 219 | if "--config" in args: 220 | # Get the config file path from 221 | config_index = args.index("--config") 222 | args.pop(config_index) # remove the --config flag 223 | config_path = args.pop(config_index) # get the path to the config file 224 | with open(config_path) as yaml_file: 225 | config = yaml.safe_load(yaml_file) 226 | 227 | # Set the environment variables specified in the config file 228 | if "env" in config: 229 | env_vars = config.pop("env", {}) 230 | if not isinstance(env_vars, dict): 231 | raise ValueError("`env` field should be a dict in the YAML file.") 232 | for key, value in env_vars.items(): 233 | os.environ[key] = str(value) 234 | 235 | # Set the defaults from the config values 236 | config_remaining_strings = self.set_defaults_with_config(**config) 237 | else: 238 | config_remaining_strings = [] 239 | 240 | # Parse the arguments from the command line 241 | output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) 242 | 243 | # Merge remaining strings from the config file with the remaining strings from the command line 244 | if return_remaining_strings: 245 | args_remaining_strings = output[-1] 246 | return output[:-1] + (config_remaining_strings + args_remaining_strings,) 247 | else: 248 | return output 249 | 250 | def set_defaults_with_config(self, **kwargs) -> List[str]: 251 | """ 252 | Overrides the parser's default values with those provided via keyword arguments. 253 | 254 | Any argument with an updated default will also be marked as not required 255 | if it was previously required. 256 | 257 | Returns a list of strings that were not consumed by the parser. 258 | """ 259 | # If an argument is in the kwargs, update its default and set it as not required 260 | for action in self._actions: 261 | if action.dest in kwargs: 262 | action.default = kwargs.pop(action.dest) 263 | action.required = False 264 | remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]] 265 | return remaining_strings 266 | 267 | 268 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 269 | tokenized_list = [ 270 | tokenizer( 271 | text, 272 | return_tensors="pt", 273 | padding="longest", 274 | max_length=tokenizer.model_max_length, 275 | truncation=True, 276 | ) 277 | for text in strings 278 | ] 279 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 280 | input_ids_lens = labels_lens = [ 281 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 282 | ] 283 | return dict( 284 | input_ids=input_ids, 285 | labels=labels, 286 | input_ids_lens=input_ids_lens, 287 | labels_lens=labels_lens, 288 | ) 289 | 290 | 291 | def preprocess( 292 | sources: Sequence[str], 293 | targets: Sequence[str], 294 | tokenizer: transformers.PreTrainedTokenizer, 295 | ) -> Dict: 296 | examples = [s + t for s, t in zip(sources, targets)] 297 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 298 | input_ids = examples_tokenized["input_ids"] 299 | labels = copy.deepcopy(input_ids) 300 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 301 | label[:source_len] = IGNORE_INDEX 302 | return dict(input_ids=input_ids, labels=labels) 303 | 304 | 305 | @dataclass 306 | class DataCollatorForSuperviseDataset(object): 307 | tokenizer: transformers.PreTrainedTokenizer 308 | 309 | def __call__(self, items: Sequence[Dict]) -> Dict[str, torch.Tensor]: 310 | input_ids, labels = tuple([item[key] for item in items] for key in ("input_ids", "labels")) 311 | input_ids = [torch.tensor(x) for x in input_ids] 312 | input_ids = torch.nn.utils.rnn.pad_sequence( 313 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 314 | ) 315 | labels = [torch.tensor(x) for x in labels] 316 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 317 | return dict( 318 | input_ids=input_ids, 319 | labels=labels, 320 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id) 321 | ) 322 | 323 | 324 | class SkipBadBF16BatchTrainer(SFTTrainer): 325 | def __init__(self, *args, max_bad_skips: int = 10000, **kwargs): 326 | super().__init__(*args, **kwargs) 327 | self.bad_skips = 0 328 | self.max_bad_skips = max_bad_skips 329 | 330 | def training_step(self, model, inputs, num_items_in_batch=None): 331 | try: 332 | # forward + backward + (on accumulation boundary) engine.step() 333 | return super().training_step(model, inputs, num_items_in_batch) 334 | except RuntimeError as e: 335 | msg = str(e).lower() 336 | if "expected a floating point or complex tensor as input" in msg and "got long" in msg: 337 | # 1) count & warn 338 | self.bad_skips += 1 339 | warnings.warn( 340 | f"[SkipBF16] Encountered Long‐tensor in vector_norm (batch #{self.bad_skips}), skipping.", 341 | UserWarning, 342 | ) 343 | 344 | # 2) clear any gradients to avoid leakage 345 | model.zero_grad() 346 | if hasattr(self, "optimizer") and self.optimizer is not None: 347 | self.optimizer.zero_grad() 348 | 349 | # 3) safety bail if too many 350 | if self.bad_skips >= self.max_bad_skips: 351 | raise RuntimeError( 352 | f"Too many bf16_vector_norm skips ({self.bad_skips}). Aborting to avoid infinite loop." 353 | ) 354 | 355 | # 4) return zero‐loss so Trainer moves to next batch 356 | return torch.tensor(0.0, device=self.args.device) 357 | # any other error should still bubble up 358 | raise 359 | 360 | 361 | def train_tokenize_function(examples, tokenizer): 362 | sources = [prompt for prompt in examples["prompt"]] 363 | targets = [f"{output}{tokenizer.eos_token}" for output in examples["completion"]] 364 | return preprocess(sources, targets, tokenizer) 365 | 366 | 367 | def train(): 368 | parser = TrlParser((DataArguments, SFTConfig, ModelConfig)) 369 | data_args, training_args, model_args = parser.parse_args_and_config() 370 | 371 | if "gemma" in model_args.model_name_or_path: 372 | model = transformers.AutoModelForCausalLM.from_pretrained( 373 | model_args.model_name_or_path, 374 | torch_dtype="auto", 375 | ) 376 | else: 377 | model = transformers.AutoModelForCausalLM.from_pretrained( 378 | model_args.model_name_or_path, 379 | # attn_implementation="flash_attention_2", 380 | # torch_dtype=torch.float16, # works for first 2000 steps 381 | torch_dtype="auto", 382 | use_cache=False if training_args.gradient_checkpointing else True, 383 | ) 384 | 385 | tokenizer = transformers.AutoTokenizer.from_pretrained( 386 | model_args.tokenizer_path, 387 | model_max_length=training_args.max_length, 388 | ) 389 | if tokenizer.pad_token is None: 390 | tokenizer.pad_token = tokenizer.eos_token 391 | tokenizer.pad_token_id = tokenizer.eos_token_id 392 | 393 | raw_train_dataset = load_dataset( 394 | "json", 395 | data_files=data_args.data_path, 396 | split="train", 397 | ) 398 | 399 | if training_args.local_rank > 0: 400 | torch.distributed.barrier() 401 | 402 | train_dataset = raw_train_dataset.map( 403 | train_tokenize_function, 404 | batched=True, 405 | batch_size=4096, 406 | num_proc=1, 407 | remove_columns=raw_train_dataset.column_names, 408 | desc="Running tokenizer on train dataset", 409 | fn_kwargs={ 410 | "tokenizer": tokenizer, 411 | } 412 | ) 413 | 414 | if training_args.local_rank == 0: 415 | torch.distributed.barrier() 416 | 417 | if training_args.local_rank == 0: 418 | print(len(train_dataset)) 419 | for index in random.sample(range(len(train_dataset)), 3): 420 | print(f"Sample {index} of the training set: {train_dataset[index]}.") 421 | 422 | data_collator = DataCollatorForSuperviseDataset(tokenizer=tokenizer) 423 | trainer = SkipBadBF16BatchTrainer( 424 | model=model, 425 | processing_class=tokenizer, 426 | args=training_args, 427 | train_dataset=train_dataset, 428 | data_collator=data_collator 429 | ) 430 | 431 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 432 | trainer.save_model(training_args.output_dir) 433 | 434 | 435 | if __name__ == "__main__": 436 | train() 437 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **PromptCoT & PromptCoT-Mamba: Advancing the Frontiers of Reasoning** 2 | 3 | --- 4 | 5 | ## **News** 6 | 7 | * **May 30, 2025**: PromptCoT-Mamba released! Introducing an attention-free foundation model for reasoning tasks. 8 | * **Apr 11, 2025**: PromptCoT-QwQ-32B model and its training data released, achieving new state-of-the-art results. 9 | * **Mar 7, 2025**: PromptCoT project launched, including the problem generation model, distilled models (PromptCoT-DS series), and associated datasets. 10 | 11 | --- 12 | 13 | ## **Overview** 14 | 15 | This repository unifies two synergistic projects aimed at advancing the frontiers of mathematical and code reasoning in Large Language Models (LLMs): **PromptCoT** and **PromptCoT-Mamba**. 16 | 17 | **PromptCoT (Synthesizing Olympiad-Level Problems for Mathematical Reasoning in Large Language Models)** addresses the critical challenge of acquiring high-quality, complex problems for training advanced LLMs. It introduces a novel methodology to systematically generate Olympiad-level mathematical problems by modeling the rationale behind expert problem design. This approach not only enhances problem diversity and difficulty but also ensures logical consistency in problem construction, providing a scalable solution for creating robust training datasets. 18 | 19 | **PromptCoT-Mamba (Scaling Reasoning without Attention)** leverages the problem generation capabilities of the PromptCoT pipeline to train **PromptCoT-Mamba-7B**, the first attention-free foundation model based on the Mamba-2 architecture. This model demonstrates that structured training curricula can enable attention-free models to surpass strong Transformer baselines on a wide array of competition-level math and code reasoning tasks, all while maintaining constant-memory inference without KV caching. 20 | 21 | Together, these projects offer a powerful suite of tools, models, and datasets for researchers and developers working on the cutting edge of AI reasoning. 22 | 23 | --- 24 | 25 | ## **Highlights & Key Results** 26 | 27 | ### **1. PromptCoT: Problem Generation & Distilled Models** 28 | 29 | * **✨ The Missing Piece for Test-Time Scaling**: A lightweight yet powerful problem generation model enabling the construction of prompt sets at any scale with sufficient quality, perfect for SFT or RL post-training. 30 | * **📖 A Fully Open Project**: All models (generation, distilled LLMs) and datasets (generation inputs, SFT data) are open-sourced. 31 | * **🏆 Superior Performance of Distilled Models**: 32 | * **PromptCoT-DS-7B** consistently surpasses its base model, DeepSeek-R1-Distill-Qwen-7B, with significant gains: 33 | * **+0.9%** on MATH-500 (**93.7%**) 34 | * **+3.2%** on AIME2024 (**58.7%**) 35 | * **+9.2%** on AIME2025 (**49.2%**) 36 | * **PromptCoT-DS-7B** (7B parameters) achieves results comparable to larger 32B models like S1-32B and LIMO-32B. 37 | * **PromptCoT-QwQ-32B** sets a new standard, outperforming other 32B models by a significant margin: 38 | * MATH-500: **96.7% ± 0.5%** 39 | * AIME2024: **83.8% ± 2.8%** 40 | * AIME2025: **75.4% ± 4.7%** 41 | * **PromptCoT-DS-1.5B** demonstrates competitive performance against RL-based models purely through distillation. 42 | * **⚡ Efficiency Without Compromise**: **PromptCoT-DS-1.5B** achieves **40+% AIME scores** using **over 15× fewer A100 GPU hours** compared to models like DeepScaleR-1.5B-Preview. 43 | 44 | ### **2. PromptCoT-Mamba: Attention-Free Reasoning** 45 | 46 | * 🚀 **First Attention-Free SOTA**: PromptCoT-Mamba-7B is the first attention-free model (Mamba-2 architecture) to outperform strong Transformer baselines in math and code reasoning. 47 | * 🧠 **Trained with PromptCoT Pipeline**: Utilizes a structured, two-stage curriculum with data generated by PromptCoT. 48 | * 💪 **Strong General Performance**: PromptCoT-Mamba-7B consistently outperforms 7B-scale Transformer and hybrid Mamba-Transformer baselines. 49 | * MATH-500: **84.6%** 50 | * AIME 2024: **35.2%** 51 | * AIME 2025: **24.6%** 52 | * Livecodebench: **29.9%** 53 | * 🎯 **Math Specialization**: The math-specialized variant, **PromptCoT-Mamba-Math-7B**, further boosts math performance: 54 | * MATH-500: **88.0%** 55 | * AIME 2024: **42.9%** (+7.7% over generalist) 56 | * AIME 2025: **30.8%** (+6.2% over generalist) 57 | * ⚡ **Inference Efficiency**: Offers substantial speedups (e.g., **3.66× faster** on 24GB GPU for long sequences) and constant-memory inference, ideal for cost-sensitive or long-context workloads. 58 | 59 | --- 60 | 61 | ## **Performance Details** 62 | 63 | ### **PromptCoT Series Performance** 64 | 65 | | **Model** | **GSM8K** | **MATH-500** | **AIME2024** | **AIME2025** | 66 | |----------------------------------------------|------------------|---------------------|---------------------|---------------------| 67 | | **🔹 1.5B Models** | | | | | 68 | | DeepSeek-R1-Distill-Qwen-1.5B | - | 83.9% | 28.9% | 28.1% | 69 | | STILL-3-1.5B-preview | - | 85.5% | 39.3% | - | 70 | | DeepScaleR-1.5B-Preview | - | 🟢 **87.8%** | 🟢 **43.1%** | 🟢 **37.1%** | 71 | | **PromptCoT-DS-1.5B** (**ours**) | 🟢 **87.6% ± 0.5%** | **85.3% ± 1.1%** | **41.2% ± 6.9%** | **36.7% ± 6.2%** | 72 | | **🔹 7B Models** | | | | | 73 | | DeepSeek-R1-Distill-Qwen-7B | - | 92.8% | 55.5% | 40.0% | 74 | | Qwen2.5-7B-SimpleRL | - | 82.4% | 26.7% | - | 75 | | OpenThinker-7B | - | 89.6% | 30.0% | 33.3% | 76 | | OpenR1-Qwen-7B | - | 90.6% | 36.7% | 40.0% | 77 | | **PromptCoT-DS-7B** (**ours**) | 🔥 **92.8% ± 0.5%** | 🔥 **93.7% ± 0.7%** | 🔥 **58.7% ± 3.1%** | 🔥 **49.2% ± 7.9%** | 78 | | **🔹 32B Models** | | | | | 79 | | DeepSeek-R1-Distill-Qwen-32B | - | 94.3% | 72.6% | - | 80 | | S1-32B | - | 93.0% | 56.7% | 26.6% | 81 | | LIMO-32B | - | 94.8% | 57.1% | 46.6% | 82 | | QwQ-32B | - | - | 82.1% | 70.8% | 83 | | **PromptCoT-QwQ-32B** (**ours**) | 🔥🔥 **96.4% ± 0.2%** | 🔥🔥 **96.7% ± 0.5%** | 🔥🔥 **83.8% ± 2.8%** | 🔥🔥 **75.4% ± 4.7%** | 84 | 85 | ### **PromptCoT-Mamba Performance** 86 | 87 | **General Performance:** 88 | 89 | | Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench | 90 | | ---------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- | 91 | | **PromptCoT-Mamba-7B** | **84.6** | 🔥🔥**35.2** | 🔥🔥**24.6** | **50.7** | **81.7** | **75.0** | 🔥🔥**29.9** | 92 | | Gemma3-27B | **89.0** | 32.6 | 24.0 | **54.2** | **86.0** | **78.0** | 26.9 | 93 | | Gemma3-12B | 83.8 | 22.9 | 19.2 | 49.9 | 81.1 | 73.2 | 22.2 | 94 | | Sky-T1-7B | 85.0 | 19.2 | 19.2 | 49.2 | 41.5 | 37.2 | 18.3 | 95 | | S1.1-7B | 82.0 | 19.2 | 17.5 | 43.1 | 64.0 | 56.7 | 13.3 | 96 | | Bespoke-Stratos-7B | 81.2 | 18.3 | 16.3 | 45.0 | 73.2 | 68.3 | 8.6 | 97 | | Nemotron-H-8B | 77.6 | -- | -- | -- | 79.3 | 74.4 | -- | 98 | | M1-3B | 81.7 | 23.0 | 22.0 | 43.6 | -- | -- | -- | 99 | 100 | **Math Specialization vs. Generalist:** 101 | 102 | | Model | MATH-500 | AIME 24 | AIME 25 | OlympiadBench | HumanEval | HumanEval+ | Livecodebench | 103 | | --------------------------- | -------- | -------- | -------- | ------------- | --------- | ---------- | ------------- | 104 | | **PromptCoT-Mamba-Math-7B** | 🔥🔥**88.0** | 🔥🔥**42.9** | 🔥🔥**30.8** | 🔥🔥**52.1** | 71.3 | 66.5 | 20.3 | 105 | | PromptCoT-Mamba-7B | 84.6 | 35.2 | 24.6 | 50.7 | **81.7** | **75.0** | **29.9** | 106 | 107 | --- 108 | 109 | 110 | ## **Citation** 111 | 112 | If you find **PromptCoT** or **PromptCoT-Mamba** useful in your research, please consider citing the respective papers: 113 | 114 | **For PromptCoT:** 115 | ```bibtex 116 | @article{zhao2025promptcot, 117 | author = {Zhao, Xueliang and Wu, Wei and Guan, Jian and Kong, Lingpeng}, 118 | title = {PromptCoT: Synthesizing Olympiad-Level Problems for Mathematical Reasoning in Large Language Models}, 119 | year = {2025}, 120 | journal = {arXiv preprint arXiv:2503.02324}, 121 | url = {http://arxiv.org/abs/2503.02324} 122 | } 123 | ``` 124 | 125 | **For PromptCoT-Mamba:** 126 | ```bibtex 127 | @article{zhao2025scaling, 128 | author = {Xueliang Zhao and Wei Wu and Lingpeng Kong}, 129 | title = {Scaling Reasoning without Attention}, 130 | journal = {arXiv preprint arXiv:2505.22425}, 131 | year = {2025}, 132 | url = {https://arxiv.org/abs/2505.22425} 133 | } 134 | ``` --------------------------------------------------------------------------------