├── .gitignore ├── LICENSE ├── README.md ├── big_math_front_page.png ├── reformulation ├── README.md ├── __init__.py ├── big_math_reformulation.png ├── main.py ├── modules.py ├── requirements.txt └── signatures.py └── signals ├── README.md ├── __init__.py ├── add_empty_boxed_signal.py ├── add_hyperlink_signal.py ├── add_language_signal.py ├── add_multipartq_signal.py ├── add_multiple_choice_signal.py ├── add_proof_signal.py ├── add_semdedup_signal.py ├── add_true_false_signal.py ├── add_yes_no_signal.py ├── model_based_signals.py ├── requirements.txt ├── rollouts_based_signals ├── __init__.py ├── evaluate_responses.py ├── example_solve_rate_script.sh ├── math_eval.py ├── sample_from_model.py └── utils │ ├── __init__.py │ ├── openai_server.py │ └── sglang_util.py └── semdedup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 SynthLabs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Big-Math: A Large-Scale, High-Quality Math Dataset for Reinforcement Learning in Language Models 2 | 3 |
4 | Click to collapse paper preview 5 |

6 | Big Math Front Page 7 |

8 |
9 | 10 |

11 | 12 | Static Badge 13 | 14 | 15 | Static Badge 16 | 17 |

18 | 19 |

20 | 21 | Request Early Access 22 | 23 | 24 | Private Reasoning Evals 25 | 26 |

27 | 28 | ## Introduction 29 | 30 | Welcome to the official repository for **Big-Math**, a large-scale, high-quality dataset designed specifically for RL training ([PPO](https://arxiv.org/abs/1707.06347), [GRPO](https://arxiv.org/abs/2402.03300), etc.) with large language models (LLMs). 31 | 32 | This repository provides tools for reformulating multiple-choice questions and implementing rule-based and model-based filtering as described in the [Big-Math paper](https://alon-albalak.github.io/images/Big_MATH.pdf). 33 | 34 | Find the dataset on HuggingFace at [https://huggingface.co/datasets/SynthLabsAI/Big-Math-RL-Verified](https://huggingface.co/datasets/SynthLabsAI/Big-Math-RL-Verified) 35 | 36 | 37 | > [!WARNING] 38 | > This repo is intended for research purposes, and is thus under constant development. Please expect major changes to the design. 39 | > The primary goal of the big-math repo is to share the filtering and reformulation code for creating the Big-MATH dataset and to speed the development of future datasets. 40 | > The Big-Math dataset is intended only for RL training of LLMs, it does not contain rollouts 41 | 42 | ## Repository Structure 43 | 44 | This repo consists of 2 main directories: signals and reformulation. 45 | 46 | ### Signals 47 | 48 | This folder contains code used to generate signals on a dataset. The below signals can be generated either using rule-based methods or model-based methods: 49 | 50 | | Signal | Rule-Based | Model-Based | 51 | |-----------------------|------------|-------------| 52 | | Hyperlink Detection | ✅ | | 53 | | Language Identification | | ✅ | 54 | | Semantic Duplicate | | ✅ | 55 | | Multiple Choice Question | ✅ | ✅ | 56 | | Multi-Part Question | ✅ | ✅ | 57 | | True/False Question | ✅ | ✅ | 58 | | Yes/No Question | ✅ | ✅ | 59 | | Proof Detection | ✅ | ✅ | 60 | | Model Solve Rate | | ✅ | 61 | 62 | ### Reformulation 63 | 64 | This folder contains code used to reformulate multiple choice problems to open-ended questions. 65 | 66 | ## 🚀 Getting Started 67 | 68 | ### Prerequisites 69 | - python 3.10+ 70 | - install with packages in `signals/requirements.txt` to generate signals on a dataset 71 | - install with packages in `reformulation/requirements.txt` to reformulate multiple choicen questions into open-ended questions 72 | 73 | ### Installation 74 | 1. Clone the repository: 75 | ```bash 76 | git clone https://github.com/SynthLabsAI/big-math.git 77 | cd big-math 78 | ``` 79 | 2. Install dependencies 80 | ```bash 81 | pip install -r signals/requirements.txt -r reformulation/requirements.txt 82 | ``` 83 | 84 | ### 🛠 Usage 85 | 86 | #### Reformulation 87 | See [the reformulation readme](https://github.com/SynthLabsAI/big-math/blob/master/reformulation/README.md) for an explanation of files and usage. 88 | 89 | #### Signals 90 | See [the signals readme](https://github.com/SynthLabsAI/big-math/blob/master/signals/README.md) for an explanation of files and usage. 91 | 92 | ## 📄 Citation 93 | 94 | ```bibtex 95 | @misc{albalak2025bigmathlargescalehighqualitymath, 96 | title={Big-Math: A Large-Scale, High-Quality Math Dataset for Reinforcement Learning in Language Models}, 97 | author={Alon Albalak and Duy Phung and Nathan Lile and Rafael Rafailov and Kanishk Gandhi and Louis Castricato and Anikait Singh and Chase Blagden and Violet Xiang and Dakota Mahan and Nick Haber}, 98 | year={2025}, 99 | eprint={2502.17387}, 100 | archivePrefix={arXiv}, 101 | primaryClass={cs.LG}, 102 | url={https://arxiv.org/abs/2502.17387}, 103 | } 104 | ``` 105 | 106 | ## License 107 | 108 | This project is licensed under the MIT License. See the [LICENSE](https://github.com/SynthLabsAI/big-math/blob/master/LICENSE) for details. 109 | -------------------------------------------------------------------------------- /big_math_front_page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/big_math_front_page.png -------------------------------------------------------------------------------- /reformulation/README.md: -------------------------------------------------------------------------------- 1 | # Multiple Choice Reformulation 2 | 3 | The **Multiple Choice Reformulator** is designed to transform multiple choice questions into open-ended questions using dspy. The tool is designed for research purposes. 4 | 5 | Our Reformulation Strategy 6 | 7 | 8 | ## Folder Structure 9 | 10 | - `main.py`: The main entry point for the reformulation process. 11 | - `signatures.py`: Contains the function signatures and definitions for the dspy method used within the reformulator. 12 | - `modules.py`: Houses the Reformulator class, which implements the core logic for converting MCQs to open-ended problems. 13 | - `requirements.txt`: Make sure to install the dependencies as specified here. 14 | 15 | ## Setup and Configuration 16 | 17 | ### Step 1: Implement the `custom_load_dataset()` function 18 | The entry point for running the reformulator is located in `main.py`. You need to implement the `custom_load_dataset()` function within this file to load your dataset of multiple-choice questions. This function should return the dataset in a format that the reformulator can process. 19 | 20 | Here is a simple template using HuggingFace Datasets to get you started: 21 | ```python 22 | from datasets import load_dataset 23 | def custom_load_dataset(dataset_name): 24 | dataset = load_dataset(dataset_name, split="train") 25 | return dataset 26 | ``` 27 | 28 | ### Step 2: Define the `API_CONFIGS` variable 29 | Within `main.py`, you will also need to define the `API_CONFIGS` variable. This variable will store configuration information related to the model API you are using. Modify this section to suit your API configuration setup. 30 | 31 | Example: 32 | ```python 33 | # Server configurations 34 | API_CONFIGS = { 35 | """Example""" 36 | "": { 37 | "api_base": "http://:/v1", 38 | "api_key": "PROVIDER_API_KEY", 39 | }, 40 | "123.456.789.012": { 41 | "api_base": "http://123.456.789.012:8000/v1", 42 | "api_key": "PROVIDER_API_KEY", 43 | }, 44 | } 45 | ``` 46 | 47 | ### Step 3: Run the Reformulator 48 | Once the dataset loading and API configurations are set up, you can execute the reformulator by running the following command: 49 | ```bash 50 | python main.py 51 | ``` 52 | -------------------------------------------------------------------------------- /reformulation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/reformulation/__init__.py -------------------------------------------------------------------------------- /reformulation/big_math_reformulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/reformulation/big_math_reformulation.png -------------------------------------------------------------------------------- /reformulation/main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import math 4 | import pprint 5 | import random 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional 9 | 10 | from dotenv import load_dotenv 11 | from loguru import logger 12 | from tenacity import retry, stop_after_attempt, wait_exponential 13 | from tqdm import tqdm 14 | 15 | import dspy 16 | 17 | load_dotenv() 18 | 19 | from modules import MathProblemReformulator 20 | 21 | # Global configs 22 | # MODEL = "openai/gpt-4-turbo-preview" # default model 23 | MODEL = "openai/meta-llama/Llama-3.1-8B-Instruct" 24 | CONCURRENT_TASKS_PER_API = 25 25 | 26 | # Server configurations 27 | API_CONFIGS = { 28 | """Example""" 29 | "": { 30 | "api_base": "http://:/v1", 31 | "api_key": "PROVIDER_API_KEY", 32 | }, 33 | "123.456.789.012": { 34 | "api_base": "http://123.456.789.012:8000/v1", 35 | "api_key": "PROVIDER_API_KEY", 36 | }, 37 | } 38 | 39 | 40 | def custom_load_dataset(): 41 | """ 42 | This function is a placeholder for loading a custom dataset. 43 | 44 | You need to define this function based on the specific requirements of your dataset. 45 | It should include the logic to load and return the dataset from its source. 46 | 47 | Returns: 48 | list: An empty list as a placeholder. Replace this with the actual dataset. 49 | """ 50 | """Define your own dataset loader here, depending on where the data comes from""" 51 | return [] 52 | 53 | def setup_dspy_with_endpoint(api_base: str, api_key: str, temperature: float = 0.0001) -> None: 54 | """Configure DSPy with specified endpoint and settings""" 55 | lm = dspy.LM( 56 | MODEL, 57 | api_key=api_key, 58 | api_base=api_base, 59 | temperature=round(temperature, 6), 60 | # model_type='chat' 61 | ) 62 | # Configure async worker pool size 63 | dspy.settings.configure(lm=lm, async_max_workers=CONCURRENT_TASKS_PER_API) 64 | logger.debug(f"DSPy configured with model: {MODEL} at {api_base}") 65 | 66 | def save_batch_results(batch_data: List[Dict[str, Any]], output_dir: str = "outputs") -> None: 67 | """Append batch results to a single JSONL file""" 68 | # Create output directory if it doesn't exist 69 | output_dir = Path(output_dir) 70 | output_dir.mkdir(parents=True, exist_ok=True) 71 | 72 | # Use a single JSONL file for all results 73 | output_file = output_dir / f"reformulation_results.jsonl" 74 | 75 | # Append results to file 76 | with open(output_file, "a") as f: 77 | for item in batch_data: 78 | f.write(json.dumps(item) + "\n") 79 | 80 | logger.debug(f"Batch results appended to {output_file}") 81 | 82 | 83 | def load_processed_problems(output_dir: str = "outputs") -> set: 84 | """Load set of already processed problem IDs that passed judge verification""" 85 | output_files = list(Path(output_dir).glob(f"reformulation_results_*.jsonl")) 86 | processed = set() 87 | 88 | if output_files: 89 | for output_file in output_files: 90 | with open(output_file) as f: 91 | for line in f: 92 | try: 93 | data = json.loads(line) 94 | # Only consider it processed if it was successful AND got a pass verdict 95 | if (data.get('success') and 96 | 'data' in data and 97 | data.get('data', {}).get('judge_verdict', '').lower() == 'pass'): 98 | processed.add(data['data']['uuid']) 99 | except: 100 | continue 101 | 102 | logger.info(f"Found {len(processed)} previously processed problems with PASS verdict") 103 | return processed 104 | 105 | 106 | def log_processing_details(problem: str, result) -> None: 107 | """Log the processing details for a single item""" 108 | # Print and log results with clear section headers 109 | logger.info("=== Original Problem ===") 110 | print("\n=== Original Problem ===") 111 | logger.debug(problem) 112 | print(problem) 113 | 114 | # Display reformulation process details 115 | key_display_map = { 116 | "core_mathematical_concept": "", 117 | "key_information_extraction": "", 118 | "problem_structure_analysis": "", 119 | "multiple_choice_removal_strategy": "", 120 | "rephrasing_approach": "", 121 | "problem_integrity_preservation": "", 122 | "answer_format_specification": "" 123 | } 124 | 125 | logger.info("=== Reformulation Process Details ===") 126 | print("\n=== Reformulation Process Details ===") 127 | for key, value in result.reformulation_process.items(): 128 | display_key = key_display_map.get(key, key.replace('_', ' ').title()) 129 | logger.debug(f"{display_key}: {value}") 130 | print(f"{display_key}:") 131 | print(f" {value}") 132 | 133 | logger.info("=== Reformulated Problem ===") 134 | print("\n=== Reformulated Problem ===") 135 | logger.debug(result.reformulated_problem) 136 | print(result.reformulated_problem) 137 | 138 | # Log judge results 139 | logger.info("=== Judge Evaluation ===") 140 | print("\n=== Judge Evaluation ===") 141 | print(f"Verdict: {result.judge_verdict}") 142 | 143 | if result.judge_issues: 144 | print("\nIssues Found:") 145 | for issue in result.judge_issues: 146 | print(f"- {issue}") 147 | 148 | if result.judge_suggestions: 149 | print("\nSuggestions:") 150 | for suggestion in result.judge_suggestions: 151 | print(f"- {suggestion}") 152 | 153 | if result.judge_corrected_reformulated_problem: 154 | logger.info("=== Judge's Corrected Version ===") 155 | print("\n=== Judge's Corrected Version ===") 156 | print(result.judge_corrected_reformulated_problem) 157 | 158 | # Log validation results 159 | logger.info("=== Basic Validation Results ===") 160 | print("\n=== Basic Validation Results ===") 161 | print("Checking the following criteria:") 162 | validation_descriptions = { 163 | "no_mc_options": "Multiple choice options removed", 164 | "process_complete": "All reformulation steps completed" 165 | } 166 | 167 | for check, passed in result.validation_details.items(): 168 | status = "✓ PASSED" if passed else "✗ FAILED" 169 | description = validation_descriptions.get(check, check) 170 | logger.debug(f"{status} | {description}") 171 | print(f"{status} | {description}") 172 | 173 | overall = "✓ PASSED" if result.validation_result else "✗ FAILED" 174 | logger.info(f"Overall Basic Validation: {overall}") 175 | print(f"\nOverall Basic Validation: {overall}") 176 | print("\n" + "="*50 + "\n") 177 | 178 | 179 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) 180 | async def process_single_item(item: Dict[str, Any], reformulator, api_base: str = None) -> Dict[str, Any]: 181 | """Process a single item and return its result""" 182 | try: 183 | problem_text = item['problem'] 184 | if not problem_text: 185 | raise ValueError("Empty problem text") 186 | try: 187 | correct_answer = item['final_answer'] 188 | except: 189 | logger.warning("No final answer found") 190 | correct_answer = None 191 | if not correct_answer: 192 | try: 193 | correct_answer = item['answer'] 194 | except: 195 | logger.warning("No answer found") 196 | correct_answer = None 197 | if not correct_answer: 198 | try: 199 | correct_answer = item['source_solution'] 200 | except: 201 | logger.warning("No final answer found") 202 | correct_answer = None 203 | if not correct_answer: 204 | try: 205 | correct_answer = item['solution'] 206 | except: 207 | logger.warning("No solution found") 208 | pprint.pprint(item) 209 | # input("NO SOLUTION/ANSWER FOUND!!! YIKKKKEESSSS...... Press Enter to continue...") 210 | correct_answer = None 211 | print(f"Correct answer: {correct_answer}") 212 | # Add uuid if not present for single problems 213 | if 'uuid' not in item: 214 | item['uuid'] = str(datetime.now().timestamp()) 215 | 216 | temperature = round(random.uniform(0.00001, 0.001), 6) 217 | 218 | # Call reformulator asynchronously with randomized temperature 219 | with dspy.context( 220 | lm=dspy.LM( 221 | MODEL, 222 | temperature=temperature, 223 | api_base=api_base, 224 | api_key="PROVIDER_API_KEY", 225 | # model_type='chat' 226 | )): 227 | result = await reformulator(problem_text, correct_answer) 228 | 229 | processed_data = { 230 | 'reformulation_process': result.reformulation_process, 231 | 'reformulated_problem': result.reformulated_problem, 232 | 'reformulation_solution': result.reformulation_solution, 233 | 'validation_details': result.validation_details, 234 | 'validation_result': result.validation_result, 235 | 'reasoning': result.reasoning if result.reasoning else None, 236 | 'judge_verdict': result.judge_verdict, 237 | 'judge_issues': result.judge_issues, 238 | 'judge_suggestions': result.judge_suggestions, 239 | 'judge_reasoning': result.judge_reasoning, 240 | 'judge_corrected_reformulated_problem': result.judge_corrected_reformulated_problem, 241 | 'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"), 242 | 'temperature': temperature, 243 | 'api_base': api_base, 244 | **item, 245 | } 246 | 247 | logger.info("=== Processing Results ===") 248 | log_processing_details(problem_text, result) 249 | 250 | return {"success": True, "data": processed_data} 251 | except Exception as e: 252 | logger.error(f"Failed to process problem: {e}") 253 | return {"success": False, "error": str(e), "item": item} 254 | 255 | async def process_batch_distributed(items, reformulator, api_resources: Dict) -> List[Dict[str, Any]]: 256 | """Process a batch of items distributed across multiple API endpoints""" 257 | new_items = [item for item in items if item['uuid']] 258 | 259 | if not new_items: 260 | return [] 261 | 262 | api_endpoints = list(api_resources.keys()) 263 | num_endpoints = len(api_endpoints) 264 | 265 | tasks = [] 266 | for idx, item in enumerate(new_items): 267 | api_base = api_endpoints[idx % num_endpoints] 268 | resources = api_resources[api_base] 269 | pprint.pprint(item) 270 | setup_dspy_with_endpoint( 271 | api_base=resources['api_base'], 272 | api_key=resources['api_key'], 273 | temperature=round(random.uniform(0.00001, 0.001), 6) 274 | ) 275 | 276 | task = process_single_item(item, reformulator, api_base=resources['api_base'],) 277 | tasks.append(task) 278 | 279 | results = await asyncio.gather(*tasks) 280 | 281 | successful_results = [r for r in results if r["success"]] 282 | if successful_results: 283 | save_batch_results(successful_results) 284 | 285 | return results 286 | 287 | async def main(problem: Optional[str] = None) -> None: 288 | """Main function to run math problem reformulation""" 289 | # Initialize API resources 290 | api_resources = { 291 | api_base: { 292 | 'api_base': config['api_base'], 293 | 'api_key': config['api_key'] 294 | } for api_base, config in API_CONFIGS.items() 295 | } 296 | 297 | # Create reformulator instance and make it async 298 | reformulator = dspy.asyncify(MathProblemReformulator()) 299 | logger.debug("MathProblemReformulator instance created") 300 | 301 | if problem is None: 302 | # Load and process dataset in parallel 303 | dataset = custom_load_dataset() 304 | if not dataset: 305 | logger.error("No dataset loaded. Exiting.") 306 | return 307 | 308 | total_problems = len(dataset) 309 | logger.info(f"Processing {total_problems} problems from dataset") 310 | 311 | print("\nSample row from dataset:") 312 | first_item = dataset[0] 313 | print(json.dumps(first_item, indent=2)) 314 | # input("\nPress Enter to continue...") 315 | 316 | batch_size = len(api_resources) * CONCURRENT_TASKS_PER_API 317 | with tqdm(total=total_problems, desc="Processing problems") as pbar: 318 | for i in range(0, total_problems, batch_size): 319 | batch = dataset[i:i + batch_size] 320 | 321 | logger.info(f"Processing batch {i//batch_size + 1} of {math.ceil(total_problems/batch_size)}") 322 | results = await process_batch_distributed(batch, reformulator, api_resources) 323 | pbar.update(len(results)) 324 | else: 325 | # Process single problem 326 | api_base = next(iter(API_CONFIGS.keys())) 327 | config = API_CONFIGS[api_base] 328 | setup_dspy_with_endpoint(config['api_base'], config['api_key']) 329 | 330 | try: 331 | # Create item dictionary matching dataset format 332 | item = { 333 | "problem": problem, 334 | "source": "manual_input" 335 | } 336 | result = await process_single_item(item, reformulator) 337 | if result["success"]: 338 | save_batch_results([result]) 339 | logger.info("Successfully saved single problem result") 340 | except Exception as e: 341 | logger.error(f"Failed to process single problem: {e}") 342 | 343 | if __name__ == "__main__": 344 | asyncio.run(main()) -------------------------------------------------------------------------------- /reformulation/modules.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from loguru import logger 3 | 4 | from signatures import MathReformulationSignature, ReformulationJudgeSignature 5 | 6 | 7 | class MathProblemReformulator(dspy.Module): 8 | """ 9 | Reformulates multiple-choice math problems into open-ended format. 10 | 11 | Methods 12 | ------- 13 | __init__(): 14 | Initializes the MathProblemReformulator with prediction models for reformulation and judgment. 15 | 16 | forward(problem: str, correct_answer: str) -> dspy.Prediction: 17 | Reformulates a given multiple-choice math problem and judges the reformulation. 18 | 19 | Parameters: 20 | problem (str): The original multiple-choice math problem. 21 | correct_answer (str): The correct answer to the original problem. 22 | 23 | Returns: 24 | dspy.Prediction: The prediction result containing the reformulated problem, validation details, and judgment results. 25 | 26 | _validate(reformulation_process: dict, reformulated_problem: str) -> tuple[bool, dict]: 27 | Validates the reformulation output. 28 | 29 | Parameters: 30 | reformulation_process (dict): The process details of the reformulation. 31 | reformulated_problem (str): The reformulated math problem. 32 | 33 | Returns: 34 | tuple[bool, dict]: A tuple containing a boolean indicating if the validation passed and a dictionary with validation details. 35 | """ 36 | """Reformulates multiple-choice math problems into open-ended format""" 37 | 38 | def __init__(self): 39 | super().__init__() 40 | self.reformulate = dspy.Predict(MathReformulationSignature) 41 | self.judge = dspy.Predict(ReformulationJudgeSignature) 42 | 43 | def forward(self, problem: str, correct_answer: str): 44 | # First pass - reformulate the problem 45 | result = self.reformulate(original_problem=problem, correct_answer=correct_answer) 46 | 47 | # Skip judge if problem wasn't actually multiple choice 48 | if result.reformulated_problem.strip() == "N/A": 49 | # Return early with minimal response 50 | return dspy.Prediction( 51 | reformulation_process=result.reformulation_process, 52 | reformulated_problem="N/A", 53 | validation_result=True, # Consider it valid since it's properly identified 54 | validation_details={"not_multiple_choice": True}, 55 | reasoning=result.reasoning, 56 | reformulation_solution=result.solution, 57 | judge_corrected_reformulated_problem=None, 58 | judge_verdict=None, 59 | judge_issues=None, 60 | judge_suggestions=None, 61 | judge_reasoning=None 62 | ) 63 | 64 | # Second pass - judge the reformulation 65 | judgment = self.judge( 66 | original_problem=problem, 67 | correct_answer=correct_answer, 68 | reformulated_problem=result.reformulated_problem 69 | ) 70 | 71 | # Run basic validation 72 | is_valid, validation_details = self._validate( 73 | result.reformulation_process, 74 | result.reformulated_problem 75 | ) 76 | 77 | # Add corrected_reformulated_problem without overwriting 78 | corrected_reformulation = judgment.corrected_version if judgment.corrected_version else None 79 | 80 | return dspy.Prediction( 81 | reformulation_process=result.reformulation_process, 82 | reformulated_problem=result.reformulated_problem, 83 | validation_result=is_valid, 84 | validation_details=validation_details, 85 | reasoning=result.reasoning, 86 | reformulation_solution=result.solution, 87 | judge_corrected_reformulated_problem=corrected_reformulation, 88 | judge_verdict=judgment.verdict, 89 | judge_issues=judgment.issues, 90 | judge_suggestions=judgment.suggestions, 91 | judge_reasoning=judgment.reasoning 92 | ) 93 | 94 | def _validate(self, reformulation_process: dict, reformulated_problem: str) -> tuple[bool, dict]: 95 | """Validate reformulation output""" 96 | # expected keys in snake_case to match model output 97 | expected_keys = { 98 | 'core_mathematical_concept', 99 | 'key_information_extraction', 100 | 'problem_structure_analysis', 101 | 'multiple_choice_removal_strategy', 102 | 'rephrasing_approach', 103 | 'problem_integrity_preservation', 104 | 'answer_format_specification', 105 | 'is_multiple_choice' 106 | } 107 | 108 | # Validate reformulation process keys 109 | provided_keys = set(reformulation_process.keys()) 110 | missing_keys = expected_keys - provided_keys 111 | extra_keys = provided_keys - expected_keys 112 | 113 | # Log any key mismatches for debugging 114 | if missing_keys: 115 | logger.debug(f"Missing keys: {missing_keys}") 116 | if extra_keys: 117 | logger.debug(f"Unexpected extra keys: {extra_keys}") 118 | 119 | # # Check if reformulated problem contains boxed instruction 120 | # has_box_instruction = "\\boxed{" in reformulated_problem 121 | 122 | # Check if multiple choice options were removed 123 | no_mc_options = not any(x in reformulated_problem for x in ["(A)", "(B)", "(C)", "(D)", "(E)"]) 124 | 125 | process_complete = not missing_keys and not extra_keys 126 | 127 | validation_details = { 128 | # "has_box_instruction": has_box_instruction, 129 | "no_mc_options": no_mc_options, 130 | "process_complete": process_complete 131 | } 132 | 133 | return all( 134 | [ 135 | # has_box_instruction, 136 | no_mc_options, 137 | process_complete 138 | ] 139 | ), validation_details -------------------------------------------------------------------------------- /reformulation/requirements.txt: -------------------------------------------------------------------------------- 1 | tenacity==9.0.0 2 | dspy==2.5.40 3 | loguru==0.7.2 4 | python-dotenv==1.0.1 5 | hf-transfer==0.1.8 6 | duckdb==1.1.3 7 | httpx==0.25.2 -------------------------------------------------------------------------------- /reformulation/signatures.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import dspy 3 | 4 | 5 | class MathReformulationSignature(dspy.Signature): 6 | """You are an AI assistant specializing in transforming multiple-choice math problems into open-ended, solvable questions suitable for an automatic grading system that relies on regex pattern matching within `\boxed{}`. Your task is to reformulate the given problem while adhering to specific guidelines. 7 | 8 | **Rules:** 9 | - Never turn problems into proofs, all problems should have an answer that is machine-verifiable 10 | - Reformulated problems should have the same final answer as the original problem (just not in the multiple-choice format) 11 | - It must be reasonably clear from the problem what the student is expected to write inside of \boxed{final_answer}. E.g. "What are its key properties and characteristics?" is almost always a bad reformulation because it is ambiguous what the student should write inside of the `\boxed{}`. "Express your answer in centimeters." is a good answer format specification because it does NOT give away the answer but does specify a clear format. 12 | - Do NOT reformulate problems that are not actually multiple choice or are impossible to reformulate 13 | - If a problem is not multiple choice, return "N/A" for the reformulated problem and other "N/A" fields 14 | 15 | **Follow these steps:** 16 | 17 | 1. **Analyze the original problem:** 18 | - Identify the core mathematical concept. 19 | - Note any crucial information, including numerical values, equations, and key terms. 20 | - Determine the structure of the problem (e.g., scenario-based, direct question, multi-step). 21 | - Preserve all mathematical notations, symbols, and formatting as they appear in the original problem. 22 | 23 | 2. **Remove multiple-choice options:** 24 | - Eliminate all answer choices and their labels (A, B, C, etc.). 25 | - If any options contain information essential to solving the problem, incorporate that information into the main problem statement. 26 | 27 | 3. **Rephrase the question:** 28 | - Transform the question into an open-ended format that requires a direct numerical or algebraic answer. 29 | - Ensure the rephrased question is clear, unambiguous, and uses language appropriate for the student's level. 30 | 31 | 4. **Maintain problem integrity:** 32 | - Keep all original numerical values, equations, figures, and key terms intact. 33 | - Preserve any scenarios, dialogues, or conditional information crucial for solving the problem. 34 | - Do not introduce new information, alter units of measurement, or change the mathematical intent of the problem. 35 | - If the problem references diagrams or figures, ensure that any necessary descriptions are included. 36 | 37 | 5. **Specify the answer format:** 38 | - Instruct the student to provide their answer using the `\boxed{}` format. 39 | - Do not include placeholders like `\boxed{N}`; instead, guide the student to input their calculated answer within the boxed format. 40 | - For example, "Provide your answer in the form `\boxed{\text{your answer here}}`." 41 | 42 | 6. **Final check:** 43 | - Ensure the reformulated problem contains all necessary information for independent solving. 44 | - Verify that the problem hasn't become easier or harder than the original. 45 | - Check for any common errors, such as unit inconsistencies or typographical mistakes. 46 | - Confirm that no hints, solution methods, or additional explanations have been inadvertently added. 47 | 48 | 7. **Is actually multiple choice:** 49 | - Some problems are not actually multiple choice and do NOT actually need to be reformulated. 50 | - If the problem is **NOT** multiple choice, do NOT reformulate it! Note here if it is not multiple choice and return "N/A" for the reformulated problem. 51 | 52 | **Before providing the final reformulated problem, please create a `reformulation_process` dictionary (for internal use; do not include this dictionary or its content in the final problem). The dictionary should have the following exact keys:** 53 | 54 | - `"core_mathematical_concept"`: Summarize the core mathematical concept. 55 | - `"key_information_extraction"`: List key information (numerical values, equations, terms). 56 | - `"problem_structure_analysis"`: Identify the problem structure. 57 | - `"multiple_choice_removal_strategy"`: Plan how to remove multiple-choice options. 58 | - `"rephrasing_approach"`: Outline the rephrasing strategy. 59 | - `"problem_integrity_preservation"`: Note how to maintain problem integrity. 60 | - `"answer_format_specification"`: Specify the answer format. 61 | - `"is_multiple_choice"`: Whether the problem is actually multiple choice. 62 | 63 | **This structured approach will help ensure accuracy and completeness in your final output.** 64 | 65 | **After your reformulation process, present the reformulated problem as plain text, making sure it meets all the criteria outlined above. Do not include the `` tags or any of their content in the final problem presented to the student.** 66 | 67 | **Examples:** 68 | ``` 69 | 70 | 71 | Three squares with the same center and corresponding parallel sides are drawn. The distance between the sides of successive squares is 3 units, and the side length of the largest square is 22 units. What is the perimeter of the smallest square? 72 | (A) 40 73 | (B) 64 74 | (C) 100 75 | (D) 10 76 | (E) 20 77 | 78 | 79 | 80 | { 81 | "core_mathematical_concept": "Geometry - concentric squares and their dimensions", 82 | "key_information_extraction": [ 83 | "Three concentric squares", 84 | "Distance between sides of successive squares: 3 units", 85 | "Side length of the largest square: 22 units", 86 | "Need to find the perimeter of the smallest square" 87 | ], 88 | "problem_structure_analysis": "Direct question based on given information", 89 | "multiple_choice_removal_strategy": [ 90 | "Eliminate all answer choices and their labels", 91 | "No essential information is lost by removing the options" 92 | ], 93 | "rephrasing_approach": [ 94 | "Keep the description of the squares intact", 95 | "Rephrase the question to ask for the perimeter in an open-ended format", 96 | "Specify units and number format requirements" 97 | ], 98 | "problem_integrity_preservation": [ 99 | "Preserve all original numerical values and key terms", 100 | "Do not add any new information or hints" 101 | ], 102 | "answer_format_specification": [ 103 | "Answer should be expressed in units", 104 | "Answer should be given as a whole number", 105 | "No decimal places required based on original options" 106 | ], 107 | "is_multiple_choice": true 108 | } 109 | 110 | 111 | 112 | Let me think about this step by step: 113 | 114 | 1. Analyzing the multiple choice aspect: 115 | - The options are all whole numbers: 40, 64, 100, 10, 20 116 | - No units are specified in the options but units are used in the problem 117 | - The options suggest an exact answer is expected, not an approximation 118 | 119 | 2. Looking at the essential information: 120 | - The geometric relationships are clearly stated 121 | - All measurements are given in units 122 | - No complex transformations or unit conversions needed 123 | 124 | 3. What needs to be preserved: 125 | - The concentric squares description 126 | - The 3-unit distance between squares 127 | - The 22-unit largest square side length 128 | - The focus on perimeter rather than area 129 | 130 | 4. What needs to be clarified: 131 | - The answer should include units 132 | - A whole number is expected 133 | - No decimal places needed 134 | 135 | 5. How to make it open-ended: 136 | - Remove all multiple choice options 137 | - Keep the geometric scenario exactly as is 138 | - Add specific answer format requirements 139 | - Maintain the same level of difficulty 140 | 141 | 6. Final considerations: 142 | - The reformulated question should be solvable without reference to the original options 143 | - The specifications should ensure an unambiguous answer format 144 | - The mathematical challenge should remain unchanged 145 | 146 | 147 | 148 | Three squares with the same center and corresponding parallel sides are drawn. The distance between the sides of each successive square is 3 units, and the side length of the largest square is 22 units. Find the perimeter of the smallest square. Express your answer in units as a whole number. 149 | 150 | 151 | ``` 152 | 153 | ``` 154 | 155 | 156 | The teacher picked 9 out of 13 cards labeled with numbers 1 to 13 and stuck them on the foreheads of 9 students. Each student can see the numbers on the other 8 students' foreheads but not their own. The teacher asked, "Those who know the number of divisors of their own number, please raise your hand." Two students raised their hands. After they put their hands down, three students had the following conversation. 157 | - A: "I know my number now." 158 | - B: "I still don't know what my number is, but I have figured out whether it is odd or even." 159 | - C: "My number is 2 less than B's number and 1 more than A's number." 160 | What is the sum of the numbers on the four cards that were not picked? 161 | 162 | 163 | 164 | { 165 | "core_mathematical_concept": "Number theory and logical deduction with combinatorics", 166 | "key_information_extraction": [ 167 | "13 cards numbered from 1 to 13", 168 | "9 cards picked and assigned to students", 169 | "Students can see others' numbers but not their own", 170 | "Two students know their number based on number of divisors", 171 | "Conversation between students A, B, and C revealing relationships", 172 | "Need to find sum of the 4 unpicked cards" 173 | ], 174 | "problem_structure_analysis": "Complex logical deduction problem combining number theory with information from multiple stages", 175 | "multiple_choice_removal_strategy": "Not applicable as this is not a multiple choice question", 176 | "rephrasing_approach": "Not needed as problem is already in appropriate format", 177 | "problem_integrity_preservation": "No modifications needed", 178 | "answer_format_specification": "Answer should be expressed as a single integer (sum of the unpicked cards)", 179 | "is_multiple_choice": false 180 | } 181 | 182 | 183 | 184 | Let me think about this step by step: 185 | 186 | 1. First, assess if reformulation is needed: 187 | - This is not a multiple choice question 188 | - It already asks for a specific numerical answer 189 | - The format is clear and unambiguous 190 | 191 | 2. Analyzing the problem structure: 192 | - It's a complex logic puzzle 193 | - All information provided is necessary 194 | - The sequence of events matters 195 | - The dialogue provides crucial information 196 | 197 | 3. Consider what might need clarification: 198 | - Answer format is implicitly clear (a sum) 199 | - No unit specification needed 200 | - No decimal places involved (working with integers 1-13) 201 | - No ambiguity in what constitutes an answer 202 | 203 | 4. Check for potential improvements: 204 | - Current format is optimal for the problem type 205 | - Adding structure might make it harder to follow 206 | - The narrative flow is important for understanding 207 | 208 | 5. Final decision: 209 | - No reformulation needed 210 | - Problem is already in ideal open-ended format 211 | - Answer requirements are clear (single integer sum) 212 | 213 | 214 | 215 | N/A 216 | 217 | 218 | ``` 219 | 220 | ``` 221 | 222 | 223 | In the cube shown, \( L \) and \( K \) are midpoints of adjacent edges \( AD \) and \( AB \), respectively. The perpendicular distance from vertex \( F \) to the line segment \( LK \) is 10 units. Which of the following measurements is closest to the volume of the cube after subtracting 3 units? 224 | 225 | (A) 200 in³ 226 | (B) 500 cm³ 227 | (C) 0.01039 m³ 228 | (D) Tie 229 | 230 | 231 | 232 | { 233 | "core_mathematical_concept": "3D geometry involving cube volume calculation", 234 | 235 | "key_information_extraction": [ 236 | "Points L and K are midpoints of adjacent edges AD and AB", 237 | "Perpendicular distance from F to LK is 10 units", 238 | "Need to subtract 3 units from final volume" 239 | ], 240 | 241 | "problem_structure_analysis": "Geometry problem requiring volume calculation", 242 | 243 | "multiple_choice_removal_strategy": [ 244 | "Remove all comparisons", 245 | "Ask for direct volume calculation" 246 | ], 247 | 248 | "rephrasing_approach": [ 249 | "Keep geometric relationships intact", 250 | "Request exact volume calculation", 251 | "Specify units and precision" 252 | ], 253 | 254 | "problem_integrity_preservation": [ 255 | "Maintain all numerical values and geometric relationships", 256 | "Remove comparative aspect entirely" 257 | ], 258 | 259 | "answer_format_specification": [ 260 | "Answer should be expressed in cubic meters (m³)", 261 | "Round to 5 decimal places", 262 | "Include units in answer" 263 | ], 264 | 265 | "is_multiple_choice": true 266 | } 267 | 268 | 269 | Let me think about this step by step: 270 | 271 | 1. First, I need to identify what makes this multiple choice: 272 | 273 | - The core task is actually calculating a specific volume 274 | - The multiple choice aspect is just about comparing to different unit measurements 275 | - The comparison isn't fundamental to the mathematical concept being tested 276 | 277 | 2. Looking at the original answer choices: 278 | 279 | - They use different units (in³, cm³, m³) 280 | - The m³ value has 5 decimal places (0.01039) 281 | - This suggests m³ with 5 decimal places is an appropriate precision level 282 | 283 | 3. How to remove the multiple choice aspect: 284 | 285 | - Instead of asking which measurement is closest 286 | - Ask for the exact calculation 287 | - Need to specify one consistent unit to avoid ambiguity 288 | 289 | 4. What to preserve: 290 | 291 | - All geometric relationships 292 | - The 3-unit subtraction 293 | - The original scenario and values 294 | 295 | 5. What needs to be added: 296 | 297 | - Specific unit requirement (m³) 298 | - Decimal place requirement (5 places) 299 | - Clear instruction about including units 300 | 301 | 6. The reformulated question should: 302 | 303 | - Be completely open-ended 304 | - Have a unique, calculable answer 305 | - Remove all traces of comparison 306 | - Maintain the same difficulty level 307 | 308 | 309 | 310 | In the cube shown, \( L \) and \( K \) are midpoints of adjacent edges \( AD \) and \( AB \), respectively. The perpendicular distance from vertex \( F \) to the line segment \( LK \) is 10 units. Calculate the volume of the cube after subtracting 3 cubic units. Express your answer in cubic meters (m³), rounded to 5 decimal places. 311 | 312 | 313 | ``` 314 | """ 315 | 316 | original_problem: str = dspy.InputField( 317 | desc="""The original multiple-choice math problem text, including all options and formatting.""" 318 | ) 319 | 320 | correct_answer: str = dspy.InputField( 321 | desc="The correct answer to the original problem." 322 | ) 323 | 324 | reformulation_process: dict = dspy.OutputField( 325 | desc="""Structured analysis of the problem reformulation process containing the following exact keys: 326 | { 327 | "core_mathematical_concept": "Identify & summarize state the main mathematical concept", 328 | "key_information_extraction": "List of numerical values, equations, and key terms", 329 | "problem_structure_analysis": "Description of problem type and structure", 330 | "multiple_choice_removal_strategy": "How to handle removing MC options", 331 | "rephrasing_approach": "Strategy for rephrasing as open-ended", 332 | "problem_integrity_preservation": "How to maintain original difficulty and context", 333 | "answer_format_specification": "How the answer should be formatted with \\boxed{}" 334 | }""", 335 | prefix="\n\n{" 336 | ) 337 | 338 | reasoning: str = dspy.OutputField( 339 | desc="""Think step by step about how to reformulate the problem while adhering to the guidelines provided. Or, if the problem is NOT a multiple choice problem, justify why it is not.""", 340 | prefix="\n\n" 341 | ) 342 | 343 | reformulated_problem: str = dspy.OutputField( 344 | desc=r"""The reformulated open-ended problem that: 345 | - Preserves all mathematical notations and symbols 346 | - Maintains the original difficulty level 347 | - Specifies, and if necessary has clear instructions, for the expected answer format. The student should not be confused about the type of answer that should be presented. (E.g. round your answer to the nearest hundredth, or express your answer as a percentage, list your answer as a coordinate, etc.). If the answer cannot be easily expressed in boxed latex format (for auto-grading), then the problem was not well formulated! 348 | - Contains no multiple-choice options 349 | - Does not add any hints or solution methods 350 | - Retains all necessary context to solve the problem 351 | (If the problem is not a multiple choice problem or cannot be reformulated, then just return N/A)""", 352 | prefix="\n\n" 353 | ) 354 | 355 | solution: str = dspy.OutputField( 356 | desc="""The solution to the reformulated problem. What the student should write inside of the `\boxed{}`.""", 357 | prefix="\n\n" 358 | ) 359 | 360 | 361 | class ReformulationJudgeSignature(dspy.Signature): 362 | """You are a IMO judge, specialized in evaluating the quality of IMO math problem reformulations. Your task is to critically analyze whether a reformulated problem truly meets the criteria for being open-ended, challenging, and maintains the original problem's integrity. 363 | 364 | **Important:** 365 | - Problems you review are being transformed from multiple choice problems into open-ended, solvable questions suitable for an automatic grading system that relies on regex pattern matching within `\boxed{}`. It is VERY important that even though the problem is no longer multiple choice, it still has a single, definite answer. 366 | - It must be reasonably clear from the problem what the student is expected to write inside of \boxed{final_answer}. E.g. "What are its key properties and characteristics?" is almost always a bad reformulation because it is ambiguous what the student should write inside of the `\boxed{}`. "Express your answer in centimeters." is a good answer format specification because it does NOT give away the answer but does specify a clear format. 367 | 368 | Key aspects to evaluate: 369 | 370 | 1. Hidden Multiple Choice 371 | - Check if the reformulation still effectively presents multiple choice options by: 372 | - Embedding a limited/fixed set of choices (e.g. 4 or 5 options) within the problem text that limits the answer choices to one of those options 373 | - Asking to compare with specific values 374 | - Limiting answers to specific options 375 | - Flag any reformulations that are just disguised multiple choice questions 376 | 377 | Example: 378 | 379 | 380 | For real numbers $t \neq 0,$ the point \[(x,y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right)\]is plotted. All the plotted points lie on what kind of curve? (A) Line (B) Circle (C) Parabola (D) Ellipse (E) Hyperbola Enter the letter of the correct option. 381 | 382 | 383 | Reformulated failed example (all multiple choice aspect still present): 384 | 385 | For real numbers $t \neq 0,$ the point \[(x,y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right)\]is plotted. What type of curve do all the plotted points lie on? Provide your answer as a specific curve type (e.g., line, circle, parabola, ellipse, hyperbola). 386 | 387 | 388 | Reformulated successful example (multiple choice aspect removed): 389 | 390 | For real numbers \( t \neq 0 \), the point 391 | \[ 392 | (x, y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right) 393 | \] 394 | is plotted. Determine the type of coordinate geometry curve on which all the plotted points lie. 395 | 396 | 397 | 2. Mathematical Integrity 398 | - Verify that the mathematical difficulty remains unchanged 399 | - Ensure no accidental hints or simplifications were introduced 400 | - Check that all necessary information was preserved 401 | - Confirm no extraneous information was added 402 | 403 | Example 1: 404 | This problem was not actually multiple choice, but it does _imply_ options that are clearly missing/omitted. 405 | 406 | A resident wants to renovate their house and buys several strips of wood, each with a length of 0.7 meters and 0.8 meters. By connecting some of these wood strips, many different lengths of wood can be obtained. For example, \(0.7 + 0.7 = 1.4\) meters, \(0.7 + 0.8 = 1.5\) meters, etc. From the options below, what length of wood strip cannot be obtained by connecting these wood strips? 407 | 408 | 409 | 410 | Example 2: 411 | 412 | $\triangle ABC$ is inscribed in a semicircle of radius $r$ so that its base $AB$ coincides with diameter $AB$. Point $C$ does not coincide with either $A$ or $B$. Let $s=AC+BC$. Then, for all permissible positions of $C$: $\textbf{(A)}\ s^2\le8r^2\qquad \textbf{(B)}\ s^2=8r^2 \qquad \textbf{(C)}\ s^2 \ge 8r^2 \qquad\\ \textbf{(D)}\ s^2\le4r^2 \qquad \textbf{(E)}\ s^2=4r^2$ 413 | 414 | 415 | The reformulation is flawed because it prematurely focuses on a single inequality (\(s^2 \leq 8r^2\)) rather than inviting exploration of the entire range of \(s^2\). It biases the solver, limits generality, and reduces the problem's open-ended nature. 416 | 417 | $\triangle ABC$ is inscribed in a semicircle of radius $r$ so that its base $AB$ coincides with diameter $AB$. Point $C$ does not coincide with either $A$ or $B$. Let $s=AC+BC$. Prove or disprove the inequality $s^2\le8r^2$ for all permissible positions of $C$. 418 | 419 | 420 | An open-ended revision that requires the same analysis and leads to the same conclusion 421 | 422 | Let triangle \( ABC \) be inscribed in a semicircle of radius \( r \), with its base \( AB \) coinciding with the diameter \( AB \). Point \( C \) lies on the semicircle but does not coincide with \( A \) or \( B \). Let \( s = AC + BC \). Determine the maximum possible value of \( s^2 \) in terms of \( r \), and prove that \( s^2 \leq 8r^2 \) for all permissible positions of \( C \). 423 | 424 | 425 | Example 3: 426 | 427 | For real numbers $t,$ the point \[(x,y) = \left( \frac{1 - t^2}{1 + t^2}, \frac{2t}{1 + t^2} \right)\]is plotted. All the plotted points lie on what kind of curve? (A) Line (B) Circle (C) Parabola (D) Ellipse (E) Hyperbola Enter the letter of the correct option. 428 | 429 | 430 | Although it removes the multiple-choice format, it does not truly open the problem to exploration. The answer asks for a description, which introduces too much variability in responses 431 | 432 | For real numbers $t,$ the point \[(x,y) = \left( \frac{1 - t^2}{1 + t^2}, \frac{2t}{1 + t^2} \right)\]is plotted. What kind of curve do all the plotted points lie on? Provide a brief description of the curve. 433 | 434 | 435 | This reformulation is clear, concise, and ensures consistent answers while remaining open-ended for exploration. 436 | 437 | For real numbers \( t \), consider the point 438 | 439 | \[ 440 | (x, y) = \left( \frac{1 - t^2}{1 + t^2},\ \frac{2t}{1 + t^2} \right). 441 | \] 442 | 443 | Determine the type of curve on which all such points \((x, y)\) lie. State your answer as the name of the curve. 444 | 445 | 446 | Example 4: 447 | 448 | 449 | Define * as an operation on ordered pairs of real numbers, such that $(a, b) *(c, d)=$ $(a c+b d, a d+b c)$. If $(a, b) *(x, y)=(a, b)$, then which of the following is $(x, y)$? (A) $(0,0)$. (B) $(0,1)$. (C) $(1,0)$. (D) $(1,1)$. (E) $(1,-1)$. 450 | 451 | 452 | Ambiguous about whether the equation holds for all \((a, b)\) or specific values, lacks context to ensure a unique solution, and removes guidance provided by the multiple-choice format. Becomes unbounded. 453 | 454 | Define * as an operation on ordered pairs of real numbers, such that $(a, b) *(c, d)=$ $(a c+b d, a d+b c)$. If $(a, b) *(x, y)=(a, b)$, then calculate the ordered pair $(x, y)$. Express your answer in the format $(x, y)$, including parentheses and a comma. 455 | 456 | 457 | Ideal because it clarifies the scope (applies to all \((a, b)\)), ensures a unique solution, and balances open-ended exploration with clear grading criteria 458 | 459 | Define \( * \) as an operation on ordered pairs of real numbers, such that 460 | \[ 461 | (a, b) * (c, d) = (ac + bd, ad + bc). 462 | \] 463 | Find the ordered pair \((x, y)\) that satisfies 464 | \[ 465 | (a, b) * (x, y) = (a, b) 466 | \] 467 | for **all** real numbers \(a\) and \(b\). Express your answer as \((x, y)\). 468 | 469 | 470 | 3. Answer Format Clarity 471 | - Evaluate if the answer format specification is: 472 | - Clear and unambiguous 473 | - Appropriate for the mathematical concept 474 | - Not overly unbounded or restrictive in a way that creates an open-ended problem out of it 475 | - Check if unit/precision requirements make sense for the problem 476 | 477 | 4. Problem Independence 478 | - Verify the reformulated problem can stand alone 479 | - Ensure it doesn't rely on knowledge of the original options 480 | - Check that answer requirements aren't derived solely from original choices 481 | 482 | Example: 483 | 484 | Which of the following is a root of the equation \( x^2 - x - 6 = 0 \)? 485 | (A) \( -3 \) (B) \( -2 \) (C) \( 2 \) (D) \( 3 \) (E) \( 6 \) 486 | 487 | 488 | Relies on original options, failing to stand independently and limiting exploration. 489 | 490 | Find a root of the equation \( x^2 - x - 6 = 0 \). Your answer must be one of the following: \( -3, -2, 2, 3, 6 \). 491 | 492 | 493 | Ideal because it ensures the solver identifies all roots without being constrained by the original options. 494 | 495 | Solve the quadratic equation \( x^2 - x - 6 = 0 \) and find all real roots. Provide your answers in increasing order. 496 | 497 | 498 | Remember: A truly open-ended reformulation should allow for calculation and expression of the answer 499 | without any reference to or knowledge of the original multiple choice options.""" 500 | 501 | original_problem: str = dspy.InputField( 502 | desc="The original multiple-choice math problem text, including all options and formatting." 503 | ) 504 | correct_answer: str = dspy.InputField( 505 | desc="The correct answer to the original problem." 506 | ) 507 | 508 | reformulated_problem: str = dspy.InputField( 509 | desc="The reformulated open-ended version of the problem that needs to be evaluated." 510 | ) 511 | 512 | issues: list[str] = dspy.OutputField( 513 | desc="List of specific issues found with the reformulation, if any.", 514 | prefix="\n" 515 | ) 516 | 517 | suggestions: list[str] = dspy.OutputField( 518 | desc="Specific recommendations for improving the reformulation.", 519 | prefix="\n" 520 | ) 521 | 522 | reasoning: str = dspy.OutputField( 523 | desc="Step-by-step explanation of the evaluation process and rationale for the verdict.", 524 | prefix="\n" 525 | ) 526 | 527 | verdict: str = dspy.OutputField( 528 | desc="Either 'PASS' or 'FAIL' based on comprehensive evaluation of the reformulation.", 529 | prefix="\n" 530 | ) 531 | 532 | corrected_version: Optional[str] = dspy.OutputField( 533 | desc="An improved version of the reformulation if issues were found. Should be None if verdict is PASS.", 534 | prefix="\n" 535 | ) -------------------------------------------------------------------------------- /signals/README.md: -------------------------------------------------------------------------------- 1 | # Signals 2 | 3 | The `signals` folder contains scripts and resources for implementing rule-based and model-based filters for dataset curation, as described in the [Big-Math paper](https://alon-albalak.github.io/images/Big_MATH.pdf). These filtering techniques are critical for ensuring the quality, diversity, and relevance of the dataset for reinforcement learning. 4 | 5 | --- 6 | 7 | ## 📂 Contents 8 | 9 | ### Rule-based filters 10 | - `add_empty_boxed_signal.py` (only useful for extracting answers from full solutions (e.g. from [NuminaMath](https://huggingface.co/datasets/AI-MO/NuminaMath-CoT))) 11 | - `add_hyperlink_signal.py` 12 | - `add_language_signal.py` 13 | - `add_multipartq_signal.py` 14 | - `add_multiple_choice_signal.py` 15 | - `add_proof_signal.py` 16 | - `add_true_false_signal.py` 17 | - `add_yes_no_signal.py` 18 | 19 | ### Deduplication 20 | - `add_semdedup_signal.py` 21 | 22 | ### Model-based signals 23 | - `model_based_signals.py` 24 | 25 | ### Solve rate 26 | - `rollouts_based_signals/example_solve_rate_script.sh` 27 | 28 | 29 | ## 🚀 Getting Started 30 | 31 | ### Prerequisites 32 | - Python3.10+ 33 | - Install dependencies: 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## 🛠 Usage 39 | 40 | ### Rule-based signals 41 | Rule-based signals are written in separate files, but share the same use pattern. Each signal requires only a single input variable, `dataset_path`. The script will automatically update the source dataset with a new column. 42 | 43 | This is a simple example that will detect whether the `problem` or `solution` columns have a hyperlink, adding 2 new columns, `has_hyperlink_problem` and `has_hyperlink_solution`: 44 | ```bash 45 | DATASET_PATH= 46 | python3 add_hyperlink_signal.py --dataset_path ${DATASET_PATH} 47 | ``` 48 | 49 | ### Deduplication 50 | We use SemDeDup to mark duplicated problems. Just like the rule-based signals, this requires only a single input variable, `dataset_path`. The script, by default, uses a similarity threshold of 0.5, but that can easily be adjusted by adding/removing values from the `EPSILONS` variable. 51 | ```bash 52 | DATASET_PATH= 53 | python3 add_semdedup_signal.py --dataset_path ${DATASET_PATH} 54 | ``` 55 | 56 | ### Model-based signals 57 | The model-based signals are all included in a single file for efficiency when running multiple filters. Our scripts make use of [SGLang](https://github.com/sgl-project/sglang) for fast inference. 58 | 59 | Model-based signals require a bit more specification than rule-based signals. Here, we need to specify: 60 | - `model_name`: The name of a local or HF model 61 | - `output_model_name`: The name of model you want to use in the signal's column name. For example, the below code uses `--output_model_name "llama3-70b"` and will add the multiple choice signal in a new column: "multiple_choice_llama3-70b" 62 | - `dataset_name`: A path to HF or local dataset 63 | - `dataset_name_outputs`: Optional - The path to save the updated dataset to. Will save to `dataset_name` when not specified 64 | - `save_folder`: Local directory to save intermediate outputs to 65 | - `save_name`: The name of the file to save intermediate outputs to 66 | - `tp`: tensor parallelism 67 | - Signal options: `--multiple_choice`, `--proof`, `--yes_no`, `--true_false`, `--multiple_part` 68 | 69 | The below example runs all model-based signals using Llama-3.1-70B: 70 | ```bash 71 | OUTPUT_DATASET_NAME= 72 | 73 | python3 model_based_filters.py \ 74 | --model_name "meta-llama/Meta-Llama-3.1-70B-Instruct"\ 75 | --output_model_name "llama3-70b"\ 76 | --dataset_name "SynthLabsAI/Big-Math-RL-Verified"\ 77 | --dataset_name_outputs "${OUTPUT_DATASET_NAME}"\ 78 | --save_folder "signal_outputs"\ 79 | --save_name "model_based_signals"\ 80 | --tp 8\ 81 | --multiple_choice\ 82 | --proof\ 83 | --yes_no\ 84 | --true_false\ 85 | --multiple_part\ 86 | > model_based_signal_outputs.txt 2>&1 87 | ``` 88 | 89 | ### Solve rate 90 | See `rollouts_based_signals/example_solve_rate_script.sh` for example usage. 91 | -------------------------------------------------------------------------------- /signals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/__init__.py -------------------------------------------------------------------------------- /signals/add_empty_boxed_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from functools import partial 3 | import multiprocessing as mp 4 | import re 5 | import argparse 6 | 7 | def is_boxed_empty(row, patterns): 8 | row['is_boxed_empty'] = False 9 | for pattern in patterns: 10 | if re.search(pattern, row['solution']): 11 | row['is_boxed_empty'] = True 12 | break 13 | return row 14 | 15 | def main(dataset_path): 16 | # load the dataset 17 | dataset = load_dataset(dataset_path, split="train") 18 | 19 | # define the regex pattern for an empty boxed solution 20 | empty_boxed_patterns = [ 21 | r'boxed\{\s*\}', 22 | r'boxed\{[\s\n\r]*\}', 23 | r'\\boxed\{\s*\}' 24 | ] 25 | 26 | # run the detection over the full dataset 27 | is_boxed_empty_partial = partial(is_boxed_empty, patterns=empty_boxed_patterns) 28 | dataset = dataset.map(is_boxed_empty_partial, num_proc=mp.cpu_count()) 29 | 30 | # add the new column, 'is_boxed_empty', to the dataset 31 | dataset.push_to_hub(dataset_path) 32 | 33 | # print stats 34 | boxed_empty_only = dataset.filter(lambda x: x['is_boxed_empty']) 35 | print(f"Boxed empty: {len(boxed_empty_only)} / {len(dataset)}") 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser(description="Detect empty boxed solutions in a dataset.") 39 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 40 | args = parser.parse_args() 41 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_hyperlink_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import multiprocessing as mp 3 | import re 4 | import argparse 5 | 6 | def detect_latex_hyperlinks(row): 7 | # Check for \url{} links 8 | url_pattern = r'\\url' 9 | # Check for \href{}{} links 10 | href_pattern = r'\\href' 11 | # Check for http:// or https:// links 12 | http_pattern = r'http://' 13 | https_pattern = r'https://' 14 | 15 | problem = row['problem'] 16 | 17 | # Search for patterns 18 | if re.search(url_pattern, problem) or re.search(href_pattern, problem) \ 19 | or re.search(http_pattern, problem) or re.search(https_pattern, problem): 20 | row['has_hyperlink_problem'] = True 21 | else: 22 | row['has_hyperlink_problem'] = False 23 | 24 | solution = row['solution'] 25 | 26 | # Search for patterns 27 | if re.search(url_pattern, solution) or re.search(href_pattern, solution) \ 28 | or re.search(http_pattern, solution) or re.search(https_pattern, solution): 29 | row['has_hyperlink_solution'] = True 30 | else: 31 | row['has_hyperlink_solution'] = False 32 | 33 | return row 34 | 35 | def main(dataset_path): 36 | # load the dataset 37 | dataset = load_dataset(dataset_path, split="train") 38 | 39 | # run hyperlink detection over the full dataset 40 | dataset = dataset.map(detect_latex_hyperlinks, num_proc=mp.cpu_count()) 41 | 42 | # push the updated dataset 43 | dataset.push_to_hub(dataset_path) 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description="Detect LaTeX hyperlinks in a dataset.") 47 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 48 | args = parser.parse_args() 49 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_language_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import fasttext 3 | from functools import partial 4 | from huggingface_hub import hf_hub_download 5 | import multiprocessing as mp 6 | import re 7 | import argparse 8 | 9 | # load the fasttext language identification model 10 | model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") 11 | fasttext_lang_id_model = fasttext.load_model(model_path) 12 | 13 | def remove_special_chars(text): 14 | # Remove numbers, special characters, and math symbols 15 | pattern = r'[0-9*,\.+\-=\(\)\/\^\[\]{}|<>~`!@#$%&?_]' 16 | text = re.sub(pattern, '', text) 17 | 18 | # Remove extra whitespace 19 | text = re.sub(r'\s+', ' ', text).strip() 20 | 21 | return text 22 | 23 | def remove_latex_commands(text): 24 | # remove \frac{}{} commands, and any arbitrary text inside the curly braces 25 | text = re.sub(r'\\frac\{[^{}]*\}\{[^{}]*\}', '', text) 26 | 27 | # Remove commands with arguments 28 | text = re.sub(r'\\[a-zA-Z]+\{[^{}]*\}', '', text) 29 | 30 | # Remove standalone commands 31 | text = re.sub(r'\\[a-zA-Z]+', '', text) 32 | 33 | # Remove math environments 34 | text = re.sub(r'\$\$(.*?)\$\$', '', text) 35 | text = re.sub(r'\$(.*?)\$', '', text) 36 | 37 | # Remove extra whitespace 38 | text = re.sub(r'\s+', ' ', text).strip() 39 | 40 | return text 41 | 42 | def predict_text(text: str): 43 | preds = fasttext_lang_id_model.predict(text, k=-1) 44 | return [(label.replace("__label__", ""), float(score)) for label, score in zip(*preds)] 45 | 46 | def detect_lang(row, column_name='problem'): 47 | if not row[column_name]: 48 | row[f'fasttext_lang_{column_name}'] = [] 49 | row[f'{column_name}_language'] = "No language detected" 50 | return row 51 | 52 | cleaned_text = remove_special_chars(remove_latex_commands(row[column_name]).replace("\n", "")) 53 | lang_probs = predict_text(cleaned_text) 54 | 55 | # reformat the lang_probs into list of strings 56 | row[f'fasttext_lang_{column_name}'] = [f"{lang}: {prob}" for lang, prob in lang_probs[:5]] 57 | 58 | # if the text is too short, just allow it 59 | if len(cleaned_text) < 10: 60 | row[f'{column_name}_language'] = "en" 61 | return row 62 | 63 | # set default language 64 | row[f'{column_name}_language'] = "en" 65 | 66 | for lang, prob in lang_probs: 67 | if lang == "eng_Latn": 68 | eng_prob = prob 69 | 70 | # label anything with >2% probability of english as english 71 | if eng_prob < 0.02: 72 | highest_prob_lang = lang_probs[0][0] 73 | # if highest likelihood language uses latin character, set to "en" 74 | if highest_prob_lang.endswith("_Latn"): 75 | row[f'{column_name}_language'] = "en" 76 | else: 77 | row[f'{column_name}_language'] = highest_prob_lang 78 | 79 | return row 80 | 81 | def main(dataset_path): 82 | # define the columns to detect language on 83 | detect_language_on = ["problem", "final_answer"] 84 | new_columns = [f'{col}_language' for col in detect_language_on] 85 | new_columns += [f'fasttext_lang_{col}' for col in detect_language_on] 86 | 87 | # load the dataset 88 | dataset = load_dataset(dataset_path, split="train") 89 | 90 | for col in detect_language_on: 91 | # create function on the "problem" column 92 | detect_language = partial(detect_lang, column_name=col) 93 | 94 | # run language detection over the full dataset 95 | dataset = dataset.map(detect_language, num_proc=mp.cpu_count()) 96 | 97 | # push the updated dataset 98 | dataset.push_to_hub(dataset_path) 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser(description="Detect language in a dataset.") 102 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 103 | args = parser.parse_args() 104 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_multipartq_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import multiprocessing as mp 3 | import re 4 | import argparse 5 | 6 | def has_multi_part_q(row): 7 | lower_problem = row['problem'].lower() 8 | pattern1 = r'\([IⅠ\s,\.]\).*\([IⅡ\s,\.]\)' # Roman numerals 9 | pattern2 = r'\([1-9][^\)]*\).*\([1-9][^\)]*\)' # Numbered parts in parentheses 10 | pattern3 = r'(?:\d+\.).*(?:\d+\.)' # Numbered parts with period 11 | pattern4 = r'\(I+\).*\(I+\)' # Traditional Roman numerals 12 | pattern5 = r'(?:(?:^|\s)(?:1[.\)]|[Ii][.\)]|①).*(?:\s|$)(?:2[.\)]|[Ii]{2}[.\)]|②))' # Mixed types 13 | 14 | # try the patterns in order 15 | if re.search(pattern1, lower_problem): 16 | row['is_multi_part_q_regex'] = True 17 | elif re.search(pattern2, lower_problem): 18 | row['is_multi_part_q_regex'] = True 19 | elif re.search(pattern3, lower_problem): 20 | row['is_multi_part_q_regex'] = True 21 | elif re.search(pattern4, lower_problem): 22 | row['is_multi_part_q_regex'] = True 23 | elif re.search(pattern5, lower_problem): 24 | row['is_multi_part_q_regex'] = True 25 | else: 26 | row['is_multi_part_q_regex'] = False 27 | 28 | return row 29 | 30 | def main(dataset_path): 31 | # load the dataset 32 | dataset = load_dataset(dataset_path, split="train") 33 | 34 | # run multi-part question detection over the full dataset 35 | dataset = dataset.map(has_multi_part_q, num_proc=mp.cpu_count()) 36 | 37 | # push the updated dataset 38 | dataset.push_to_hub(dataset_path) 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser(description="Detect multi-part questions in a dataset.") 42 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 43 | args = parser.parse_args() 44 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_multiple_choice_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import multiprocessing as mp 3 | import argparse 4 | 5 | LETTER_OPTIONS=['A', 'B', 'C', 'D'] 6 | NUMERICAL_OPTIONS=['1', '2', '3', '4'] 7 | 8 | def is_multiple_choice(question, answer_options): 9 | # first, filter out a question that uses the answer options as part of the question 10 | # for example ABCD can be part of the question, but is not an answer option 11 | question = question.replace("".join(answer_options), "") 12 | 13 | # next, search for the last occurrence of each answer option 14 | options_found = [question.rfind(option) for option in answer_options[::-1]] 15 | 16 | # if any answer option is not found, return False 17 | if any(option == -1 for option in options_found): 18 | return False 19 | 20 | # check if the options are found in reverse order 21 | for i, option in enumerate(options_found): 22 | if i > 0 and option > options_found[i - 1]: 23 | return False 24 | 25 | return True 26 | 27 | def detect_multiple_choice(row): 28 | if is_multiple_choice(row['problem'], LETTER_OPTIONS) or is_multiple_choice(row['problem'], NUMERICAL_OPTIONS): 29 | row['is_multiple_choice'] = True 30 | else: 31 | row['is_multiple_choice'] = False 32 | return row 33 | 34 | def main(dataset_path): 35 | # pull the dataset 36 | dataset = load_dataset(dataset_path, split="train") 37 | 38 | # run language detection over the full dataset 39 | dataset = dataset.map(detect_multiple_choice, num_proc=mp.cpu_count()) 40 | 41 | # push the updated dataset 42 | dataset.push_to_hub(dataset_path) 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description="Detect multiple choice questions in a dataset.") 46 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 47 | args = parser.parse_args() 48 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_proof_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import multiprocessing as mp 3 | import nltk 4 | import argparse 5 | 6 | nltk.download('punkt_tab') 7 | 8 | # given a single row of the dataset, determine if there is a proof 9 | def has_proof(row): 10 | row['is_math_proof'] = False 11 | 12 | lower_problem = row['problem'].lower() 13 | if 'prove that' in lower_problem or 'a proof' in lower_problem: 14 | row['is_math_proof'] = True 15 | 16 | # special search in "olympiads" subset, search for "show" at the beginning of a sentence 17 | if row['source'] == "olympiads" and not row['is_math_proof']: 18 | sentences = nltk.sent_tokenize(row['problem']) 19 | for sentence in sentences: 20 | sentence = sentence.lstrip().lower() 21 | if sentence.startswith('show'): 22 | row['is_math_proof'] = True 23 | break 24 | 25 | return row 26 | 27 | def main(dataset_path): 28 | # load the dataset 29 | dataset = load_dataset(dataset_path, split="train") 30 | 31 | # run proof detection over the full dataset 32 | dataset = dataset.map(has_proof, num_proc=mp.cpu_count()) 33 | 34 | # push the dataset 35 | dataset.push_to_hub(dataset_path) 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser(description="Detect proofs in a dataset.") 39 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 40 | args = parser.parse_args() 41 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_semdedup_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from functools import partial 3 | import multiprocessing as mp 4 | import argparse 5 | from semdedup import semantic_deduplication 6 | 7 | EPSILONS = [0.5] 8 | 9 | def merge_problem_and_answer(row): 10 | row['problem_answer'] = f"Problem: {row['problem']} Answer: {row['final_answer']}" 11 | return row 12 | 13 | def is_semdedup_duplicate(row, idx, indices_to_remove, epsilon): 14 | row[f'is_semdedup_duplicate_eps{epsilon}'] = idx in indices_to_remove 15 | return row 16 | 17 | def main(dataset_path): 18 | 19 | dataset = load_dataset(dataset_path, split="train") 20 | print(f"Original Dataset: {dataset}") 21 | 22 | # create a new column that merges the problem and the answer 23 | dataset = dataset.map(merge_problem_and_answer, num_proc=mp.cpu_count()) 24 | 25 | # convert to pandas and run semdedup 26 | df = dataset.data.to_pandas() 27 | for epsilon in EPSILONS: 28 | print(f"Running semdedup with epsilon {epsilon}") 29 | df, indices_to_remove, cluster_duplicates_dfs = semantic_deduplication( 30 | df = df, 31 | required_columns = ["problem_answer"], 32 | num_kmeans_clusters = len(df) // 100, 33 | embedding_batch_size = 2500, 34 | use_gpu=True 35 | ) 36 | 37 | # add the semdedup outcome to the dataset 38 | partial_is_semdedup_duplicate = partial(is_semdedup_duplicate, indices_to_remove=indices_to_remove, epsilon=epsilon) 39 | dataset = dataset.map(partial_is_semdedup_duplicate, with_indices=True, num_proc=mp.cpu_count()) 40 | 41 | print(f"Epsilon {epsilon} Filtered Dataset: {dataset}") 42 | print(f"Epsilon {epsilon} Number of duplicates: {len(indices_to_remove)}") 43 | 44 | # remove the problem_answer column 45 | dataset = dataset.remove_columns("problem_answer") 46 | 47 | # push the merged dataset to the hub 48 | dataset.push_to_hub(dataset_path) 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser(description="Detect semantic duplicates in a dataset.") 52 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 53 | args = parser.parse_args() 54 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_true_false_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import re 3 | import multiprocessing as mp 4 | import argparse 5 | 6 | def is_true_false(row): 7 | row['is_true_false_question'] = False 8 | 9 | # check if solution has flexible "true or false" in it 10 | 11 | # if answer is not empty, check that first: 12 | if row['final_answer']: 13 | if re.search(r'(true|false)', row['final_answer'], re.IGNORECASE): 14 | row['is_true_false_question'] = True 15 | else: 16 | # check the final line of the solutions 17 | for solution in row['hard_math_solutions']: 18 | lines = solution.split('\n') 19 | if re.search(r'(true|false)', lines[-1], re.IGNORECASE): 20 | row['is_true_false_question'] = True 21 | 22 | return row 23 | 24 | def main(dataset_path): 25 | # load the dataset 26 | dataset = load_dataset(dataset_path, split="train") 27 | 28 | # run proof detection over the full dataset 29 | dataset = dataset.map(is_true_false, num_proc=mp.cpu_count()) 30 | 31 | # push the dataset 32 | dataset.push_to_hub(dataset_path) 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser(description="Detect true/false questions in a dataset.") 36 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 37 | args = parser.parse_args() 38 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/add_yes_no_signal.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import re 3 | import argparse 4 | 5 | def add_yes_no_signal(row): 6 | row['is_yes_no_question'] = False 7 | 8 | # check if the solution responds with yes/no 9 | # if final answer is not empty, check that first 10 | if row['final_answer']: 11 | if re.search(r'(\s|\{|\b)(yes|no)(\s|\}|\b)', row['final_answer'], re.IGNORECASE): 12 | row['is_yes_no_question'] = True 13 | else: 14 | # check the final line of the solution 15 | solution_lines = row['solution'].split('\n') 16 | if re.search(r'(\s|\{|\b)(yes|no)(\s|\}|\b)', solution_lines[-1], re.IGNORECASE): 17 | row['is_yes_no_question'] = True 18 | 19 | # fall back to checking if the problem sounds like a yes/no question 20 | problem_lines = row['problem'].split('\n') 21 | 22 | # search for a series of yes/no question prefixes 23 | yes_no_question_prefixes = [ 24 | "is ", "are ", "do ", "does ", "can " 25 | ] 26 | # check if there is a previous line, if so then the previous line must end with punctuation ".,:;!?". 27 | if (len(problem_lines) > 1 and re.search(r'[.,:;!?\s]$', problem_lines[-2])) or \ 28 | len(problem_lines) == 1: 29 | if any(problem_lines[-1].lower().startswith(prefix) for prefix in yes_no_question_prefixes): 30 | row['is_yes_no_question'] = True 31 | 32 | return row 33 | 34 | def main(dataset_path): 35 | # load the dataset 36 | dataset = load_dataset(dataset_path, split="train") 37 | 38 | # run the detection over the full dataset 39 | dataset = dataset.map(add_yes_no_signal) 40 | 41 | # push the dataset 42 | dataset.push_to_hub(dataset_path) 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description="Detect yes/no questions in a dataset.") 46 | parser.add_argument('dataset_path', type=str, help='Path to the dataset') 47 | args = parser.parse_args() 48 | main(args.dataset_path) -------------------------------------------------------------------------------- /signals/model_based_signals.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | from datasets import Dataset, load_dataset 4 | from enum import Enum 5 | import json 6 | import os 7 | from pydantic import BaseModel, Field, field_validator 8 | import traceback 9 | from tqdm import tqdm 10 | 11 | from rollouts_based_signals.utils.sglang_util import SGLangServerManager 12 | 13 | class ModelType(str, Enum): 14 | """supported llm model types""" 15 | Llama3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct" 16 | Llama3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct" 17 | Llama3_1_405B = "meta-llama/Llama-3.1-405B-Instruct-FP8" 18 | CLAUDE_3_5 = "claude-3-5-sonnet-latest" 19 | O1_PREVIEW = "o1-preview" 20 | O1_MINI = "o1-mini" 21 | 22 | class LLMHyperparams(BaseModel): 23 | system_prompt: str = """You are a math expert. Given the following math problem, provide your solution in Latex format. Always format your final answer in perfect LaTeX \\boxed{{final_answer}} format.""" 24 | prompt: str = "{problem}" 25 | temperature: float = Field( 26 | default=0.8, 27 | ge=0.0, 28 | le=2.0, 29 | description='Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.' 30 | ) 31 | top_k: int = Field( 32 | default=-1, 33 | description='Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.' 34 | ) 35 | top_p: float = Field( 36 | default=0.95, 37 | ge=0.0, 38 | le=1.0, 39 | description='Float that controls the cumulative probability of the top tokens to consider.' 40 | ) 41 | max_tokens: int = 2048 42 | model_name: ModelType = Field( 43 | description='The model to use for generation.', 44 | ) 45 | n: int = Field( 46 | description='Number of samples to generate.' 47 | ) 48 | 49 | @field_validator('temperature') 50 | def validate_temperature(cls, v): 51 | if v < 0: 52 | raise ValueError("temp can't be negative") 53 | return v 54 | 55 | def to_json(self) -> str: 56 | return json.dumps(self.model_dump(), indent=2) 57 | 58 | @classmethod 59 | def load_json(cls, json_str: str) -> 'LLMHyperparams': 60 | data = json.loads(json_str) 61 | return cls(**data) 62 | 63 | no_multiple_parts = [ 64 | """\ 65 | Ancient Greek mathematicians from the Pythagorean school studied various polygonal numbers, such as triangular numbers 1, 3, 6, 10, ..., with the $n$-th triangular number being $\\frac{n(n+1)}{2} = \\frac{1}{2}n^{2} + \\frac{1}{2}n$. Let the $n$-th $k$-sided polygon number be denoted as $N(n, k)$ ($k \\geq 3$). Below are the expressions for the $n$-th number of some $k$-sided polygon numbers: \nTriangular numbers $N(n, 3) = \\frac{1}{2}n^{2} + \\frac{1}{2}n$ \nSquare numbers $N(n, 4) = n^{2}$ \nPentagonal numbers $N(n, 5) = \\frac{3}{2}n^{2} - \\frac{1}{2}n$ \nHexagonal numbers $N(n, 6) = 2n^{2} - n$ \n... \nFrom this, we can deduce the expression for $N(n, k)$ and calculate $N(8, 12) = \\_\\_\\_\\_\\_\\_$.' 66 | """, 67 | """\ 68 | Let x be the dividend, y be the divisor, z be the quotient, and r be the remainder. If y = 3(y1 + y2) + 4, z = 2z1^2 - z2, r = 3r1 + 2, and x = 2x1y1 - x2 + 10, find the values of x, y, z, and r, given that y1 = 2, y2 = 3, z1 = 3, z2 = 5, r1 = 1, x1 = 4, and x2 = 6. 69 | """, 70 | """\ 71 | If $x_{1}, x_{2},\ldots ,x_{n}$ are positive real numbers with $x_{1}^2+x_2^{2}+\ldots +x_{n}^{2}=1$ , find the minimum value of $\sum_{i=1}^{n}\frac{x_{i}^{5}}{x_{1}+x_{2}+\ldots +x_{n}-x_{i}}$ . 72 | """, 73 | """ 74 | Given that the value of the function \( f(x) = \frac{1}{(x-2)^{2}} - 2x + \cos 2\theta - 3 \sin \theta + 2 \) is always positive for \( x \in (-\infty, 2) \), determine the range of the parameter \( \theta \) within the interval \( (0, \pi) \). 75 | """, 76 | """ 77 | A positive integer \( n \) is written on the blackboard. Two players, A and B, take turns performing operations, starting with A. On A's turn, if the number on the board is \( k \), A replaces it with one of \( \frac{k}{2} \), \( \frac{k}{4} \), or \( 3k \) (with the first two choices being allowed only if \( \frac{k}{2} \) or \( \frac{k}{4} \) are integers). On B's turn, if the number on the board is \( k \), B replaces it with either \( k+1 \) or \( k-1 \). Player A wins the game if the number 3 is written on the blackboard at any point. For which positive integers \( n \) does player A have a winning strategy? 78 | """, 79 | """\ 80 | A student was studying the properties of the function $f(x) = x^2e^x$ and came to the following conclusions: 81 | - ① The interval where $f(x)$ is monotonically decreasing is $(-2, 0)$; 82 | - ② $f(x)$ has neither a minimum nor a maximum value; 83 | - ③ The graph of $f(x)$ and its tangent line at $(0,0)$ intersect at two points; 84 | - ④ The graph of $f(x)$ and the line $x - y + 2012 = 0$ intersect at two points. 85 | 86 | Among these conclusions, the correct ones are __________. 87 | """, 88 | """\ 89 | Prove that the set \(A=\left\{2,2^{2}, \cdots\right., \left.2^{n}, \cdots\right\}\) satisfies the following: 90 | 91 | (1) For every \(a \in A\) and \(b \in \mathbf{N}^{*}\), if \(b < 2a - 1\), then \(b(b+1)\) is not a multiple of \(2a\). 92 | 93 | (2) For every \(a \in \bar{A}\) (where \(\bar{A}\) denotes the complement of \(A\) in \(\mathbf{N}^{*}\)) and \(a \neq 1\), there exists \(b \in \mathbf{N}^{*}\) such that \(b < 2a - 1\) and \(b(b+1)\) is a multiple of \(2a\). 94 | """, 95 | """\ 96 | You have a 200 liters mixture of four chemicals: W, X, Y, and Z which are in the ratio of 3:4:6:7. You add 40 liters of chemical W, 35 liters of chemical X, 50 liters of chemical Y and 25 liters of chemical Z to this mixture. What is the new percentage of each chemical W, X, Y and Z in the resulting mixture? 97 | """, 98 | """\ 99 | Which of the following statements are correct? 100 | 101 | A: If a sample of size $5$ is drawn from a population of $50$ individuals using simple random sampling, the probability of individual $m$ being selected is $0.1$. 102 | 103 | B: Given a set of data $1$, $2$, $m$, $6$, $7$ with an average of $4$, then the variance of this data set is $\frac{{26}}{5}$. 104 | 105 | C: The $70$th percentile of the data $13$, $27$, $24$, $12$, $14$, $30$, $15$, $17$, $19$, $23$ is $23$. 106 | 107 | D: If the standard deviation of a sample data $x_{1}$, $x_{2}$, $\cdots$, $x_{10}$ is $8$, then the standard deviation of the data $2x_{1}-1$, $2x_{2}-1$, $\cdots$, $2x_{10}-1$ is $32$. 108 | """, 109 | """\ 110 | A transparent, sealed cubic container is exactly half filled with water. When this cube is rotated arbitrarily, the shape of the water surface inside the container can be: (1) triangle; (2) rectangle; (3) square; (4) regular hexagon. Among these, the correct conclusions are __________. 111 | """, 112 | """\ 113 | Given a 100-sided polygon \( P \) in the Cartesian coordinate plane, satisfying: 114 | (i) The coordinates of the vertices of \( P \) are all integers; 115 | (ii) The sides of \( P \) are parallel to the coordinate axes; 116 | (iii) The side lengths of \( P \) are all odd numbers. 117 | 118 | Prove that the area of \( P \) is an odd number. 119 | """, 120 | """\ 121 | Given a positive integer $n \ge 2$. Find all $n$-tuples of positive integers $(a_1,a_2,\ldots,a_n)$, such that $1 0)$ intersects curve $C_1$ at point $A$ (other than the pole) and intersects curve $C_2$ at point $B$. Find $|AB|$. 138 | """, 139 | """\ 140 | Given the function $f(x)=|x+2|-|2x-a|$, $(a\in\mathbb{R})$. (I) When $a=3$, solve the inequality $f(x) > 0$; (II) When $x \in [0, +\infty)$, $f(x) < 3$ always holds, find the range of $a$. 141 | """, 142 | """\ 143 | Given an ellipse $C_1$: $\frac{x^2}{a^2} + \frac{y^2}{b^2} = 1$ ($a > b > 0$) with a major axis length of 4 and an eccentricity of $\frac{1}{2}$, where $F_1$ and $F_2$ are its left and right foci, respectively. A moving circle passes through point $F_2$ and is tangent to the line $x = -1$. (Ⅰ) (i) Find the equation of the ellipse $C_1$; (ii) Find the equation of the trajectory of the center $C$ of the moving circle; (Ⅱ) On the curve $C$, there are two points $M$ and $N$, and on the ellipse $C_1$, there are two points $P$ and $Q$, satisfying that $MF_2$ and $\overrightarrow{NF_2}$ are collinear, $\overrightarrow{PF_2}$ and $\overrightarrow{QF_2}$ are collinear, and $\overrightarrow{PF_2} \cdot \overrightarrow{MF_2} = 0$, find the minimum value of the area of quadrilateral $PMQN$. 144 | """, 145 | """\ 146 | In the rectangular coordinate system $xOy$, a polar coordinate system is established with the coordinate origin as the pole and the positive semi-axis of the $x$-axis as the polar axis. The polar coordinate equation of circle $C$ is $\rho^2 - 2m\rho\cos\theta + 4\rho\sin\theta = 1 - 2m$. (1) Find the rectangular coordinate equation of $C$ and its radius. (2) When the radius of $C$ is the smallest, the curve $y = \sqrt{3}|x - 1| - 2$ intersects $C$ at points $A$ and $B$, and point $M(1, -4)$. Find the area of $\triangle MAB$. 147 | """, 148 | """\ 149 | In recent years, the emergence of "shared bicycles" has greatly facilitated the "green travel" for citizens. A shared bicycle company "Mobie" plans to invest a total of 1.2 million yuan in two cities, A and B. According to industry regulations, each city must receive an investment of at least 400,000 yuan. Preliminary market research shows that the profit $P$ in city A and the investment $a$ (in units of 10,000 yuan) satisfy $P=3 \sqrt{2a}-6$, and the profit $Q$ in city B and the investment $a$ (in units of 10,000 yuan) satisfy $Q= \frac{1}{4}a+2$. Let the investment in city A be $x$ (in units of 10,000 yuan), and the total profit of the two cities be $f(x)$ (in units of 10,000 yuan). $(1)$ When the investment in city A is 500,000 yuan, calculate the total profit of the company at this time; $(2)$ How should the investments in cities A and B be arranged to maximize the total profit? 150 | """, 151 | """ 152 | From points \( M \) and \( K \), which are 70 km apart, a bus and a cyclist set out towards each other simultaneously. They met after 1 hour and 24 minutes. Continuing at the same speed, the bus arrived at \( K \) and left for the return journey after a 20-minute stop. Find the speeds of the bus and the cyclist if the bus overtook the cyclist 2 hours and 41 minutes after their first meeting. 153 | """, 154 | """\ 155 | Let $ L$ denote the set of all lattice points of the plane (points with integral coordinates). Show that for any three points $ A,B,C$ of $ L$ there is a fourth point $ D,$ different from $ A,B,C,$ such that the interiors of the segments $ AD,BD,CD$ contain no points of $ L.$ Is the statement true if one considers four points of $ L$ instead of three? 156 | """, 157 | """\ 158 | For n real numbers $a_{1},\, a_{2},\, \ldots\, , a_{n},$ let $d$ denote the difference between the greatest and smallest of them and $S = \sum_{i a > 1\), and \(a\) does not divide \(b\), and a given sequence of positive integers \(\{b_n\}_{n=1}^{\infty}\) satisfying \(b_{n+1} \geq 2b_n\) for all positive integers \(n\). Does there always exist a sequence of positive integers \(\{a_n\}_{n=1}^{\infty}\) such that for all positive integers \(n\), \(a_{n+1} - a_n \in \{a, b\}\), and for all positive integers \(m\) and \(l\) (which can be the same), \(a_m + a_l \notin \{b_n\}\) for all \(n\)? 165 | """, 166 | """\ 167 | Let \( f(x) = x^n, x \in D, n \in \mathbf{N}^{+} \). Determine whether \( f(x) \) is a solution to the functional inequality 168 | \[ 169 | f(x) + f(1-x) > 1 170 | \] 171 | If so, find the domain \( D \); if not, provide an explanation. 172 | """, 173 | """\ 174 | In a right angled-triangle $ABC$, $\angle{ACB} = 90^o$. Its incircle $O$ meets $BC$, $AC$, $AB$ at $D$,$E$,$F$ respectively. $AD$ cuts $O$ at $P$. If $\angle{BPC} = 90^o$, prove $AE + AP = PD$. 175 | """, 176 | """\ 177 | A(x,y), B(x,y), and C(x,y) are three homogeneous real-coefficient polynomials of x and y with degree 2, 3, and 4 respectively. we know that there is a real-coefficient polinimial R(x,y) such that $B(x,y)^2-4A(x,y)C(x,y)=-R(x,y)^2$. Show that there exist 2 polynomials F(x,y,z) and G(x,y,z) such that $F(x,y,z)^2+G(x,y,z)^2=A(x,y)z^2+B(x,y)z+C(x,y)$ if for any x, y, z real numbers $A(x,y)z^2+B(x,y)z+C(x,y)\ge 0$ 178 | """, 179 | """\ 180 | Prove \[\frac{1}{\cos 0^\circ \cos 1^\circ} + \frac{1}{\cos 1^\circ \cos 2^\circ} + \cdots + \frac{1}{\cos 88^\circ \cos 89^\circ} = \frac{\cos 1^\circ}{\sin^2 1^\circ}.\] 181 | """, 182 | ] 183 | 184 | no_proof_questions = [ 185 | """\ 186 | In a $100 \times 25$ rectangular table, each cell is filled with a non-negative real number. The number in the $i$-th row and $j$-th column is denoted by $x_{i, j}$ $(i=1,2,\ldots, 100; j=1,2,\ldots, 25)$ (Table 1). The numbers in each column of Table 1 are then reordered in descending order to create Table 2 such that $x_{1, j}^{\prime} \geq x_{2, j}^{\prime} \geq \cdots \geq x_{100, j}^{\prime}$ $(j=1,2,\ldots, 25)$. Find the smallest natural number $k$ such that if the numbers in Table 1 satisfy $\sum_{j=1}^{25} x_{i, j} \leq 1$ $(i=1,2,\ldots, 100)$, then for $i \geq k$, Table 2 satisfies $\sum_{j=1}^{25} x_{i, j}^{\prime} \leq 1$ $(i=1,2,\ldots, 100)$. 187 | """, 188 | """\ 189 | We are given $2n$ natural numbers 190 | \[1, 1, 2, 2, 3, 3, \ldots, n - 1, n - 1, n, n.\] 191 | Find all $n$ for which these numbers can be arranged in a row such that for each $k \leq n$, there are exactly $k$ numbers between the two numbers $k$. 192 | """, 193 | """\ 194 | Determine all positive integers $n$, $n\ge2$, such that the following statement is true: If $(a_1,a_2,...,a_n)$ is a sequence of positive integers with $a_1+a_2+\cdots+a_n=2n-1$, then there is block of (at least two) consecutive terms in the sequence with their (arithmetic) mean being an integer. 195 | """, 196 | """\ 197 | Turbo the snail sits on a point on a circle with circumference $1$. Given an infinite sequence of positive real numbers $c_1, c_2, c_3, \dots$, Turbo successively crawls distances $c_1, c_2, c_3, \dots$ around the circle, each time choosing to crawl either clockwise or counterclockwise. 198 | Determine the largest constant $C > 0$ with the following property: for every sequence of positive real numbers $c_1, c_2, c_3, \dots$ with $c_i < C$ for all $i$, Turbo can (after studying the sequence) ensure that there is some point on the circle that it will never visit or crawl across. 199 | """, 200 | """\ 201 | For an even integer positive integer $n$ Kevin has a tape of length $4 n$ with marks at $-2 n,-2 n+1, \ldots, 2 n-1,2 n$. He then randomly picks $n$ points in the set $-n,-n+1,-n+2, \ldots, n-1, n$, and places a stone on each of these points. We call a stone 'stuck' if it is on $2 n$ or $-2 n$, or either all the points to the right, or all the points to the left, all contain stones. Then, every minute, Kevin shifts the unstuck stones in the following manner: He picks an unstuck stone uniformly at random and then flips a fair coin. If the coin came up heads, he then moves that stone and every stone in the largest contiguous set containing that stone one point to the left. If the coin came up tails, he moves every stone in that set one point right instead. He repeats until all the stones are stuck. Let $p_{k}$ be the probability that at the end of the process there are exactly $k$ stones in the right half. Evaluate $$\frac{p_{n-1}-p_{n-2}+p_{n-3}-\ldots+p_{3}-p_{2}+p_{1}}{p_{n-1}+p_{n-2}+p_{n-3}+\ldots+p_{3}+p_{2}+p_{1}}$$ in terms of $n$. 202 | """, 203 | """\ 204 | A 0-1 sequence of length $2^k$ is given. Alice can pick a member from the sequence, and reveal it (its place and its value) to Bob. Find the largest number $s$ for which Bob can always pick $s$ members of the sequence, and guess all their values correctly. 205 | 206 | Alice and Bob can discuss a strategy before the game with the aim of maximizing the number of correct guesses of Bob. The only information Bob has is the length of the sequence and the member of the sequence picked by Alice. 207 | """ 208 | ] 209 | 210 | 211 | def formatted_items(problem, filter_type): 212 | if filter_type == "multiple_choice": 213 | prompt = f"""\ 214 | Given this question: {problem} 215 | 216 | Is this a multiple choice question (a question that provides specific options to choose from, typically labeled as A, B, C, D or 1, 2, 3, 4)? 217 | 218 | Return only "yes" or "no" without any additional explanation. 219 | """ 220 | elif filter_type == "proof": 221 | prompt = f"""\ 222 | Given this question: {problem} 223 | 224 | Is this a mathematical proof question (a question that asks to prove a statement, theorem, or property...)? 225 | 226 | Examples of proof indicators: 227 | - "Prove that..." 228 | - "Show that..." 229 | - "Demonstrate why..." 230 | - "Justify your answer..." 231 | - "Explain why..." 232 | etc. 233 | Here are examples of proof questions: 234 | Example 1: {proof_questions[0]} 235 | Example 2: {proof_questions[1]} 236 | Example 3: {proof_questions[2]} 237 | Example 4: {proof_questions[3]} 238 | Example 5: {proof_questions[4]} 239 | Here are examples of non-proof questions: 240 | Example 1: {no_proof_questions[0]} 241 | Example 2: {no_proof_questions[1]} 242 | Example 3: {no_proof_questions[2]} 243 | Example 4: {no_proof_questions[3]} 244 | Example 5: {no_proof_questions[4]} 245 | Example 6: {no_proof_questions[5]} 246 | 247 | Return only "yes" or "no" without any additional explanation. 248 | """ 249 | elif filter_type == "yes_no": 250 | prompt = f"""\ 251 | Given this question: {problem} 252 | 253 | Is this a yes/no question (a question that asks to choose between two options, typically labeled as yes or no)? 254 | 255 | Return only "yes" or "no" without any additional explanation. 256 | """ 257 | elif filter_type == "true_false": 258 | prompt = f"""\ 259 | Given this question: {problem} 260 | 261 | Is this a true/false question (a question that asks to choose between two options, typically labeled as true or false)? 262 | 263 | Return only "true" or "false" without any additional explanation. 264 | """ 265 | elif filter_type == "multiple_part": 266 | prompt = f"""\ 267 | Your task is to determine if the given question contains multiple sub-questions, sub-parts, or sub-tasks. 268 | A multi-part question requires separate answers for different components, rather than a single comprehensive answer. 269 | Besides that, if the question is multiple choice and only requires to select one option, it is not a multi-part question. 270 | 271 | Here are examples of multi-part questions that require multiple distinct answers: 272 | Example 1: {multiple_parts[0]} 273 | Example 2: {multiple_parts[1]} 274 | Example 3: {multiple_parts[2]} 275 | Example 4: {multiple_parts[3]} 276 | Example 5: {multiple_parts[4]} 277 | Example 6: {multiple_parts[5]} 278 | Example 7: {multiple_parts[6]} 279 | Example 8: {multiple_parts[7]} 280 | Example 9: {multiple_parts[8]} 281 | 282 | Here are examples of single-part questions that require only one answer: 283 | Example 1: {no_multiple_parts[0]} 284 | Example 2: {no_multiple_parts[1]} 285 | Example 3: {no_multiple_parts[2]} 286 | Example 4: {no_multiple_parts[3]} 287 | Example 5: {no_multiple_parts[4]} 288 | Example 6: {no_multiple_parts[5]} 289 | Example 7: {no_multiple_parts[6]} 290 | Please analyze this question: {problem} 291 | 292 | Does this question contain multiple parts requiring separate answers? 293 | Return only "yes" or "no" without any additional explanation. 294 | """ 295 | else: 296 | raise ValueError(f"Invalid type: {filter_type}") 297 | 298 | return [{"role": "user", "content": prompt}] 299 | 300 | 301 | async def main(): 302 | # Configuration 303 | model_name = args.model_name 304 | dataset_name = args.dataset_name 305 | save_name = args.save_name 306 | save_folder = args.save_folder 307 | problem_column_name = args.problem_column_name 308 | if args.dataset_name_outputs: 309 | hf_save_dataset_name = args.dataset_name_outputs 310 | else: 311 | hf_save_dataset_name = dataset_name 312 | os.makedirs(save_folder, exist_ok=True) 313 | 314 | batch_size = args.batch_size 315 | save_interval = args.save_interval 316 | 317 | # Sampling hyperparameters 318 | llm_params = LLMHyperparams( 319 | top_k=args.top_k, 320 | top_p=args.top_p, 321 | temperature=args.temperature, 322 | model_name=ModelType(model_name), 323 | n=args.n, 324 | max_tokens=args.max_tokens 325 | ) 326 | 327 | print("Hyperparameters:") 328 | print(llm_params.to_json()) 329 | 330 | # Load and preprocess dataset 331 | ds = load_dataset(dataset_name, split=args.dataset_split) 332 | df = ds.to_pandas() 333 | 334 | # if only using partial dataset, slice it 335 | if args.end == -1: 336 | args.end = len(df) 337 | df = df.iloc[args.start:args.end] 338 | df = df.reset_index(drop=True) 339 | 340 | print(f"Total dataset: {len(df)}") 341 | 342 | # prepare data for generation 343 | items = {} 344 | if args.multiple_choice: 345 | items['multiple_choice'] = [formatted_items(row[problem_column_name], "multiple_choice") for _, row in df.iterrows()] 346 | if args.proof: 347 | items['proof'] = [formatted_items(row[problem_column_name], "proof") for _, row in df.iterrows()] 348 | if args.yes_no: 349 | items['yes_no'] = [formatted_items(row[problem_column_name], "yes_no") for _, row in df.iterrows()] 350 | if args.true_false: 351 | items['true_false'] = [formatted_items(row[problem_column_name], "true_false") for _, row in df.iterrows()] 352 | if args.multiple_part: 353 | items['multiple_part'] = [formatted_items(row[problem_column_name], "multiple_part") for _, row in df.iterrows()] 354 | 355 | 356 | # Process items in batches 357 | total_items = len(df) 358 | count = 0 359 | with SGLangServerManager(model_name, tp=args.tp) as server_handler: 360 | for idx in tqdm(range(0, total_items, batch_size)): 361 | print(f"Processing indices {idx}:{idx+batch_size}...") 362 | for filter_type in items.keys(): 363 | batch_items = items[filter_type][idx:idx+batch_size] 364 | batch_outputs = await server_handler.get_chat_responses( 365 | batch_items, 366 | n=llm_params.n, 367 | top_p=llm_params.top_p, 368 | temperature=llm_params.temperature, 369 | max_tokens=llm_params.max_tokens 370 | ) 371 | batch_responses = [] 372 | for resp in batch_outputs: 373 | try: 374 | batch_responses.append(resp[-1]["responses"]) 375 | except Exception as e: 376 | print(f"Response: {resp}") 377 | traceback_str = traceback.format_exc() 378 | print(f"Error processing response: {traceback_str}") 379 | batch_responses.append([""]) 380 | 381 | count += 1 382 | # Assign responses to dataframe 383 | for i, response_list in enumerate(batch_responses): 384 | if isinstance(response_list, list) and len(response_list) > args.n: 385 | print(f"Response List Error: Length > n: {response_list}") 386 | response_list = response_list[:args.n] 387 | # if processing a single output (this is the usual case), extract the response from the list 388 | if args.n == 1: 389 | response_list = response_list[0] 390 | df.at[idx+i, f"{args.output_model_name}_{filter_type}"] = response_list 391 | 392 | # Save checkpoint 393 | if count % save_interval == 0: 394 | try: 395 | df.iloc[:idx+batch_size].to_parquet( 396 | os.path.join(save_folder, f"{save_name}_{count}_batch.parquet") 397 | ) 398 | ds = Dataset.from_pandas(df) 399 | ds.push_to_hub(hf_save_dataset_name, private=True) 400 | except Exception as e: 401 | print(f"Error saving checkpoint: {e}") 402 | # Save final results 403 | try: 404 | df.to_parquet(os.path.join(save_folder, f"{save_name}.parquet")) 405 | ds = Dataset.from_pandas(df) 406 | ds.push_to_hub(hf_save_dataset_name, private=True) 407 | print(f"Saved to {os.path.join(save_folder, f'{save_name}.parquet')}") 408 | except Exception as e: 409 | print(f"Error saving final results: {e}") 410 | 411 | if __name__ == "__main__": 412 | parser = argparse.ArgumentParser() 413 | # Model and dataset configuration 414 | parser.add_argument("--model_name", type=str, required=True) 415 | parser.add_argument("--output_model_name", type=str, required=True, 416 | help="Name to be prepended to new column names.") 417 | parser.add_argument("--dataset_name", type=str, required=True) 418 | parser.add_argument("--dataset_name_outputs", type=str, 419 | help="To save the outputs to a different HF dataset, specify here.") 420 | parser.add_argument("--dataset_split", type=str, default="train") 421 | parser.add_argument("--problem_column_name", type=str, default="problem") 422 | 423 | # Filters configuration 424 | parser.add_argument("--multiple_choice", action="store_true") 425 | parser.add_argument("--proof", action="store_true") 426 | parser.add_argument("--yes_no", action="store_true") 427 | parser.add_argument("--true_false", action="store_true") 428 | parser.add_argument("--multiple_part", action="store_true") 429 | 430 | # Save configuration 431 | parser.add_argument("--save_folder", type=str, required=True) 432 | parser.add_argument("--save_name", type=str, required=True) 433 | parser.add_argument("--save_interval", type=int, default=10000, 434 | help="Save every n batches.") 435 | 436 | # SGLang server configuration 437 | parser.add_argument("--tp", type=int, default=1) 438 | parser.add_argument("--batch_size", type=int, default=10000, 439 | help="Total batch size will be args.batch_size * args.n.") 440 | 441 | # LLM Hyperparameters 442 | parser.add_argument("--top_k", type=int, default=-1) 443 | parser.add_argument("--top_p", type=float, default=0.95) 444 | parser.add_argument("--temperature", type=float, default=0.2) 445 | parser.add_argument("--max_tokens", type=int, default=5) 446 | parser.add_argument("--n", type=int, default=1) 447 | 448 | # dataset slicing 449 | parser.add_argument("--start", type=int, default=0) 450 | parser.add_argument("--end", type=int, default=-1) 451 | args = parser.parse_args() 452 | 453 | asyncio.run(main()) 454 | 455 | # Example Usage: 456 | # python3 model_based_filters.py --model_name "meta-llama/Meta-Llama-3.1-70B-Instruct" --output_model_name "llama3-70b" --dataset_name "RLAIF/INTERNAL-ONLY-Big-Math-RL-Verified-MC-Rewrites" --dataset_name_outputs "RLAIF/model_filter_testing" --save_folder "outputs" --save_name "model_based_filters" --multiple_choice --proof --yes_no --true_false --multiple_part --tp 8 > model_based_filter_outputs.txt 2>&1 -------------------------------------------------------------------------------- /signals/requirements.txt: -------------------------------------------------------------------------------- 1 | asynciolimiter==1.0.0 2 | datasets==3.0.1 3 | faiss_cpu==1.9.0.post1 4 | fasttext==0.9.3 5 | huggingface_hub==0.26.1 6 | loguru==0.7.2 7 | nltk==3.9.1 8 | openai==1.63.2 9 | pandas==2.2.3 10 | pydantic==2.10.6 11 | sympy==1.13.3 12 | tenacity==9.0.0 13 | torch==2.4.0 14 | tqdm==4.66.5 15 | transformers==4.45.2 16 | -------------------------------------------------------------------------------- /signals/rollouts_based_signals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/rollouts_based_signals/__init__.py -------------------------------------------------------------------------------- /signals/rollouts_based_signals/evaluate_responses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | from datasets import load_dataset 4 | from functools import partial 5 | from tqdm.asyncio import tqdm_asyncio 6 | from typing import Union 7 | 8 | from math_eval import MathEvaluator, is_correct_no_judge, get_answer_expr 9 | 10 | async def rate_limit_is_correct(evaluator: MathEvaluator, answer: str, pred: str, sem: asyncio.Semaphore): 11 | """ 12 | Asynchronously checks if the predicted answer is correct, with rate limiting. 13 | """ 14 | async with sem: 15 | return await evaluator.is_correct(answer, pred) 16 | 17 | async def evaluate_preds_async(preds: Union[list[list],list], ground_truth_answers: list, evaluator: MathEvaluator, sem: asyncio.Semaphore = None): 18 | """ 19 | Asynchronously evaluates predictions against ground truth answers using a provided evaluator. 20 | 21 | Args: 22 | preds (Union[list[list], list]): A list of predictions or a list of lists of predictions. 23 | ground_truth_answers (list): A list of ground truth answers corresponding to the predictions. 24 | evaluator (MathEvaluator): An instance of MathEvaluator used to evaluate the predictions. 25 | sem (asyncio.Semaphore, optional): An optional semaphore to limit the number of concurrent evaluations. Defaults to None. 26 | 27 | Returns: 28 | list: A list of evaluation results. 29 | 30 | """ 31 | print("Evaluating predictions...") 32 | if isinstance(preds[0], list): 33 | tasks = [] 34 | for pred, answer in zip(preds, ground_truth_answers): 35 | tasks.extend([asyncio.create_task(rate_limit_is_correct(evaluator, answer, p, sem)) for p in pred]) 36 | else: 37 | tasks = [asyncio.create_task(rate_limit_is_correct(evaluator, answer, pred, sem)) for pred, answer in zip(preds, ground_truth_answers)] 38 | result = await tqdm_asyncio.gather(*tasks) 39 | return result 40 | 41 | def extract_answers(row, evaluator: MathEvaluator, response_column_name: str, predicted_answer_column_name: str): 42 | """Extracts answers from the response column and adds them as a separate column.""" 43 | if isinstance(row[response_column_name], list): 44 | row[predicted_answer_column_name] = [evaluator.get_answer_expr(p) for p in row[response_column_name]] 45 | else: 46 | row[predicted_answer_column_name] = evaluator.get_answer_expr(row[response_column_name]) 47 | return row 48 | 49 | def extract_and_evaluate_answers( 50 | row, 51 | response_column_name: str, 52 | predicted_answer_column_name: str, 53 | predicted_answer_correctness_column_name: str, 54 | ground_truth_answer_column_name: str, 55 | ): 56 | """Extracts answers from the response column and evaluates correctness.""" 57 | if isinstance(row[response_column_name], list): 58 | row[predicted_answer_column_name] = [get_answer_expr(p) for p in row[response_column_name]] 59 | row[predicted_answer_correctness_column_name] = [is_correct_no_judge(row[ground_truth_answer_column_name], p) for p in row[predicted_answer_column_name]] 60 | else: 61 | row[predicted_answer_column_name] = get_answer_expr(row[response_column_name]) 62 | row[predicted_answer_correctness_column_name] = is_correct_no_judge(row[ground_truth_answer_column_name], row[predicted_answer_column_name]) 63 | return row 64 | 65 | async def async_main(args): 66 | # define the column names 67 | 68 | response_column_name = args.response_column_name 69 | ground_truth_answer_column_name = args.ground_truth_answer_column_name 70 | 71 | # define the answer and correctness columns 72 | predicted_answer_column_name = f"{args.response_column_name}_extracted_answers" 73 | predicted_answer_correctness_column_name = f"{args.response_column_name}_correctness" 74 | 75 | # load the dataset 76 | preds_dataset = load_dataset(args.predictions_dataset_name, split=args.dataset_split) 77 | 78 | # Make sure we can fulfill the evaluation requirements 79 | if args.calculate_greedy_accuracy: 80 | assert "greedy_correct" in preds_dataset.column_names, "Predictions dataset does not contain greedy_correct column" 81 | 82 | evaluator = MathEvaluator() 83 | 84 | # Extract the answers from responses 85 | if predicted_answer_column_name in preds_dataset.column_names or predicted_answer_correctness_column_name in preds_dataset.column_names: 86 | # if answers are already extracted, or correctness is already calculated, skip extraction 87 | print("Answers already extracted. Skipping extraction") 88 | else: 89 | # extract the answers from the response column and add them as a separate column 90 | print(f"Extracting answers to {predicted_answer_column_name}...") 91 | partial_extract_answers = partial( 92 | extract_answers, 93 | evaluator=evaluator, 94 | response_column_name=response_column_name, 95 | predicted_answer_column_name=predicted_answer_column_name 96 | ) 97 | # TODO: Rewrite the math_evaluator.is_correct() function so that it can live outside of the MathEvaluator object 98 | # otherwise, this can only run with num_proc=1 99 | # HuggingFace datasets cannot pickle the MathEvaluator object 100 | # returns error: "TypeError: cannot pickle 'SSLContext' object" 101 | preds_dataset = preds_dataset.map(partial_extract_answers) 102 | 103 | if predicted_answer_correctness_column_name in preds_dataset.column_names: 104 | # if correctness is already calculated, skip evaluation 105 | print("Correctness already calculated. Skipping evaluation") 106 | else: 107 | # calculate the correctness of predicted answers 108 | print(f"Evaluating correctness of {predicted_answer_column_name}...") 109 | sem = asyncio.Semaphore(args.semaphore_limit) 110 | answer_correctness = await evaluate_preds_async( 111 | preds=preds_dataset[predicted_answer_column_name], 112 | ground_truth_answers=preds_dataset[ground_truth_answer_column_name], 113 | evaluator=evaluator, 114 | sem=sem 115 | ) 116 | 117 | # reshape the correctness list to match the shape of the predictions 118 | if isinstance(preds_dataset[predicted_answer_column_name][0], list): 119 | reshaped_answer_correctness = [] 120 | cur_start = 0 121 | for i in range(len(preds_dataset[predicted_answer_column_name])): 122 | cur_len = len(preds_dataset[predicted_answer_column_name][i]) 123 | reshaped_answer_correctness.append(answer_correctness[cur_start:cur_start+cur_len]) 124 | cur_start += cur_len 125 | answer_correctness = reshaped_answer_correctness 126 | 127 | # add the correctness column to the dataset 128 | preds_dataset = preds_dataset.add_column(predicted_answer_correctness_column_name, answer_correctness) 129 | 130 | # push the dataset back to the hub 131 | preds_dataset.push_to_hub(args.predictions_dataset_name, private=True) 132 | 133 | def main(args): 134 | # define the column names 135 | 136 | response_column_name = args.response_column_name 137 | ground_truth_answer_column_name = args.ground_truth_answer_column_name 138 | 139 | # define the answer and correctness columns 140 | predicted_answer_column_name = f"{args.response_column_name}_extracted_answers" 141 | predicted_answer_correctness_column_name = f"{args.response_column_name}_correctness" 142 | 143 | # load the dataset 144 | preds_dataset = load_dataset(args.predictions_dataset_name, split=args.dataset_split) 145 | 146 | # Make sure we can fulfill the evaluation requirements 147 | if args.calculate_greedy_accuracy: 148 | assert "greedy_correct" in preds_dataset.column_names, "Predictions dataset does not contain greedy_correct column" 149 | 150 | # Extract answers and evaluate correctness in a single step 151 | print(f"Extracting and evaluating correctness of {predicted_answer_column_name}...") 152 | partial_extract_and_evaluate_answers = partial( 153 | extract_and_evaluate_answers, 154 | response_column_name=response_column_name, 155 | predicted_answer_column_name=predicted_answer_column_name, 156 | predicted_answer_correctness_column_name=predicted_answer_correctness_column_name, 157 | ground_truth_answer_column_name=ground_truth_answer_column_name 158 | ) 159 | 160 | preds_dataset = preds_dataset.map(partial_extract_and_evaluate_answers, num_proc=64) 161 | 162 | # push the dataset back to the hub 163 | preds_dataset.push_to_hub(args.predictions_dataset_name, private=True) 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser(description="Evaluate rollouts") 167 | # Model and dataset arguments 168 | parser.add_argument("--predictions_dataset_name", type=str, required=True, 169 | help="The name of the dataset containing the predictions.") 170 | parser.add_argument("--response_column_name", type=str, required=True, 171 | help="The name of the column containing the model's responses.") 172 | parser.add_argument("--ground_truth_answer_column_name", type=str, required=True, 173 | help="The name of the column containing the ground truth answers.") 174 | parser.add_argument("--dataset_split", type=str, default="train") 175 | parser.add_argument("--calculate_greedy_accuracy", action="store_true") 176 | 177 | # Evaluation config 178 | parser.add_argument("--use_llm_judge_backup", action="store_true", 179 | help="Whether to use the LLM judge as a backup for evaluation.") 180 | 181 | # system configuration 182 | parser.add_argument("--num_proc", type=int, default=1, help="Number of processes to use for data processing.") 183 | parser.add_argument("--semaphore_limit", type=int, default=20, 184 | help="The maximum number of concurrent requests to the evaluator, when using LLM judge as backup.") 185 | 186 | args = parser.parse_args() 187 | 188 | if args.use_llm_judge_backup: 189 | asyncio.run(async_main(args)) 190 | else: 191 | main(args) 192 | 193 | # Example usage: 194 | # No LLM-judge 195 | # python3 evaluate_responses.py --predictions_dataset_name RLAIF/Big-Math-needs-llama3-8b-rollouts --response_column_name responses --ground_truth_answer_column_name final_answer 196 | # With LLM-judge 197 | # python3 evaluate_responses.py --predictions_dataset_name RLAIF/Big-Math-needs-llama3-8b-rollouts --response_column_name responses --ground_truth_answer_column_name final_answer --use_llm_judge_backup -------------------------------------------------------------------------------- /signals/rollouts_based_signals/example_solve_rate_script.sh: -------------------------------------------------------------------------------- 1 | # define some variables 2 | 3 | MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct" 4 | 5 | # path to the dataset to generate rollouts on 6 | EVAL_DATASET_PATH="SynthLabsAI/Big-Math-RL-Verified" 7 | 8 | # path to save the updated dataset to: change this to a valid path 9 | SAVE_DATASET_PATH="" 10 | 11 | # local path to save intermediate results to 12 | INTERMEDIATE_STEPS_SAVE_DIR="llama8b_solve_rate" 13 | 14 | # column to save the models responses to 15 | RESPONSE_COLUMN_NAME="llama8b_response" 16 | 17 | # set the number of processes to use 18 | HF_NUM_PROC=16 19 | 20 | # make sure the user changed the SAVE_DATASET_PATH 21 | if [ "${SAVE_DATASET_PATH}" == "" ]; then 22 | echo "Please change the SAVE_DATASET_PATH variable to a valid path" 23 | exit 1 24 | fi 25 | 26 | # first, we need to run the inference script 27 | echo "Running inference on ${MODEL} on ${EVAL_DATASET_PATH} and saving to ${SAVE_DATASET_PATH}" 28 | mkdir -p ${INTERMEDIATE_STEPS_SAVE_DIR} 29 | python3 sample_from_model.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --dataset_name ${EVAL_DATASET_PATH} --dataset_name_outputs ${SAVE_DATASET_PATH} --response_column_name ${RESPONSE_COLUMN_NAME} --save_folder ${INTERMEDIATE_STEPS_SAVE_DIR} --save_name llama8b_inference --max_tokens 2048 --n 64 30 | 31 | # then, we evaluate the response correctness 32 | echo "Evaluating responses" 33 | python3 evaluate_responses.py --predictions_dataset_name ${SAVE_DATASET_PATH} --response_column_name ${RESPONSE_COLUMN_NAME} --ground_truth_answer_column_name final_answer --num_proc ${HF_NUM_PROC} --save_folder ${INTERMEDIATE_STEPS_SAVE_DIR} --save_name llama8b_eval_responses 34 | 35 | echo "Done!" -------------------------------------------------------------------------------- /signals/rollouts_based_signals/math_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | # math_eval.py 3 | """ 4 | 5 | # TODO: replace is_equiv and sympy_match with math-verify (https://github.com/huggingface/Math-Verify) 6 | 7 | 8 | import re 9 | import signal 10 | import traceback 11 | from typing import AsyncIterator, Dict, List, Optional 12 | 13 | import sympy 14 | from asynciolimiter import StrictLimiter 15 | from loguru import logger 16 | from openai import AsyncOpenAI 17 | from sympy.parsing.latex import parse_latex 18 | from tenacity import AsyncRetrying, stop_after_attempt, wait_exponential 19 | 20 | 21 | def last_boxed_only_string(string: str) -> Optional[str]: 22 | """ 23 | Extracts the last boxed expression from a string. 24 | 25 | Args: 26 | string (str): The input string. 27 | 28 | Returns: 29 | Optional[str]: The last boxed expression or None. 30 | """ 31 | idx = string.rfind("\\boxed") 32 | if "\\boxed " in string: 33 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] 34 | if idx < 0: 35 | idx = string.rfind("\\fbox") 36 | if idx < 0: 37 | return None 38 | 39 | i = idx 40 | right_brace_idx = None 41 | num_left_braces_open = 0 42 | while i < len(string): 43 | if string[i] == "{": 44 | num_left_braces_open += 1 45 | if string[i] == "}": 46 | num_left_braces_open -= 1 47 | if num_left_braces_open == 0: 48 | right_brace_idx = i 49 | break 50 | i += 1 51 | 52 | return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None 53 | 54 | def remove_boxed(s: str) -> Optional[str]: 55 | """ 56 | Removes the \boxed or \fbox formatting from a string. 57 | 58 | Args: 59 | s (str): The input string. 60 | 61 | Returns: 62 | Optional[str]: String without boxed formatting or None. 63 | """ 64 | # pattern = r"\\boxed\s*{([^}]*)}" 65 | # return re.sub(pattern, r"\1", s, flags=re.DOTALL) 66 | 67 | if "\\boxed " in s: 68 | left = "\\boxed " 69 | assert s[: len(left)] == left 70 | return s[len(left) :] 71 | elif "\\boxed{" in s: 72 | left = "\\boxed{" 73 | assert s[: len(left)] == left 74 | assert s[-1] == "}" 75 | return s[len(left) : -1] 76 | elif "\\fbox{" in s: 77 | left = "\\fbox{" 78 | assert s[: len(left)] == left 79 | assert s[-1] == "}" 80 | return s[len(left) : -1] 81 | else: 82 | return s 83 | 84 | def get_answer_expr(answer: str) -> str: 85 | """ 86 | Extracts the mathematical expression from the answer. 87 | 88 | Args: 89 | answer (str): The answer string. 90 | 91 | Returns: 92 | str: Extracted expression. 93 | """ 94 | try: 95 | answer = remove_boxed(last_boxed_only_string(answer)) 96 | except Exception: 97 | answer = answer.split("\n")[-1] 98 | return answer 99 | 100 | def remove_right_units(string: str) -> str: 101 | """ 102 | Removes units described within \\text{ } at the end of the string. 103 | 104 | Args: 105 | string (str): The input string. 106 | 107 | Returns: 108 | str: String without units. 109 | """ 110 | if "\\text{ " in string: 111 | splits = string.split("\\text{ ") 112 | if len(splits) == 2: 113 | return splits[0] 114 | return string 115 | 116 | def fix_sqrt(string: str) -> str: 117 | """ 118 | Ensures that square roots in the string are properly formatted with braces. 119 | 120 | Args: 121 | string (str): The input string. 122 | 123 | Returns: 124 | str: String with fixed square roots. 125 | """ 126 | if "\\sqrt" not in string: 127 | return string 128 | splits = string.split("\\sqrt") 129 | new_string = splits[0] 130 | for split in splits[1:]: 131 | if not split.startswith("{"): 132 | if len(split) < 1: 133 | return string 134 | a = split[0] 135 | new_substr = f"\\sqrt{{{a}}}" + split[1:] 136 | else: 137 | new_substr = "\\sqrt" + split 138 | new_string += new_substr 139 | return new_string 140 | 141 | def fix_fracs(string: str) -> str: 142 | """ 143 | Fixes improperly formatted fractions in a LaTeX string. 144 | 145 | Args: 146 | string (str): The input string. 147 | 148 | Returns: 149 | str: String with fixed fractions. 150 | """ 151 | substrs = string.split("\\frac") 152 | new_str = substrs[0] 153 | if len(substrs) > 1: 154 | substrs = substrs[1:] 155 | for substr in substrs: 156 | new_str += "\\frac" 157 | if substr.startswith("{"): 158 | new_str += substr 159 | else: 160 | if len(substr) < 2: 161 | return string 162 | a = substr[0] 163 | b = substr[1] 164 | if b != "{": 165 | if len(substr) > 2: 166 | post_substr = substr[2:] 167 | new_str += f"{{{a}}}{{{b}}}{post_substr}" 168 | else: 169 | new_str += f"{{{a}}}{{{b}}}" 170 | else: 171 | if len(substr) > 2: 172 | post_substr = substr[2:] 173 | new_str += f"{{{a}}}{{{b}}}{post_substr}" 174 | else: 175 | new_str += f"{{{a}}}{{{b}}}" 176 | return new_str 177 | 178 | def fix_a_slash_b(string: str) -> str: 179 | """ 180 | Converts a simple a/b format to LaTeX fraction if applicable. 181 | 182 | Args: 183 | string (str): The input string. 184 | 185 | Returns: 186 | str: Modified string with fractions fixed. 187 | """ 188 | parts = string.split("/") 189 | if len(parts) != 2: 190 | return string 191 | a, b = parts 192 | try: 193 | a = int(a) 194 | b = int(b) 195 | if string == f"{a}/{b}": 196 | return f"\\frac{{{a}}}{{{b}}}" 197 | else: 198 | return string 199 | except ValueError: 200 | return string 201 | 202 | def strip_string(string: str) -> str: 203 | """ 204 | Normalizes a LaTeX string by removing unnecessary characters and formatting. 205 | 206 | Args: 207 | string (str): The input string. 208 | 209 | Returns: 210 | str: Normalized string. 211 | """ 212 | string = string.replace("\n", "") 213 | string = string.replace("\\!", "") 214 | string = string.replace("\\\\", "\\") 215 | string = string.replace("tfrac", "frac") 216 | string = string.replace("dfrac", "frac") 217 | string = string.replace("\\left", "") 218 | string = string.replace("\\right", "") 219 | string = string.replace("^{\\circ}", "") 220 | string = string.replace("^\\circ", "") 221 | string = string.replace("\\$", "") 222 | string = remove_right_units(string) 223 | string = string.replace("\\%", "") 224 | string = string.replace(r"\%", "") 225 | string = string.replace(" .", " 0.") 226 | string = string.replace("{.", "{0.") 227 | 228 | if len(string) == 0: 229 | return string 230 | if string[0] == ".": 231 | string = "0" + string 232 | 233 | if len(string.split("=")) == 2: 234 | if len(string.split("=")[0]) <= 2: 235 | string = string.split("=")[1] 236 | 237 | string = fix_sqrt(string) 238 | string = string.replace(" ", "") 239 | string = fix_fracs(string) 240 | 241 | if string == "0.5": 242 | string = "\\frac{1}{2}" 243 | 244 | string = fix_a_slash_b(string) 245 | 246 | return string 247 | 248 | def is_equiv(str1: Optional[str], str2: Optional[str], verbose: bool = False) -> bool: 249 | """ 250 | Checks if two strings are equivalent after normalization. 251 | 252 | Args: 253 | str1 (Optional[str]): First string. 254 | str2 (Optional[str]): Second string. 255 | verbose (bool): If True, prints the normalized strings. 256 | 257 | Returns: 258 | bool: True if equivalent, False otherwise. 259 | """ 260 | if str1 is None and str2 is None: 261 | print("WARNING: Both None") 262 | return True 263 | if str1 is None or str2 is None: 264 | return False 265 | 266 | try: 267 | ss1 = strip_string(str1) 268 | ss2 = strip_string(str2) 269 | if verbose: 270 | print(ss1, ss2) 271 | return ss1 == ss2 272 | except Exception: 273 | return str1 == str2 274 | 275 | def sympy_match(str1: str, str2: str) -> bool: 276 | """ 277 | Checks if two mathematical expressions are equivalent using SymPy. 278 | Times out after 3 seconds to prevent hanging on complex expressions. 279 | 280 | Args: 281 | str1 (str): First expression. 282 | str2 (str): Second expression. 283 | 284 | Returns: 285 | bool: True if equivalent, False otherwise. 286 | """ 287 | def timeout_handler(signum, frame): 288 | raise TimeoutError("Parsing took too long") 289 | 290 | try: 291 | # Set timeout of 3 seconds 292 | signal.signal(signal.SIGALRM, timeout_handler) 293 | signal.alarm(3) 294 | 295 | expr1 = parse_latex(str1) 296 | expr2 = parse_latex(str2) 297 | diff = sympy.simplify(expr1 - expr2) 298 | 299 | # Disable alarm 300 | signal.alarm(0) 301 | return diff == 0 302 | except (Exception, TimeoutError): 303 | # Disable alarm 304 | signal.alarm(0) 305 | # print(f"Error in sympy_match. str1: {str1} ||| str2: {str2}") 306 | return False 307 | 308 | def is_correct_no_judge(correct_answer: str, proposed_answer: str) -> bool: 309 | """ 310 | checks if the provided answer is correct. 311 | """ 312 | extracted_answer = get_answer_expr(proposed_answer) 313 | 314 | # If the extracted answer is empty, it's not correct. 315 | if extracted_answer.strip() == "": 316 | return False 317 | 318 | # Three-stage process based off of https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/eval_details.md#math 319 | if is_equiv(extracted_answer, correct_answer): 320 | return True 321 | elif sympy_match(extracted_answer, correct_answer): 322 | return True 323 | else: 324 | return False 325 | 326 | class MathEvaluator: 327 | """ 328 | either pass in an openai api key, or set the OPENAI_API_KEY environment variable. 329 | export OPENAI_API_KEY="sk-proj-..." 330 | 331 | usage: 332 | ```python 333 | evaluator = MathEvaluator() 334 | result = await evaluator.is_correct(correct_answer="4", proposed_answer="4") 335 | assert result 336 | ``` 337 | """ 338 | 339 | def __init__( 340 | self, 341 | model_name: str = "gpt-4o-mini", 342 | rate_limit: float = 10000 / 60, 343 | api_key: Optional[str] = None, 344 | ): 345 | """ 346 | Initializes the MathEvaluator with dataset paths, OpenAI client, and processing configurations. 347 | 348 | Args: 349 | train_data_path (str): Path to the training dataset. 350 | test_data_path (str): Path to the testing dataset. 351 | model_name (str): OpenAI model name to use for judging equality. 352 | rate_limit (float): Rate limit for asynchronous operations. 353 | api_key (Optional[str]): OpenAI API key. If None, it will use the OPENAI_API_KEY environment variable. 354 | """ 355 | self.model_name = model_name 356 | self.rate_limiter = StrictLimiter(rate_limit) 357 | self.openai_client = AsyncOpenAI(api_key=api_key) 358 | 359 | async def judge_equality(self, expr1: str, expr2: str) -> bool: 360 | """ 361 | Determines if two mathematical expressions are equivalent using the OpenAI client. 362 | 363 | Args: 364 | expr1 (str): Generated answer. 365 | expr2 (str): True answer. 366 | 367 | Returns: 368 | bool: True if equivalent, False otherwise. 369 | """ 370 | EQUALITY_TEMPLATE = """ 371 | Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications 372 | 373 | Examples: 374 | 375 | Expression 1: $2x+3$ 376 | Expression 2: $3+2x$ 377 | 378 | Yes 379 | 380 | Expression 1: 3/2 381 | Expression 2: 1.5 382 | 383 | Yes 384 | 385 | Expression 1: $x^2+2x+1$ 386 | Expression 2: $y^2+2y+1$ 387 | 388 | No 389 | 390 | Expression 1: $x^2+2x+1$ 391 | Expression 2: $(x+1)^2$ 392 | 393 | Yes 394 | 395 | Expression 1: 3245/5 396 | Expression 2: 649 397 | 398 | No 399 | (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) 400 | 401 | Expression 1: 2/(-3) 402 | Expression 2: -2/3 403 | 404 | Yes 405 | (trivial simplifications are allowed) 406 | 407 | Expression 1: 72 degrees 408 | Expression 2: 72 409 | 410 | Yes 411 | (give benefit of the doubt to units) 412 | 413 | Expression 1: 64 414 | Expression 2: 64 square feet 415 | 416 | Yes 417 | (give benefit of the doubt to units) 418 | 419 | Expression 1: 2^{{n-1}} - n 420 | Expression 2: 4 - \\frac{{n+2}}{{2^{{n-1}}}} 421 | 422 | No 423 | 424 | Expression 1: 8 425 | Expression 2: Therefore, there are 8 cats. 426 | 427 | Yes 428 | (simple conclusion sentences giving an answer are allowed) 429 | 430 | Expression 1: 3n^2 431 | Expression 2: a_n = 3n^2 432 | 433 | Yes 434 | (variable names are allowed) 435 | 436 | Expression 1: a=3, b=4, e=7 437 | Expression 2: {{a=3, b=4, e=7}} 438 | 439 | Yes 440 | 441 | Expression 1: 453.6235 442 | Expression 2: 454.0231 443 | 444 | No 445 | (approximately equal is not equivalent) 446 | 447 | Expression 1: 1/3 448 | Expression 2: So we have that $A$ 449 | 450 | No 451 | (anything that appears cut off or nonsensical is not equivalent) 452 | 453 | --- 454 | 455 | YOUR TASK 456 | 457 | 458 | Respond with only "Yes" or "No" (without quotes). Do not include a rationale. 459 | 460 | Expression 1: {expr1} 461 | Expression 2: {expr2} 462 | """.strip() 463 | prompt = EQUALITY_TEMPLATE.format(expr1=expr1, expr2=expr2) 464 | 465 | await self.rate_limiter.wait() 466 | try: 467 | response = await self.openai_client.chat.completions.create( 468 | model=self.model_name, 469 | messages=[{"role": "user", "content": prompt}], 470 | max_tokens=20, 471 | n=1, 472 | temperature=0.0, 473 | ) 474 | result = response.choices[0].message.content.strip() 475 | return result.lower().strip() == "yes" 476 | except Exception as e: 477 | logger.error(f"error in judge_equality: {e}") 478 | traceback.print_exc() 479 | return False 480 | 481 | async def is_correct(self, correct_answer: str, proposed_answer: str) -> bool: 482 | """ 483 | checks if the provided answer is correct. 484 | """ 485 | async for attempt in AsyncRetrying( 486 | stop=stop_after_attempt(3), 487 | wait=wait_exponential(multiplier=1, min=4, max=120), 488 | reraise=True, 489 | ): 490 | with attempt: 491 | extracted_answer = self.get_answer_expr(proposed_answer) 492 | 493 | # If the extracted answer is empty, it's not correct. 494 | if extracted_answer.strip() == "": 495 | return False 496 | 497 | # Three-stage process based off of https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/eval_details.md#math 498 | if self.is_equiv(extracted_answer, correct_answer): 499 | return True 500 | elif self.sympy_match(extracted_answer, correct_answer): 501 | return True 502 | else: 503 | return await self.judge_equality(extracted_answer, correct_answer) 504 | 505 | async def is_correct_anywhere(self, correct_answer: str, proposed_answer: str) -> bool: 506 | """ 507 | checks if the correct answer appears anywhere in the proposed answer. 508 | """ 509 | if await self.is_correct(correct_answer, proposed_answer): 510 | return True 511 | 512 | boxed_expressions = self.extract_boxed_expressions(proposed_answer) 513 | 514 | for expr in boxed_expressions: 515 | extracted_answer = self.remove_boxed(expr) 516 | if self.is_equiv(extracted_answer, correct_answer): 517 | return True 518 | elif self.sympy_match(extracted_answer, correct_answer): 519 | return True 520 | elif await self.judge_equality(extracted_answer, correct_answer): 521 | return True 522 | 523 | return False 524 | 525 | async def __call__(self, split: str) -> AsyncIterator[Dict]: 526 | """ 527 | Allows the MathEvaluator to be called as an async generator. 528 | 529 | Args: 530 | split (str): The dataset split to use ('train' or 'test'). 531 | 532 | Yields: 533 | Dict: The next item in the dataset. 534 | 535 | Raises: 536 | ValueError: If an invalid split is provided. 537 | """ 538 | if split == "train": 539 | dataset = self.ds_train 540 | elif split == "test": 541 | dataset = self.ds_test 542 | else: 543 | raise ValueError("split must be 'train' or 'test'") 544 | 545 | for item in dataset: 546 | yield item 547 | 548 | @staticmethod 549 | def has_formatted_answer(answer: str) -> bool: 550 | """ 551 | Checks if the answer contains a formatted solution. 552 | 553 | Args: 554 | answer (str): The answer string. 555 | 556 | Returns: 557 | bool: True if formatted answer exists, False otherwise. 558 | """ 559 | try: 560 | if MathEvaluator.remove_boxed(MathEvaluator.last_boxed_only_string(answer)): 561 | return True 562 | return False 563 | except Exception: 564 | return False 565 | 566 | @staticmethod 567 | def get_answer_expr(answer: str) -> str: 568 | """ 569 | Extracts the mathematical expression from the answer. 570 | 571 | Args: 572 | answer (str): The answer string. 573 | 574 | Returns: 575 | str: Extracted expression. 576 | """ 577 | try: 578 | answer = MathEvaluator.remove_boxed(MathEvaluator.last_boxed_only_string(answer)) 579 | except Exception: 580 | answer = answer.split("\n")[-1] 581 | return answer 582 | 583 | @staticmethod 584 | def extract_boxed_expressions(string: str) -> List[str]: 585 | """ 586 | extracts all \boxed{...} and \boxed ... expressions from the string. 587 | """ 588 | boxed_expressions = [] 589 | 590 | pattern_braces = r"\\boxed\s*\{([^}]*)\}" 591 | boxed_expressions += re.findall(pattern_braces, string) 592 | 593 | pattern_space = r"\\boxed\s+([^\s\$]+)" 594 | boxed_expressions += re.findall(pattern_space, string) 595 | 596 | return ["\\boxed{" + expr + "}" for expr in boxed_expressions] 597 | 598 | @staticmethod 599 | def remove_boxed(s: str) -> Optional[str]: 600 | """ 601 | Removes the \boxed or \fbox formatting from a string. 602 | 603 | Args: 604 | s (str): The input string. 605 | 606 | Returns: 607 | Optional[str]: String without boxed formatting or None. 608 | """ 609 | # pattern = r"\\boxed\s*{([^}]*)}" 610 | # return re.sub(pattern, r"\1", s, flags=re.DOTALL) 611 | 612 | if "\\boxed " in s: 613 | left = "\\boxed " 614 | assert s[: len(left)] == left 615 | return s[len(left) :] 616 | elif "\\boxed{" in s: 617 | left = "\\boxed{" 618 | assert s[: len(left)] == left 619 | assert s[-1] == "}" 620 | return s[len(left) : -1] 621 | elif "\\fbox{" in s: 622 | left = "\\fbox{" 623 | assert s[: len(left)] == left 624 | assert s[-1] == "}" 625 | return s[len(left) : -1] 626 | else: 627 | return s 628 | 629 | @staticmethod 630 | def last_boxed_only_string(string: str) -> Optional[str]: 631 | """ 632 | Extracts the last boxed expression from a string. 633 | 634 | Args: 635 | string (str): The input string. 636 | 637 | Returns: 638 | Optional[str]: The last boxed expression or None. 639 | """ 640 | idx = string.rfind("\\boxed") 641 | if "\\boxed " in string: 642 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] 643 | if idx < 0: 644 | idx = string.rfind("\\fbox") 645 | if idx < 0: 646 | return None 647 | 648 | i = idx 649 | right_brace_idx = None 650 | num_left_braces_open = 0 651 | while i < len(string): 652 | if string[i] == "{": 653 | num_left_braces_open += 1 654 | if string[i] == "}": 655 | num_left_braces_open -= 1 656 | if num_left_braces_open == 0: 657 | right_brace_idx = i 658 | break 659 | i += 1 660 | 661 | return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None 662 | 663 | @staticmethod 664 | def all_boxed_strings(string: str) -> List[str]: 665 | """ 666 | Extracts all boxed expressions from a string in order of appearance. 667 | 668 | Args: 669 | string (str): The input string. 670 | 671 | Returns: 672 | List[str]: List of all boxed expressions in order of appearance. 673 | """ 674 | results = [] 675 | i = 0 676 | 677 | while i < len(string): 678 | # Find next occurrence of either \boxed or \fbox 679 | boxed_idx = string.find("\\boxed", i) 680 | fbox_idx = string.find("\\fbox", i) 681 | 682 | # Determine which comes first 683 | if boxed_idx == -1 and fbox_idx == -1: 684 | break 685 | elif boxed_idx == -1: 686 | idx = fbox_idx 687 | is_boxed = False 688 | elif fbox_idx == -1: 689 | idx = boxed_idx 690 | is_boxed = True 691 | else: 692 | if boxed_idx < fbox_idx: 693 | idx = boxed_idx 694 | is_boxed = True 695 | else: 696 | idx = fbox_idx 697 | is_boxed = False 698 | 699 | if is_boxed and idx + 6 < len(string) and string[idx:idx+6] == "\\boxed ": 700 | # Handle \boxed space case 701 | expr = "\\boxed " + string[idx+6:].split("$")[0].split()[0] 702 | results.append(expr) 703 | i = idx + len(expr) 704 | else: 705 | # Handle \boxed{...} or \fbox{...} case 706 | j = idx 707 | right_brace_idx = None 708 | num_left_braces_open = 0 709 | while j < len(string): 710 | if string[j] == "{": 711 | num_left_braces_open += 1 712 | if string[j] == "}": 713 | num_left_braces_open -= 1 714 | if num_left_braces_open == 0: 715 | right_brace_idx = j 716 | break 717 | j += 1 718 | 719 | if right_brace_idx is not None: 720 | results.append(string[idx:right_brace_idx + 1]) 721 | i = right_brace_idx + 1 722 | else: 723 | i = idx + 1 724 | 725 | return results 726 | 727 | @staticmethod 728 | def is_equiv(str1: Optional[str], str2: Optional[str], verbose: bool = False) -> bool: 729 | """ 730 | Checks if two strings are equivalent after normalization. 731 | 732 | Args: 733 | str1 (Optional[str]): First string. 734 | str2 (Optional[str]): Second string. 735 | verbose (bool): If True, prints the normalized strings. 736 | 737 | Returns: 738 | bool: True if equivalent, False otherwise. 739 | """ 740 | if str1 is None and str2 is None: 741 | print("WARNING: Both None") 742 | return True 743 | if str1 is None or str2 is None: 744 | return False 745 | 746 | try: 747 | ss1 = MathEvaluator.strip_string(str1) 748 | ss2 = MathEvaluator.strip_string(str2) 749 | if verbose: 750 | print(ss1, ss2) 751 | return ss1 == ss2 752 | except Exception: 753 | return str1 == str2 754 | 755 | @staticmethod 756 | def sympy_match(str1: str, str2: str) -> bool: 757 | """ 758 | Checks if two mathematical expressions are equivalent using SymPy. 759 | Times out after 3 seconds to prevent hanging on complex expressions. 760 | 761 | Args: 762 | str1 (str): First expression. 763 | str2 (str): Second expression. 764 | 765 | Returns: 766 | bool: True if equivalent, False otherwise. 767 | """ 768 | def timeout_handler(signum, frame): 769 | raise TimeoutError("Parsing took too long") 770 | 771 | try: 772 | # Set timeout of 3 seconds 773 | signal.signal(signal.SIGALRM, timeout_handler) 774 | signal.alarm(3) 775 | 776 | expr1 = parse_latex(str1) 777 | expr2 = parse_latex(str2) 778 | diff = sympy.simplify(expr1 - expr2) 779 | 780 | # Disable alarm 781 | signal.alarm(0) 782 | return diff == 0 783 | except (Exception, TimeoutError): 784 | # Disable alarm 785 | signal.alarm(0) 786 | # print(f"Error in sympy_match. str1: {str1} ||| str2: {str2}") 787 | return False 788 | 789 | @staticmethod 790 | def strip_string(string: str) -> str: 791 | """ 792 | Normalizes a LaTeX string by removing unnecessary characters and formatting. 793 | 794 | Args: 795 | string (str): The input string. 796 | 797 | Returns: 798 | str: Normalized string. 799 | """ 800 | string = string.replace("\n", "") 801 | string = string.replace("\\!", "") 802 | string = string.replace("\\\\", "\\") 803 | string = string.replace("tfrac", "frac") 804 | string = string.replace("dfrac", "frac") 805 | string = string.replace("\\left", "") 806 | string = string.replace("\\right", "") 807 | string = string.replace("^{\\circ}", "") 808 | string = string.replace("^\\circ", "") 809 | string = string.replace("\\$", "") 810 | string = MathEvaluator.remove_right_units(string) 811 | string = string.replace("\\%", "") 812 | string = string.replace(r"\%", "") 813 | string = string.replace(" .", " 0.") 814 | string = string.replace("{.", "{0.") 815 | 816 | if len(string) == 0: 817 | return string 818 | if string[0] == ".": 819 | string = "0" + string 820 | 821 | if len(string.split("=")) == 2: 822 | if len(string.split("=")[0]) <= 2: 823 | string = string.split("=")[1] 824 | 825 | string = MathEvaluator.fix_sqrt(string) 826 | string = string.replace(" ", "") 827 | string = MathEvaluator.fix_fracs(string) 828 | 829 | if string == "0.5": 830 | string = "\\frac{1}{2}" 831 | 832 | string = MathEvaluator.fix_a_slash_b(string) 833 | 834 | return string 835 | 836 | @staticmethod 837 | def fix_fracs(string: str) -> str: 838 | """ 839 | Fixes improperly formatted fractions in a LaTeX string. 840 | 841 | Args: 842 | string (str): The input string. 843 | 844 | Returns: 845 | str: String with fixed fractions. 846 | """ 847 | substrs = string.split("\\frac") 848 | new_str = substrs[0] 849 | if len(substrs) > 1: 850 | substrs = substrs[1:] 851 | for substr in substrs: 852 | new_str += "\\frac" 853 | if substr.startswith("{"): 854 | new_str += substr 855 | else: 856 | if len(substr) < 2: 857 | return string 858 | a = substr[0] 859 | b = substr[1] 860 | if b != "{": 861 | if len(substr) > 2: 862 | post_substr = substr[2:] 863 | new_str += f"{{{a}}}{{{b}}}{post_substr}" 864 | else: 865 | new_str += f"{{{a}}}{{{b}}}" 866 | else: 867 | if len(substr) > 2: 868 | post_substr = substr[2:] 869 | new_str += f"{{{a}}}{{{b}}}{post_substr}" 870 | else: 871 | new_str += f"{{{a}}}{{{b}}}" 872 | return new_str 873 | 874 | @staticmethod 875 | def fix_a_slash_b(string: str) -> str: 876 | """ 877 | Converts a simple a/b format to LaTeX fraction if applicable. 878 | 879 | Args: 880 | string (str): The input string. 881 | 882 | Returns: 883 | str: Modified string with fractions fixed. 884 | """ 885 | parts = string.split("/") 886 | if len(parts) != 2: 887 | return string 888 | a, b = parts 889 | try: 890 | a = int(a) 891 | b = int(b) 892 | if string == f"{a}/{b}": 893 | return f"\\frac{{{a}}}{{{b}}}" 894 | else: 895 | return string 896 | except ValueError: 897 | return string 898 | 899 | @staticmethod 900 | def remove_right_units(string: str) -> str: 901 | """ 902 | Removes units described within \\text{ } at the end of the string. 903 | 904 | Args: 905 | string (str): The input string. 906 | 907 | Returns: 908 | str: String without units. 909 | """ 910 | if "\\text{ " in string: 911 | splits = string.split("\\text{ ") 912 | if len(splits) == 2: 913 | return splits[0] 914 | return string 915 | 916 | @staticmethod 917 | def fix_sqrt(string: str) -> str: 918 | """ 919 | Ensures that square roots in the string are properly formatted with braces. 920 | 921 | Args: 922 | string (str): The input string. 923 | 924 | Returns: 925 | str: String with fixed square roots. 926 | """ 927 | if "\\sqrt" not in string: 928 | return string 929 | splits = string.split("\\sqrt") 930 | new_string = splits[0] 931 | for split in splits[1:]: 932 | if not split.startswith("{"): 933 | if len(split) < 1: 934 | return string 935 | a = split[0] 936 | new_substr = f"\\sqrt{{{a}}}" + split[1:] 937 | else: 938 | new_substr = "\\sqrt" + split 939 | new_string += new_substr 940 | return new_string -------------------------------------------------------------------------------- /signals/rollouts_based_signals/sample_from_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | from datasets import Dataset, load_dataset 4 | from enum import Enum 5 | import json 6 | import os 7 | from pydantic import BaseModel, Field, field_validator 8 | import tenacity 9 | import traceback 10 | from tqdm import tqdm 11 | 12 | from utils.sglang_util import SGLangServerManager 13 | 14 | class ModelType(str, Enum): 15 | """supported llm model types""" 16 | Llama3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct" 17 | Llama3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct" 18 | Llama3_1_405B = "meta-llama/Llama-3.1-405B-Instruct-FP8" 19 | CLAUDE_3_5 = "claude-3-5-sonnet-latest" 20 | O1_PREVIEW = "o1-preview" 21 | O1_MINI = "o1-mini" 22 | 23 | class LLMHyperparams(BaseModel): 24 | system_prompt: str = """You are a math expert. Given the following math problem, provide your solution in Latex format. Always format your final answer in perfect LaTeX \\boxed{{final_answer}} format.""" 25 | prompt: str = "{problem}" 26 | temperature: float = Field( 27 | default=0.8, 28 | ge=0.0, 29 | le=2.0, 30 | description='Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.' 31 | ) 32 | top_k: int = Field( 33 | default=-1, 34 | description='Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.' 35 | ) 36 | top_p: float = Field( 37 | default=0.95, 38 | ge=0.0, 39 | le=1.0, 40 | description='Float that controls the cumulative probability of the top tokens to consider.' 41 | ) 42 | max_tokens: int = 2048 43 | model_name: ModelType = Field( 44 | description='The model to use for generation.', 45 | ) 46 | n: int = Field( 47 | description='Number of samples to generate.' 48 | ) 49 | 50 | @field_validator('temperature') 51 | def validate_temperature(cls, v): 52 | if v < 0: 53 | raise ValueError("temp can't be negative") 54 | return v 55 | 56 | def to_json(self) -> str: 57 | return json.dumps(self.model_dump(), indent=2) 58 | 59 | @classmethod 60 | def load_json(cls, json_str: str) -> 'LLMHyperparams': 61 | data = json.loads(json_str) 62 | return cls(**data) 63 | 64 | def construct_few_shot_prompt(problem): 65 | # Updated prompt as per the user's instructions 66 | system_prompt = r"""From this moment forward, adopt the persona of a tenured mathematics professor. You receive math problems and you solve them, step by step, following the formatting instructions below. 67 | 68 | # Mathematical Solution Formatting Guidelines 69 | 70 | ## Mathematical Notation Rules 71 | 72 | **Use $$ $$ for:** 73 | - Standalone equations 74 | - Complex mathematical expressions 75 | - Multi-line equations 76 | - Final answers 77 | 78 | **Use single $ $ for:** 79 | - Simple inline expressions 80 | - Variables 81 | - Numbers with mathematical meaning 82 | - Parts of expressions being discussed 83 | 84 | **For equations with multiple lines:** 85 | `latex 86 | $$\\begin{aligned} 87 | equation1 &= expression1 \\\\ 88 | equation2 &= expression2 89 | \\end{aligned}$$ 90 | ` 91 | 92 | **Place the final answer inside \boxed{answer_here}** 93 | After solving and verifying your full solution, write the final answer in \\boxed{answer} notation. 94 | 95 | ## Thought Structure 96 | Each thought must: 97 | 1. Begin with active voice ("I" statements or direct observations) 98 | 2. Express exactly ONE logical step 99 | 3. Integrate mathematical notation naturally within sentences 100 | 4. Use appropriate notation based on context 101 | 5. No more than 1 sentence + interleaved math notation long 102 | 103 | ### Example: 104 | "I can see that when $x = 2$, the expression $$y = x^2 + 3x + 1$$ becomes $$y = 4 + 6 + 1 = 11$$." 105 | 106 | ## Common Mathematical Patterns 107 | - Use $$ $$ for equations being solved 108 | - Use $ $ for discussing components: "where $m$ is the slope" 109 | - Keep expressions inline when discussing specific parts 110 | - Use block notation ($$) for key steps and results 111 | 112 | ## Visual Formatting 113 | - Two blank lines between thoughts (\n\n) 114 | - No bullet points or numbered lists 115 | - No section headers within the solution 116 | - Mathematical expressions should flow naturally within sentences 117 | - No explicit step labeling 118 | - Human-style reasoning 119 | 120 | --- 121 | 122 | ## Examples 123 | 124 | ### Problem 1: 125 | 126 | A standard die is rolled six times. What is the probability that the product of all six rolls is odd? Express your answer as a common fraction. 127 | 128 | ### Solution 1: 129 | 130 | I need to find the probability by dividing favorable outcomes by total outcomes. 131 | 132 | I know that the total number of possible outcomes when rolling a die six times is $$6^6$$. 133 | 134 | For the product to be odd, each individual roll must be odd. 135 | 136 | Looking at a standard die, the odd numbers are 1, 3, and 5. 137 | 138 | Therefore, for each roll, I have only 3 choices to maintain an odd product. 139 | 140 | The total number of favorable outcomes is $$3^6$$ since I have 3 choices for each of the 6 rolls. 141 | 142 | The probability is thus $$\frac{3^6}{6^6}$$. 143 | 144 | This simplifies to $$\left( \frac{1}{2} \right)^6 = \\boxed{\frac{1}{64}}$$. 145 | 146 | ### Problem 2: 147 | 148 | Solve over the integers: $$2^a+4^b+8^c=328$$ 149 | 150 | ### Solution 2: 151 | 152 | I notice that all terms are powers of 2, so I can rewrite everything with base 2. 153 | 154 | I know that $$4 = 2^2$$ and $$8 = 2^3$$, so the equation becomes $$2^a + 2^{2b} + 2^{3c} = 328$$. 155 | 156 | To solve this, I can decompose 328 into powers of 2. 157 | 158 | The largest power of 2 less than 328 is 256 ($$2^8$$). 159 | 160 | Subtracting 256 from 328 leaves 72. 161 | 162 | The largest power of 2 less than 72 is 64 ($$2^6$$). 163 | 164 | Subtracting 64 leaves 8 ($$2^3$$). 165 | 166 | Therefore, $$328 = 2^8 + 2^6 + 2^3$$. 167 | 168 | Comparing terms: $$2^a = 2^8$$, $$2^{2b} = 2^6$$, and $$2^{3c} = 2^3$$. 169 | 170 | Solving these equations: $$a = 8$$, $$b = 3$$, and $$c = 1$$. 171 | 172 | My final answer is \\boxed{(a, b, c) = (8, 3, 1)} 173 | 174 | ### Problem 3: 175 | 176 | Find the equation of the circle which passes through $$(2, 3)$$ and $$(4, 5)$$ and whose center lies on the straight line $$y - 4x + 3 = 0$$. 177 | 178 | ### Solution 3: 179 | 180 | I know that the center lies on the line, so I can write its coordinates as $$(x, 4x - 3)$$. 181 | 182 | The radius can be found using the distance from the center to $$(2, 3)$$. 183 | 184 | Using the distance formula, I get $$r^2 = (x - 2)^2 + (4x - 3 - 3)^2 = (x - 2)^2 + (4x - 6)^2$$. 185 | 186 | Expanding and simplifying: $$r^2 = 17x^2 - 52x + 40$$. 187 | 188 | Since the circle also passes through $$(4, 5)$$, I can write $$r^2 = (x - 4)^2 + (4x - 3 - 5)^2$$. 189 | 190 | This simplifies to $$r^2 = 17x^2 - 72x + 80$$. 191 | 192 | Since both expressions equal $$r^2$$, I can write $$17x^2 - 52x + 40 = 17x^2 - 72x + 80$$. 193 | 194 | I can now simplify: $$20x = 40$$, so $$x = 2$$. 195 | 196 | The center is therefore $$(2, 4(2) - 3) = (2, 5)$$. 197 | 198 | The radius squared is $$r^2 = 17(2)^2 - 52(2) + 40 = 16$$, so $$r = 4$$. 199 | 200 | The equation of the circle is $$(x - 2)^2 + (y - 5)^2 = 16$$. 201 | 202 | The final answer is the full equation of the circle, so the correct final answer is \\boxed{(x - 2)^2 + (y - 5)^2 = 16} 203 | 204 | """ 205 | 206 | # Append the real example as a user/assistant message 207 | user_problem = r"Given real numbers \( a, b, c \) and a positive number \( \lambda \) such that the polynomial \( f(x) = x^3 + a x^2 + b x + c \) has three real roots \( x_1, x_2, x_3 \), and the conditions \( x_2 - x_1 = \lambda \) and \( x_3 > \frac{1}{2}(x_1 + x_2) \) are satisfied, find the maximum value of \( \frac{2 a^3 + 27 c - 9 a b}{\lambda^3} \)." 208 | 209 | assistant_solution = r""" 210 | I need to find the maximum value of the given expression, which involves the coefficients of the polynomial and the difference between its roots. 211 | 212 | Since \( x_1, x_2, x_3 \) are the roots of the polynomial \( f(x) = x^3 + a x^2 + b x + c \), by Vieta's formulas, I know that $$x_1 + x_2 + x_3 = -a$$. 213 | 214 | Also from Vieta's formulas, $$x_1 x_2 + x_2 x_3 + x_3 x_1 = b$$ and $$x_1 x_2 x_3 = -c$$. 215 | 216 | Given that \( x_2 - x_1 = \lambda \), I can express \( x_2 \) as \( x_2 = x_1 + \lambda \). 217 | 218 | The condition \( x_3 > \frac{1}{2}(x_1 + x_2) \) implies that \( x_3 > x_1 + \frac{\lambda}{2} \). 219 | 220 | However, to simplify the problem, I will consider the equality case, which should give me the maximum value. 221 | 222 | Thus, I set \( x_3 = x_1 + \frac{\lambda}{2} \). 223 | 224 | Now, I substitute \( x_2 \) and \( x_3 \) in terms of \( x_1 \) and \( \lambda \) into the expression for \( a \): $$-a = x_1 + (x_1 + \lambda) + \left( x_1 + \frac{\lambda}{2} \right ).$$ 225 | 226 | This simplifies to $$-a = 3x_1 + \frac{3\lambda}{2}.$$ 227 | 228 | I can also express \( b \) and \( c \) in terms of \( x_1 \) and \( \lambda \): 229 | $$b = x_1(x_1 + \lambda) + (x_1 + \lambda)\left( x_1 + \frac{\lambda}{2} \right ) + \left( x_1 + \frac{\lambda}{2} \right ) x_1,$$ 230 | and 231 | $$-c = x_1(x_1 + \lambda)\left( x_1 + \frac{\lambda}{2} \right ).$$ 232 | 233 | Substituting these expressions into the given expression, I get $$\frac{2 a^3 + 27 c - 9 a b}{\lambda^3}$$ in terms of \( x_1 \) and \( \lambda \). 234 | 235 | However, since I want the maximum value of this expression, which is independent of \( x_1 \), I can choose a specific value of \( x_1 \) to simplify the calculation. 236 | 237 | Let's set \( x_1 = 0 \) to get the maximum value. 238 | 239 | Substituting \( x_1 = 0 \) into the expressions for \( a \), \( b \), and \( c \), I get 240 | $$-a = \frac{3\lambda}{2},$$ 241 | $$b = \frac{\lambda^2}{2},$$ 242 | and 243 | $$-c = 0.$$ 244 | 245 | Now, I substitute these values into the given expression: 246 | $$\frac{2 a^3 + 27 c - 9 a b}{\lambda^3} = \frac{2\left( -\frac{3\lambda}{2} \right )^3 + 27(0) - 9\left( -\frac{3\lambda}{2} \right )\left( \frac{\lambda^2}{2} \right )}{\lambda^3}.$$ 247 | 248 | Simplifying the numerator: 249 | $$2\left( -\frac{27\lambda^3}{8} \right ) + \frac{27\lambda^3}{4} = -\frac{27\lambda^3}{4} + \frac{27\lambda^3}{4} = 0.$$ 250 | 251 | Therefore, the maximum value of the expression is $$\frac{0}{\lambda^3} = 0.$$ 252 | 253 | \\boxed{0} 254 | """ 255 | 256 | # Prepare the messages 257 | messages = [ 258 | {"role": "system", "content": system_prompt}, 259 | {"role": "user", "content": user_problem}, 260 | {"role": "assistant", "content": assistant_solution}, 261 | {"role": "user", "content": f"Problem: {problem}"}, 262 | ] 263 | return messages 264 | 265 | def format_items(problem, system_prompt): 266 | return [ 267 | {"role": "system", "content": system_prompt}, 268 | {"role": "user", "content": problem} 269 | ] 270 | 271 | @tenacity.retry(stop=tenacity.stop_after_attempt(5), wait=tenacity.wait_exponential(multiplier=1, min=4, max=10)) 272 | async def get_chat_responses(server_handler, batch_items, n, top_p, temperature, max_tokens): 273 | return await server_handler.get_chat_responses( 274 | batch_items, 275 | n=n, 276 | top_p=top_p, 277 | temperature=temperature, 278 | max_tokens=max_tokens 279 | ) 280 | 281 | async def main(): 282 | # Configuration 283 | model_name = args.model_name 284 | dataset_name = args.dataset_name 285 | save_name = args.save_name 286 | save_folder = args.save_folder 287 | problem_column_name = args.problem_column_name 288 | response_column_name=args.response_column_name 289 | if args.dataset_name_outputs: 290 | hf_save_dataset_name = args.dataset_name_outputs 291 | else: 292 | hf_save_dataset_name = dataset_name 293 | os.makedirs(save_folder, exist_ok=True) 294 | 295 | batch_size = args.batch_size 296 | save_interval = args.save_interval 297 | 298 | # Sampling hyperparameters 299 | if args.greedy: 300 | # greedy decoding 301 | llm_params = LLMHyperparams( 302 | temperature=0.0, 303 | model_name=ModelType(model_name), 304 | n=1, 305 | max_tokens=args.max_tokens 306 | ) 307 | else: 308 | llm_params = LLMHyperparams( 309 | top_k=args.top_k, 310 | top_p=args.top_p, 311 | temperature=args.temperature, 312 | model_name=ModelType(model_name), 313 | n=args.n, 314 | max_tokens=args.max_tokens 315 | ) 316 | 317 | print("Hyperparameters:") 318 | print(llm_params.to_json()) 319 | 320 | # Load and preprocess dataset 321 | ds = load_dataset(dataset_name, split=args.dataset_split) 322 | df = ds.to_pandas() 323 | 324 | # if only using partial dataset, slice it 325 | if args.end == -1: 326 | args.end = len(df) 327 | df = df.iloc[args.start:args.end] 328 | df = df.reset_index(drop=True) 329 | 330 | print(f"Total dataset: {len(df)}") 331 | 332 | # Prepare items for processing 333 | if args.generate_training_data: 334 | items = [construct_few_shot_prompt(row[problem_column_name]) for _, row in df.iterrows()] 335 | else: 336 | items = [format_items(row[problem_column_name], llm_params.system_prompt) for _, row in df.iterrows()] 337 | 338 | df[response_column_name] = None 339 | # Process items in batches 340 | total_items = len(items) 341 | count = 0 342 | with SGLangServerManager(model_name, tp=args.tp) as server_handler: 343 | for idx in tqdm(range(0, total_items, batch_size)): 344 | batch_items = items[idx:idx+batch_size] 345 | batch_outputs = await get_chat_responses( 346 | server_handler=server_handler, 347 | batch_items=batch_items, 348 | n=llm_params.n, 349 | top_p=llm_params.top_p, 350 | temperature=llm_params.temperature, 351 | max_tokens=llm_params.max_tokens 352 | ) 353 | batch_responses = [] 354 | for resp in batch_outputs: 355 | try: 356 | batch_responses.append(resp[-1]["responses"]) 357 | except Exception as e: 358 | print(f"Response: {resp}") 359 | traceback_str = traceback.format_exc() 360 | print(f"Error processing response: {traceback_str}") 361 | batch_responses.append([""]) 362 | 363 | count += 1 364 | # Assign responses to dataframe 365 | for i, response_list in enumerate(batch_responses): 366 | df.at[idx+i, response_column_name] = response_list 367 | if count % save_interval == 0: 368 | try: 369 | df.iloc[:idx+batch_size].to_parquet( 370 | os.path.join(save_folder, f"{save_name}_{count}_batch.parquet") 371 | ) 372 | ds = Dataset.from_pandas(df) 373 | ds.push_to_hub(hf_save_dataset_name, private=True) 374 | except Exception as e: 375 | print(f"Error saving checkpoint: {e}") 376 | # Save final results 377 | try: 378 | df.to_parquet(os.path.join(save_folder, f"{save_name}.parquet")) 379 | ds = Dataset.from_pandas(df) 380 | ds.push_to_hub(hf_save_dataset_name, private=True) 381 | print(f"Saved to {os.path.join(save_folder, f'{save_name}.parquet')}") 382 | except Exception as e: 383 | print(f"Error saving final results: {e}") 384 | 385 | if __name__ == "__main__": 386 | parser = argparse.ArgumentParser() 387 | # Model and dataset configuration 388 | parser.add_argument("--model_name", type=str, required=True) 389 | parser.add_argument("--dataset_name", type=str, required=True) 390 | parser.add_argument("--dataset_name_outputs", type=str, 391 | help="To save the outputs to a different HF dataset, specify here.") 392 | parser.add_argument("--dataset_split", type=str, default="train") 393 | parser.add_argument("--problem_column_name", type=str, default="problem") 394 | parser.add_argument("--response_column_name", type=str, required=True) 395 | parser.add_argument("--generate_training_data", action="store_true", 396 | help="Use a few-shot prompt to encourage the model to follow a specific format.") 397 | 398 | # Save configuration 399 | parser.add_argument("--save_folder", type=str, required=True) 400 | parser.add_argument("--save_name", type=str, required=True) 401 | parser.add_argument("--save_interval", type=int, default=10000, 402 | help="Save every n batches.") 403 | 404 | # SGLang server configuration 405 | parser.add_argument("--tp", type=int, default=1) 406 | parser.add_argument("--batch_size", type=int, default=250, 407 | help="Total batch size will be args.batch_size * args.n.") 408 | 409 | # LLM Hyperparameters 410 | parser.add_argument("--top_k", type=int, default=-1) 411 | parser.add_argument("--top_p", type=float, default=0.95) 412 | parser.add_argument("--temperature", type=float, default=0.8) 413 | parser.add_argument("--max_tokens", type=int, default=4096) 414 | parser.add_argument("--n", type=int, default=64) 415 | parser.add_argument("--greedy", action="store_true", help="Use greedy decoding. Ignores temperature, top_p, and top_k.") 416 | 417 | # dataset slicing 418 | parser.add_argument("--start", type=int, default=0) 419 | parser.add_argument("--end", type=int, default=-1) 420 | args = parser.parse_args() 421 | 422 | asyncio.run(main()) 423 | 424 | # Example Usage: 425 | # python3 sample_from_model.py --model_name meta-llama/Llama-3.1-405B-Instruct-FP8 --dataset_name RLAIF/Big-Math-Competition-Problems --save_folder math_competition_problems --save_name llama405b_greedy --response_column_name llama_405b_greedy_response --greedy --tp 8 426 | 427 | # python3 sample_from_model.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --dataset_name RLAIF/Big-Math-Competition-Problems --save_folder math_competition_problems --save_name llama8b_greedy --response_column_name llama_8b_greedy_response --greedy -------------------------------------------------------------------------------- /signals/rollouts_based_signals/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/rollouts_based_signals/utils/__init__.py -------------------------------------------------------------------------------- /signals/rollouts_based_signals/utils/openai_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import signal 4 | 5 | import openai 6 | from tqdm.asyncio import tqdm_asyncio 7 | 8 | 9 | class OpenAIServerManager: 10 | def __init__( 11 | self, model_name, start_port=2341, tp=1, max_time=600, trust_remote_code=False, 12 | ): 13 | self.model_name = model_name 14 | self.start_port = start_port 15 | self.tp = tp 16 | self.max_time = max_time 17 | self.trust_remote_code = trust_remote_code 18 | 19 | def __enter__(self): 20 | self.ports, self.subprocs = self.launch_servers( 21 | self.model_name, self.start_port, self.tp, self.max_time, 22 | ) 23 | self.async_clients = [ 24 | openai.AsyncClient(base_url=f"http://localhost:{port}/v1", api_key="EMPTY") 25 | for port in self.ports 26 | ] 27 | self.sems = [asyncio.Semaphore(64) for _ in self.ports] 28 | return self 29 | 30 | def __exit__(self, exc_type, exc_val, exc_tb): 31 | for proc in self.subprocs: 32 | os.killpg(os.getpgid(proc.pid), signal.SIGTERM) 33 | 34 | def launch_servers( 35 | self, model_name, start_port=8001, tp=1, max_time=600 36 | ): 37 | raise NotImplementedError 38 | 39 | async def get_chat_responses(self, chats, **kwargs) -> list[openai.ChatCompletion]: 40 | """ 41 | Get responses from the sglang server with retry functionality 42 | :param chats: list of chats 43 | :return: list of chat completions 44 | """ 45 | responses = list() 46 | 47 | 48 | async def response_wrapper(client: openai.AsyncClient, sem, **kwargs): 49 | async with sem: 50 | try: 51 | out = await client.chat.completions.create(**kwargs) 52 | completions = [choice.message.content for choice in out.choices] 53 | messages = kwargs["messages"] 54 | messages[-1]["responses"] = completions 55 | return messages 56 | except Exception as e: 57 | print(f"Error in response_wrapper: {str(e)}") 58 | default_message = "" 59 | messages = kwargs["messages"] 60 | messages[-1]["responses"] = [default_message] 61 | return messages 62 | 63 | for i, chat in enumerate(chats): 64 | curr_kwargs = kwargs.copy() 65 | curr_kwargs["model"] = "default" 66 | curr_kwargs["messages"] = chat 67 | if "max_tokens" not in curr_kwargs: 68 | curr_kwargs["max_tokens"] = ( 69 | 2048 # They can go realllyy reallly long if you just yolo it. 70 | ) 71 | responses.append( 72 | response_wrapper( 73 | self.async_clients[i % len(self.ports)], 74 | self.sems[i % len(self.ports)], 75 | **curr_kwargs, 76 | ) 77 | ) 78 | return await tqdm_asyncio.gather(*responses) 79 | 80 | async def get_completion_responses( 81 | self, prompts, **kwargs 82 | ) -> list[openai.Completion]: 83 | """ 84 | Get responses from the sglang server 85 | :param prompts: list of prompts 86 | :return: 87 | """ 88 | responses = list() 89 | 90 | async def response_wrapper(client: openai.AsyncClient, sem, **kwargs): 91 | async with sem: 92 | completions = await client.completions.create(**kwargs) 93 | completions = [choice.text for choice in completions.choices] 94 | kwargs["responses"] = completions 95 | return kwargs 96 | 97 | for i, prompt in enumerate(prompts): 98 | curr_kwargs = kwargs.copy() 99 | curr_kwargs["model"] = "default" 100 | curr_kwargs["prompt"] = prompt 101 | if "max_tokens" not in curr_kwargs: 102 | curr_kwargs["max_tokens"] = ( 103 | 2048 # They can go realllyy reallly long if you just yolo it. 104 | ) 105 | responses.append( 106 | response_wrapper( 107 | self.async_clients[i % len(self.ports)], 108 | self.sems[i % len(self.ports)], 109 | **kwargs, 110 | ) 111 | ) 112 | return await tqdm_asyncio.gather(*responses) 113 | -------------------------------------------------------------------------------- /signals/rollouts_based_signals/utils/sglang_util.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import signal 4 | import subprocess 5 | import time 6 | 7 | import openai 8 | 9 | from .openai_server import OpenAIServerManager 10 | 11 | 12 | def get_sglang_response(port): 13 | """ 14 | tries to get a response from the sglang server 15 | :return: 16 | """ 17 | 18 | client = openai.Client(base_url=f"http://localhost:{port}/v1", api_key="EMPTY") 19 | 20 | # Text completion 21 | response = client.completions.create( 22 | model="default", 23 | prompt="The capital of France is", 24 | temperature=0, 25 | max_tokens=1, 26 | ) 27 | print(response) 28 | 29 | def kill_process_group(pid): 30 | """Kill the entire process group for the given PID.""" 31 | try: 32 | os.killpg(os.getpgid(pid), signal.SIGTERM) 33 | except ProcessLookupError: 34 | pass # Process already terminated 35 | 36 | class SGLangServerManager(OpenAIServerManager): 37 | def launch_servers(self, model_name, start_port=1234, tp=1, max_time=600): 38 | """ 39 | Launches an sglang server on all available devices. 40 | 41 | Args: 42 | model_name (str): Path to the model. 43 | start_port (int): Port to start on. 44 | tp (int): Tensor parallelism. 45 | max_time (int): Maximum time (in seconds) to wait for the server to become ready. 46 | 47 | Returns: 48 | tuple: (ports, subprocesses) where ports is a list of ports and subprocesses is a list of Popen objects. 49 | """ 50 | subprocesses = [] 51 | # Get list of devices from env var (defaulting to 0-7 if not set) 52 | devices = os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(",") 53 | dp = len(devices) // tp 54 | 55 | # Correctly generate ports based on tp and available devices 56 | # (Even if we only launch one process, we keep a list for compatibility.) 57 | ports = [start_port for port in range(start_port, start_port + len(devices), tp)] 58 | 59 | # Build the command as a list to avoid using the shell 60 | cmd = [ 61 | "python", 62 | "-m", 63 | "sglang.launch_server", 64 | "--model-path", model_name, 65 | "--port", str(start_port), 66 | "--tp", str(tp), 67 | "--dp", str(dp), 68 | "--log-level", "error", 69 | ] 70 | if self.trust_remote_code: 71 | cmd.append("--trust-remote-code") 72 | 73 | # Launch the server process in its own process group. 74 | process = subprocess.Popen(cmd, start_new_session=True) 75 | subprocesses.append(process) 76 | 77 | # Ensure that the child process group is killed when the parent exits. 78 | atexit.register(kill_process_group, process.pid) 79 | 80 | # Optionally, also install signal handlers for SIGINT and SIGTERM. 81 | def _signal_handler(sig, frame): 82 | kill_process_group(process.pid) 83 | raise KeyboardInterrupt 84 | 85 | original_sigint = signal.getsignal(signal.SIGINT) 86 | original_sigterm = signal.getsignal(signal.SIGTERM) 87 | signal.signal(signal.SIGINT, _signal_handler) 88 | signal.signal(signal.SIGTERM, _signal_handler) 89 | 90 | # Wait until at least one port is responsive or timeout is reached. 91 | start_time = time.monotonic() 92 | ports_working = [] 93 | 94 | while time.monotonic() - start_time < max_time: 95 | for port in ports: 96 | if port in ports_working: 97 | continue 98 | try: 99 | get_sglang_response(port) 100 | ports_working.append(port) 101 | except (openai.APITimeoutError, openai.APIConnectionError) as err: 102 | print(f"Port {port} not ready yet.") 103 | if ports_working: 104 | break 105 | time.sleep(1) # shorter sleep interval for faster feedback 106 | 107 | else: 108 | # Timeout reached, ensure cleanup and then raise error. 109 | kill_process_group(process.pid) 110 | raise TimeoutError("Server did not become ready within the allotted time.") 111 | 112 | # Restore original signal handlers. 113 | signal.signal(signal.SIGINT, original_sigint) 114 | signal.signal(signal.SIGTERM, original_sigterm) 115 | 116 | return ports, subprocesses -------------------------------------------------------------------------------- /signals/semdedup.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | from transformers import AutoModel, AutoTokenizer 3 | import pandas as pd 4 | import torch 5 | from tqdm import tqdm 6 | 7 | def concat_columns(df: pd.DataFrame, columns: list[str]): 8 | """Concatenate the columns into a new column. 9 | 10 | Args: 11 | df (pd.DataFrame): The dataframe. 12 | columns (List[str]): The columns to combine. 13 | """ 14 | 15 | # validate that the columns are of type string 16 | for col in columns: 17 | if df[col].dtype != "object": 18 | raise ValueError(f"Column {col} is not of type object.") 19 | return df[columns].apply(lambda x: " ".join(x), axis=1) 20 | 21 | def semantic_deduplication( 22 | df: pd.DataFrame, 23 | required_columns: list[str], 24 | num_kmeans_clusters: int, 25 | epsilon: float = 0.99, 26 | similarity_metric: str = "cosine", 27 | keep_central: bool = True, 28 | kmeans_with_cosine_distance: bool = False, 29 | model_name: str = "sentence-transformers/all-MiniLM-L6-v2", 30 | embedding_batch_size: int = 100, 31 | use_gpu: bool = False, 32 | ): 33 | """ 34 | Perform semantic deduplication on a dataframe. 35 | 36 | Args: 37 | df (pd.DataFrame): The dataframe. 38 | required_columns (List[str]): The columns to use for deduplication. Will be concatenated in order. 39 | epsilon (float): The epsilon value to use for semantic deduplication. 40 | Pairs of items with similarity above epsilon will be considered duplicates. 41 | num_kmeans_clusters (int): The number of clusters to use in kmeans. 42 | similarity_metric (str): The similarity metric to use, only "cosine" currently implemented. 43 | keep_central (bool): Whether to keep the item closest (if True) or farther (if False) 44 | from the cluster centroid when determining which item to keep. 45 | kmeans_with_cosine_distance (bool): Whether to use cosine distance for kmeans, 46 | only False currently implemented. 47 | model_name (str): The model name to use for embedding. 48 | embedding_batch_size (int): The batch size to use for embedding. 49 | """ 50 | 51 | # Mean Pooling - Take attention mask into account for correct averaging 52 | def mean_pooling(model_output, attention_mask): 53 | token_embeddings = model_output[ 54 | 0 55 | ] # First element of model_output contains all token embeddings 56 | input_mask_expanded = ( 57 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 58 | ) 59 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 60 | input_mask_expanded.sum(1), min=1e-9 61 | ) 62 | 63 | def embed_sentences_batch(sentences): 64 | # Tokenize sentences 65 | encoded_input = tokenizer( 66 | sentences, padding=True, truncation=True, return_tensors="pt" 67 | ) 68 | 69 | if use_gpu: 70 | encoded_input = encoded_input.to("cuda") 71 | 72 | # Compute token embeddings 73 | with torch.no_grad(): 74 | model_output = model(**encoded_input) 75 | 76 | # Perform pooling 77 | sentence_embeddings = mean_pooling( 78 | model_output, encoded_input["attention_mask"] 79 | ) 80 | 81 | # Normalize embeddings 82 | sentence_embeddings = torch.nn.functional.normalize( 83 | sentence_embeddings, p=2, dim=1 84 | ) 85 | 86 | return sentence_embeddings 87 | 88 | def embed_sentences(sentences, batch_size=embedding_batch_size): 89 | """ 90 | Embed a list of sentences using the sentence-transformers model. 91 | 92 | Parameters 93 | ---------- 94 | sentences : List[str] 95 | The list of sentences to embed. 96 | batch_size : int 97 | The batch size for embedding. 98 | 99 | Returns 100 | ------- 101 | torch.Tensor 102 | The sentence embeddings. 103 | """ 104 | 105 | # iterate over the sentences in batches 106 | sentence_embeddings = [] 107 | for i in tqdm( 108 | range(0, len(sentences), batch_size), 109 | dynamic_ncols=True, 110 | desc="Embedding sentences...", 111 | ): 112 | batch = sentences[i : i + batch_size] 113 | sentence_embeddings.append(embed_sentences_batch(batch)) 114 | 115 | return torch.cat(sentence_embeddings) 116 | 117 | def semdedup(cluster, eps=0.95): 118 | # compute pairwise cosine similarity between cluster items 119 | pairwise_similarity_matrix = cluster @ cluster.T 120 | 121 | # filter out the diagonal elements 122 | pairwise_similarity_matrix.fill_diagonal_(0.0) 123 | 124 | # create the upper triangular matrix 125 | upper_triangular = torch.triu(pairwise_similarity_matrix) 126 | 127 | # create a binary matrix, where 1 means the similarity is above epsilon 128 | matrix_of_removals = torch.where(upper_triangular > eps, 1, 0) 129 | # get the indices to remove 130 | # head is the row, tail is the column 131 | head_duplicates, tail_duplicates = matrix_of_removals.nonzero(as_tuple=True) 132 | 133 | # get the indices to remove 134 | indices_to_remove = tail_duplicates.tolist() 135 | 136 | return ( 137 | list(set(indices_to_remove)), 138 | head_duplicates.tolist(), 139 | tail_duplicates.tolist(), 140 | ) 141 | 142 | content_col = "dedup_content" 143 | df[content_col] = concat_columns(df, required_columns) 144 | 145 | # Load model from HuggingFace Hub 146 | tokenizer = AutoTokenizer.from_pretrained(model_name) 147 | model = AutoModel.from_pretrained(model_name) 148 | 149 | if use_gpu: 150 | model = model.to("cuda") 151 | 152 | # embed the content 153 | embedded_content = embed_sentences( 154 | df[content_col].tolist(), batch_size=embedding_batch_size 155 | ) 156 | 157 | if use_gpu: 158 | embedded_content = embedded_content.to("cpu") 159 | 160 | kmeans = faiss.Kmeans( 161 | d=embedded_content.size(1), 162 | k=num_kmeans_clusters, 163 | niter=20, 164 | verbose=True, 165 | seed=42, 166 | spherical=kmeans_with_cosine_distance, # only true if using cosine distance 167 | gpu=use_gpu, 168 | ) 169 | 170 | # train the kmeans object 171 | kmeans.train(embedded_content) 172 | centroids = kmeans.centroids 173 | 174 | # get the nearest centroid for each data point 175 | dist_to_cent, nearest_cent = kmeans.index.search(embedded_content, 1) 176 | dist_to_cent, nearest_cent = dist_to_cent.squeeze(), nearest_cent.squeeze() 177 | 178 | # assign the distance and cluster to the dataframe 179 | df["distance_to_centroid"] = dist_to_cent 180 | df["kmeans_cluster"] = nearest_cent 181 | 182 | indices_to_remove = [] 183 | cluster_duplicates_dfs = {} 184 | 185 | for cluster_id in tqdm( 186 | range(num_kmeans_clusters), desc="Iterating over clusters..." 187 | ): 188 | cluster_df = df[df["kmeans_cluster"] == cluster_id] 189 | 190 | # if cluster is empty, skip 191 | if len(cluster_df) == 0: 192 | continue 193 | 194 | # get only items from this cluster 195 | cluster_idxs = cluster_df.index.tolist() 196 | cluster_embeddings = embedded_content[cluster_idxs] 197 | 198 | # compute the similarity to the centroid 199 | if similarity_metric == "cosine": 200 | if kmeans_with_cosine_distance: 201 | # if cosine distance was used for kmeans clustering, don't recompute 202 | cluster_dists_to_cent = 1 - cluster_df["distance_to_centroid"] 203 | else: 204 | # compute the cosine similarity to the centroid 205 | cluster_centroid = torch.tensor(centroids[cluster_id]) 206 | sim_to_cent = torch.nn.functional.cosine_similarity( 207 | cluster_embeddings, cluster_centroid 208 | ) 209 | cluster_dists_to_cent = 1 - sim_to_cent 210 | elif similarity_metric == "l2": 211 | cluster_dists_to_cent = cluster_df["distance_to_centroid"] 212 | 213 | # sort the cluster items by distance to centroid 214 | sort_descending = ( 215 | keep_central # if keep_central is True, sort in descending order 216 | ) 217 | cluster_sorted = sorted( 218 | zip(cluster_idxs, cluster_embeddings, cluster_dists_to_cent), 219 | key=lambda x: x[2], 220 | reverse=sort_descending, 221 | ) 222 | 223 | # get the sorted indices 224 | sorted_cluster_idxs = [x[0] for x in cluster_sorted] 225 | sorted_cluster_embeddings = torch.stack([x[1] for x in cluster_sorted]) 226 | 227 | # use semdedup to determine which items to remove 228 | ( 229 | cluster_indices_to_remove, 230 | cluster_head_duplicates, 231 | cluster_tail_duplicates, 232 | ) = semdedup(sorted_cluster_embeddings, eps=epsilon) 233 | 234 | while cluster_head_duplicates: 235 | assert len(cluster_head_duplicates) == len( 236 | cluster_tail_duplicates 237 | ), "Lengths of head and tail duplicates should be the same." 238 | 239 | # get the first pair of duplicates 240 | head_idx = cluster_head_duplicates.pop(0) 241 | tail_idx = cluster_tail_duplicates.pop(0) 242 | 243 | # if the head index is not in the duplicates, create a new dataframe for it 244 | if sorted_cluster_idxs[head_idx] not in cluster_duplicates_dfs: 245 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]] = pd.DataFrame( 246 | columns=df.columns 247 | ) 248 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]].loc[ 249 | sorted_cluster_idxs[head_idx] 250 | ] = df.loc[sorted_cluster_idxs[head_idx]] 251 | 252 | # add the tail index to the head duplicates dataframe 253 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]].loc[ 254 | sorted_cluster_idxs[tail_idx] 255 | ] = df.loc[sorted_cluster_idxs[tail_idx]] 256 | 257 | # remove the tail index if it appears in the head duplicates, 258 | # prevents duplicates from being counted multiple times 259 | tail_indxs_in_head = [ 260 | i for i, x in enumerate(cluster_head_duplicates) if x == tail_idx 261 | ] 262 | # remove in reverse order so that the indices don't change 263 | for i in tail_indxs_in_head[::-1]: 264 | cluster_head_duplicates.pop(i) 265 | cluster_tail_duplicates.pop(i) 266 | 267 | # convert cluster indices to global indices 268 | global_indices_to_remove = [ 269 | sorted_cluster_idxs[i] for i in cluster_indices_to_remove 270 | ] 271 | 272 | indices_to_remove.extend(global_indices_to_remove) 273 | 274 | # # remove the duplicates 275 | # df = df.drop(indices_to_remove) 276 | 277 | # remove the temporary columns 278 | df = df.drop( 279 | columns=["dedup_content", "distance_to_centroid", "kmeans_cluster"], axis=1 280 | ) 281 | 282 | return df, indices_to_remove, cluster_duplicates_dfs --------------------------------------------------------------------------------