├── .gitignore
├── LICENSE
├── README.md
├── big_math_front_page.png
├── reformulation
├── README.md
├── __init__.py
├── big_math_reformulation.png
├── main.py
├── modules.py
├── requirements.txt
└── signatures.py
└── signals
├── README.md
├── __init__.py
├── add_empty_boxed_signal.py
├── add_hyperlink_signal.py
├── add_language_signal.py
├── add_multipartq_signal.py
├── add_multiple_choice_signal.py
├── add_proof_signal.py
├── add_semdedup_signal.py
├── add_true_false_signal.py
├── add_yes_no_signal.py
├── model_based_signals.py
├── requirements.txt
├── rollouts_based_signals
├── __init__.py
├── evaluate_responses.py
├── example_solve_rate_script.sh
├── math_eval.py
├── sample_from_model.py
└── utils
│ ├── __init__.py
│ ├── openai_server.py
│ └── sglang_util.py
└── semdedup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 SynthLabs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Big-Math: A Large-Scale, High-Quality Math Dataset for Reinforcement Learning in Language Models
2 |
3 |
4 | Click to collapse paper preview
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## Introduction
29 |
30 | Welcome to the official repository for **Big-Math**, a large-scale, high-quality dataset designed specifically for RL training ([PPO](https://arxiv.org/abs/1707.06347), [GRPO](https://arxiv.org/abs/2402.03300), etc.) with large language models (LLMs).
31 |
32 | This repository provides tools for reformulating multiple-choice questions and implementing rule-based and model-based filtering as described in the [Big-Math paper](https://alon-albalak.github.io/images/Big_MATH.pdf).
33 |
34 | Find the dataset on HuggingFace at [https://huggingface.co/datasets/SynthLabsAI/Big-Math-RL-Verified](https://huggingface.co/datasets/SynthLabsAI/Big-Math-RL-Verified)
35 |
36 |
37 | > [!WARNING]
38 | > This repo is intended for research purposes, and is thus under constant development. Please expect major changes to the design.
39 | > The primary goal of the big-math repo is to share the filtering and reformulation code for creating the Big-MATH dataset and to speed the development of future datasets.
40 | > The Big-Math dataset is intended only for RL training of LLMs, it does not contain rollouts
41 |
42 | ## Repository Structure
43 |
44 | This repo consists of 2 main directories: signals and reformulation.
45 |
46 | ### Signals
47 |
48 | This folder contains code used to generate signals on a dataset. The below signals can be generated either using rule-based methods or model-based methods:
49 |
50 | | Signal | Rule-Based | Model-Based |
51 | |-----------------------|------------|-------------|
52 | | Hyperlink Detection | ✅ | |
53 | | Language Identification | | ✅ |
54 | | Semantic Duplicate | | ✅ |
55 | | Multiple Choice Question | ✅ | ✅ |
56 | | Multi-Part Question | ✅ | ✅ |
57 | | True/False Question | ✅ | ✅ |
58 | | Yes/No Question | ✅ | ✅ |
59 | | Proof Detection | ✅ | ✅ |
60 | | Model Solve Rate | | ✅ |
61 |
62 | ### Reformulation
63 |
64 | This folder contains code used to reformulate multiple choice problems to open-ended questions.
65 |
66 | ## 🚀 Getting Started
67 |
68 | ### Prerequisites
69 | - python 3.10+
70 | - install with packages in `signals/requirements.txt` to generate signals on a dataset
71 | - install with packages in `reformulation/requirements.txt` to reformulate multiple choicen questions into open-ended questions
72 |
73 | ### Installation
74 | 1. Clone the repository:
75 | ```bash
76 | git clone https://github.com/SynthLabsAI/big-math.git
77 | cd big-math
78 | ```
79 | 2. Install dependencies
80 | ```bash
81 | pip install -r signals/requirements.txt -r reformulation/requirements.txt
82 | ```
83 |
84 | ### 🛠 Usage
85 |
86 | #### Reformulation
87 | See [the reformulation readme](https://github.com/SynthLabsAI/big-math/blob/master/reformulation/README.md) for an explanation of files and usage.
88 |
89 | #### Signals
90 | See [the signals readme](https://github.com/SynthLabsAI/big-math/blob/master/signals/README.md) for an explanation of files and usage.
91 |
92 | ## 📄 Citation
93 |
94 | ```bibtex
95 | @misc{albalak2025bigmathlargescalehighqualitymath,
96 | title={Big-Math: A Large-Scale, High-Quality Math Dataset for Reinforcement Learning in Language Models},
97 | author={Alon Albalak and Duy Phung and Nathan Lile and Rafael Rafailov and Kanishk Gandhi and Louis Castricato and Anikait Singh and Chase Blagden and Violet Xiang and Dakota Mahan and Nick Haber},
98 | year={2025},
99 | eprint={2502.17387},
100 | archivePrefix={arXiv},
101 | primaryClass={cs.LG},
102 | url={https://arxiv.org/abs/2502.17387},
103 | }
104 | ```
105 |
106 | ## License
107 |
108 | This project is licensed under the MIT License. See the [LICENSE](https://github.com/SynthLabsAI/big-math/blob/master/LICENSE) for details.
109 |
--------------------------------------------------------------------------------
/big_math_front_page.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/big_math_front_page.png
--------------------------------------------------------------------------------
/reformulation/README.md:
--------------------------------------------------------------------------------
1 | # Multiple Choice Reformulation
2 |
3 | The **Multiple Choice Reformulator** is designed to transform multiple choice questions into open-ended questions using dspy. The tool is designed for research purposes.
4 |
5 |
6 |
7 |
8 | ## Folder Structure
9 |
10 | - `main.py`: The main entry point for the reformulation process.
11 | - `signatures.py`: Contains the function signatures and definitions for the dspy method used within the reformulator.
12 | - `modules.py`: Houses the Reformulator class, which implements the core logic for converting MCQs to open-ended problems.
13 | - `requirements.txt`: Make sure to install the dependencies as specified here.
14 |
15 | ## Setup and Configuration
16 |
17 | ### Step 1: Implement the `custom_load_dataset()` function
18 | The entry point for running the reformulator is located in `main.py`. You need to implement the `custom_load_dataset()` function within this file to load your dataset of multiple-choice questions. This function should return the dataset in a format that the reformulator can process.
19 |
20 | Here is a simple template using HuggingFace Datasets to get you started:
21 | ```python
22 | from datasets import load_dataset
23 | def custom_load_dataset(dataset_name):
24 | dataset = load_dataset(dataset_name, split="train")
25 | return dataset
26 | ```
27 |
28 | ### Step 2: Define the `API_CONFIGS` variable
29 | Within `main.py`, you will also need to define the `API_CONFIGS` variable. This variable will store configuration information related to the model API you are using. Modify this section to suit your API configuration setup.
30 |
31 | Example:
32 | ```python
33 | # Server configurations
34 | API_CONFIGS = {
35 | """Example"""
36 | "": {
37 | "api_base": "http://:/v1",
38 | "api_key": "PROVIDER_API_KEY",
39 | },
40 | "123.456.789.012": {
41 | "api_base": "http://123.456.789.012:8000/v1",
42 | "api_key": "PROVIDER_API_KEY",
43 | },
44 | }
45 | ```
46 |
47 | ### Step 3: Run the Reformulator
48 | Once the dataset loading and API configurations are set up, you can execute the reformulator by running the following command:
49 | ```bash
50 | python main.py
51 | ```
52 |
--------------------------------------------------------------------------------
/reformulation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/reformulation/__init__.py
--------------------------------------------------------------------------------
/reformulation/big_math_reformulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/reformulation/big_math_reformulation.png
--------------------------------------------------------------------------------
/reformulation/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import math
4 | import pprint
5 | import random
6 | from datetime import datetime
7 | from pathlib import Path
8 | from typing import Any, Dict, List, Optional
9 |
10 | from dotenv import load_dotenv
11 | from loguru import logger
12 | from tenacity import retry, stop_after_attempt, wait_exponential
13 | from tqdm import tqdm
14 |
15 | import dspy
16 |
17 | load_dotenv()
18 |
19 | from modules import MathProblemReformulator
20 |
21 | # Global configs
22 | # MODEL = "openai/gpt-4-turbo-preview" # default model
23 | MODEL = "openai/meta-llama/Llama-3.1-8B-Instruct"
24 | CONCURRENT_TASKS_PER_API = 25
25 |
26 | # Server configurations
27 | API_CONFIGS = {
28 | """Example"""
29 | "": {
30 | "api_base": "http://:/v1",
31 | "api_key": "PROVIDER_API_KEY",
32 | },
33 | "123.456.789.012": {
34 | "api_base": "http://123.456.789.012:8000/v1",
35 | "api_key": "PROVIDER_API_KEY",
36 | },
37 | }
38 |
39 |
40 | def custom_load_dataset():
41 | """
42 | This function is a placeholder for loading a custom dataset.
43 |
44 | You need to define this function based on the specific requirements of your dataset.
45 | It should include the logic to load and return the dataset from its source.
46 |
47 | Returns:
48 | list: An empty list as a placeholder. Replace this with the actual dataset.
49 | """
50 | """Define your own dataset loader here, depending on where the data comes from"""
51 | return []
52 |
53 | def setup_dspy_with_endpoint(api_base: str, api_key: str, temperature: float = 0.0001) -> None:
54 | """Configure DSPy with specified endpoint and settings"""
55 | lm = dspy.LM(
56 | MODEL,
57 | api_key=api_key,
58 | api_base=api_base,
59 | temperature=round(temperature, 6),
60 | # model_type='chat'
61 | )
62 | # Configure async worker pool size
63 | dspy.settings.configure(lm=lm, async_max_workers=CONCURRENT_TASKS_PER_API)
64 | logger.debug(f"DSPy configured with model: {MODEL} at {api_base}")
65 |
66 | def save_batch_results(batch_data: List[Dict[str, Any]], output_dir: str = "outputs") -> None:
67 | """Append batch results to a single JSONL file"""
68 | # Create output directory if it doesn't exist
69 | output_dir = Path(output_dir)
70 | output_dir.mkdir(parents=True, exist_ok=True)
71 |
72 | # Use a single JSONL file for all results
73 | output_file = output_dir / f"reformulation_results.jsonl"
74 |
75 | # Append results to file
76 | with open(output_file, "a") as f:
77 | for item in batch_data:
78 | f.write(json.dumps(item) + "\n")
79 |
80 | logger.debug(f"Batch results appended to {output_file}")
81 |
82 |
83 | def load_processed_problems(output_dir: str = "outputs") -> set:
84 | """Load set of already processed problem IDs that passed judge verification"""
85 | output_files = list(Path(output_dir).glob(f"reformulation_results_*.jsonl"))
86 | processed = set()
87 |
88 | if output_files:
89 | for output_file in output_files:
90 | with open(output_file) as f:
91 | for line in f:
92 | try:
93 | data = json.loads(line)
94 | # Only consider it processed if it was successful AND got a pass verdict
95 | if (data.get('success') and
96 | 'data' in data and
97 | data.get('data', {}).get('judge_verdict', '').lower() == 'pass'):
98 | processed.add(data['data']['uuid'])
99 | except:
100 | continue
101 |
102 | logger.info(f"Found {len(processed)} previously processed problems with PASS verdict")
103 | return processed
104 |
105 |
106 | def log_processing_details(problem: str, result) -> None:
107 | """Log the processing details for a single item"""
108 | # Print and log results with clear section headers
109 | logger.info("=== Original Problem ===")
110 | print("\n=== Original Problem ===")
111 | logger.debug(problem)
112 | print(problem)
113 |
114 | # Display reformulation process details
115 | key_display_map = {
116 | "core_mathematical_concept": "",
117 | "key_information_extraction": "",
118 | "problem_structure_analysis": "",
119 | "multiple_choice_removal_strategy": "",
120 | "rephrasing_approach": "",
121 | "problem_integrity_preservation": "",
122 | "answer_format_specification": ""
123 | }
124 |
125 | logger.info("=== Reformulation Process Details ===")
126 | print("\n=== Reformulation Process Details ===")
127 | for key, value in result.reformulation_process.items():
128 | display_key = key_display_map.get(key, key.replace('_', ' ').title())
129 | logger.debug(f"{display_key}: {value}")
130 | print(f"{display_key}:")
131 | print(f" {value}")
132 |
133 | logger.info("=== Reformulated Problem ===")
134 | print("\n=== Reformulated Problem ===")
135 | logger.debug(result.reformulated_problem)
136 | print(result.reformulated_problem)
137 |
138 | # Log judge results
139 | logger.info("=== Judge Evaluation ===")
140 | print("\n=== Judge Evaluation ===")
141 | print(f"Verdict: {result.judge_verdict}")
142 |
143 | if result.judge_issues:
144 | print("\nIssues Found:")
145 | for issue in result.judge_issues:
146 | print(f"- {issue}")
147 |
148 | if result.judge_suggestions:
149 | print("\nSuggestions:")
150 | for suggestion in result.judge_suggestions:
151 | print(f"- {suggestion}")
152 |
153 | if result.judge_corrected_reformulated_problem:
154 | logger.info("=== Judge's Corrected Version ===")
155 | print("\n=== Judge's Corrected Version ===")
156 | print(result.judge_corrected_reformulated_problem)
157 |
158 | # Log validation results
159 | logger.info("=== Basic Validation Results ===")
160 | print("\n=== Basic Validation Results ===")
161 | print("Checking the following criteria:")
162 | validation_descriptions = {
163 | "no_mc_options": "Multiple choice options removed",
164 | "process_complete": "All reformulation steps completed"
165 | }
166 |
167 | for check, passed in result.validation_details.items():
168 | status = "✓ PASSED" if passed else "✗ FAILED"
169 | description = validation_descriptions.get(check, check)
170 | logger.debug(f"{status} | {description}")
171 | print(f"{status} | {description}")
172 |
173 | overall = "✓ PASSED" if result.validation_result else "✗ FAILED"
174 | logger.info(f"Overall Basic Validation: {overall}")
175 | print(f"\nOverall Basic Validation: {overall}")
176 | print("\n" + "="*50 + "\n")
177 |
178 |
179 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
180 | async def process_single_item(item: Dict[str, Any], reformulator, api_base: str = None) -> Dict[str, Any]:
181 | """Process a single item and return its result"""
182 | try:
183 | problem_text = item['problem']
184 | if not problem_text:
185 | raise ValueError("Empty problem text")
186 | try:
187 | correct_answer = item['final_answer']
188 | except:
189 | logger.warning("No final answer found")
190 | correct_answer = None
191 | if not correct_answer:
192 | try:
193 | correct_answer = item['answer']
194 | except:
195 | logger.warning("No answer found")
196 | correct_answer = None
197 | if not correct_answer:
198 | try:
199 | correct_answer = item['source_solution']
200 | except:
201 | logger.warning("No final answer found")
202 | correct_answer = None
203 | if not correct_answer:
204 | try:
205 | correct_answer = item['solution']
206 | except:
207 | logger.warning("No solution found")
208 | pprint.pprint(item)
209 | # input("NO SOLUTION/ANSWER FOUND!!! YIKKKKEESSSS...... Press Enter to continue...")
210 | correct_answer = None
211 | print(f"Correct answer: {correct_answer}")
212 | # Add uuid if not present for single problems
213 | if 'uuid' not in item:
214 | item['uuid'] = str(datetime.now().timestamp())
215 |
216 | temperature = round(random.uniform(0.00001, 0.001), 6)
217 |
218 | # Call reformulator asynchronously with randomized temperature
219 | with dspy.context(
220 | lm=dspy.LM(
221 | MODEL,
222 | temperature=temperature,
223 | api_base=api_base,
224 | api_key="PROVIDER_API_KEY",
225 | # model_type='chat'
226 | )):
227 | result = await reformulator(problem_text, correct_answer)
228 |
229 | processed_data = {
230 | 'reformulation_process': result.reformulation_process,
231 | 'reformulated_problem': result.reformulated_problem,
232 | 'reformulation_solution': result.reformulation_solution,
233 | 'validation_details': result.validation_details,
234 | 'validation_result': result.validation_result,
235 | 'reasoning': result.reasoning if result.reasoning else None,
236 | 'judge_verdict': result.judge_verdict,
237 | 'judge_issues': result.judge_issues,
238 | 'judge_suggestions': result.judge_suggestions,
239 | 'judge_reasoning': result.judge_reasoning,
240 | 'judge_corrected_reformulated_problem': result.judge_corrected_reformulated_problem,
241 | 'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
242 | 'temperature': temperature,
243 | 'api_base': api_base,
244 | **item,
245 | }
246 |
247 | logger.info("=== Processing Results ===")
248 | log_processing_details(problem_text, result)
249 |
250 | return {"success": True, "data": processed_data}
251 | except Exception as e:
252 | logger.error(f"Failed to process problem: {e}")
253 | return {"success": False, "error": str(e), "item": item}
254 |
255 | async def process_batch_distributed(items, reformulator, api_resources: Dict) -> List[Dict[str, Any]]:
256 | """Process a batch of items distributed across multiple API endpoints"""
257 | new_items = [item for item in items if item['uuid']]
258 |
259 | if not new_items:
260 | return []
261 |
262 | api_endpoints = list(api_resources.keys())
263 | num_endpoints = len(api_endpoints)
264 |
265 | tasks = []
266 | for idx, item in enumerate(new_items):
267 | api_base = api_endpoints[idx % num_endpoints]
268 | resources = api_resources[api_base]
269 | pprint.pprint(item)
270 | setup_dspy_with_endpoint(
271 | api_base=resources['api_base'],
272 | api_key=resources['api_key'],
273 | temperature=round(random.uniform(0.00001, 0.001), 6)
274 | )
275 |
276 | task = process_single_item(item, reformulator, api_base=resources['api_base'],)
277 | tasks.append(task)
278 |
279 | results = await asyncio.gather(*tasks)
280 |
281 | successful_results = [r for r in results if r["success"]]
282 | if successful_results:
283 | save_batch_results(successful_results)
284 |
285 | return results
286 |
287 | async def main(problem: Optional[str] = None) -> None:
288 | """Main function to run math problem reformulation"""
289 | # Initialize API resources
290 | api_resources = {
291 | api_base: {
292 | 'api_base': config['api_base'],
293 | 'api_key': config['api_key']
294 | } for api_base, config in API_CONFIGS.items()
295 | }
296 |
297 | # Create reformulator instance and make it async
298 | reformulator = dspy.asyncify(MathProblemReformulator())
299 | logger.debug("MathProblemReformulator instance created")
300 |
301 | if problem is None:
302 | # Load and process dataset in parallel
303 | dataset = custom_load_dataset()
304 | if not dataset:
305 | logger.error("No dataset loaded. Exiting.")
306 | return
307 |
308 | total_problems = len(dataset)
309 | logger.info(f"Processing {total_problems} problems from dataset")
310 |
311 | print("\nSample row from dataset:")
312 | first_item = dataset[0]
313 | print(json.dumps(first_item, indent=2))
314 | # input("\nPress Enter to continue...")
315 |
316 | batch_size = len(api_resources) * CONCURRENT_TASKS_PER_API
317 | with tqdm(total=total_problems, desc="Processing problems") as pbar:
318 | for i in range(0, total_problems, batch_size):
319 | batch = dataset[i:i + batch_size]
320 |
321 | logger.info(f"Processing batch {i//batch_size + 1} of {math.ceil(total_problems/batch_size)}")
322 | results = await process_batch_distributed(batch, reformulator, api_resources)
323 | pbar.update(len(results))
324 | else:
325 | # Process single problem
326 | api_base = next(iter(API_CONFIGS.keys()))
327 | config = API_CONFIGS[api_base]
328 | setup_dspy_with_endpoint(config['api_base'], config['api_key'])
329 |
330 | try:
331 | # Create item dictionary matching dataset format
332 | item = {
333 | "problem": problem,
334 | "source": "manual_input"
335 | }
336 | result = await process_single_item(item, reformulator)
337 | if result["success"]:
338 | save_batch_results([result])
339 | logger.info("Successfully saved single problem result")
340 | except Exception as e:
341 | logger.error(f"Failed to process single problem: {e}")
342 |
343 | if __name__ == "__main__":
344 | asyncio.run(main())
--------------------------------------------------------------------------------
/reformulation/modules.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from loguru import logger
3 |
4 | from signatures import MathReformulationSignature, ReformulationJudgeSignature
5 |
6 |
7 | class MathProblemReformulator(dspy.Module):
8 | """
9 | Reformulates multiple-choice math problems into open-ended format.
10 |
11 | Methods
12 | -------
13 | __init__():
14 | Initializes the MathProblemReformulator with prediction models for reformulation and judgment.
15 |
16 | forward(problem: str, correct_answer: str) -> dspy.Prediction:
17 | Reformulates a given multiple-choice math problem and judges the reformulation.
18 |
19 | Parameters:
20 | problem (str): The original multiple-choice math problem.
21 | correct_answer (str): The correct answer to the original problem.
22 |
23 | Returns:
24 | dspy.Prediction: The prediction result containing the reformulated problem, validation details, and judgment results.
25 |
26 | _validate(reformulation_process: dict, reformulated_problem: str) -> tuple[bool, dict]:
27 | Validates the reformulation output.
28 |
29 | Parameters:
30 | reformulation_process (dict): The process details of the reformulation.
31 | reformulated_problem (str): The reformulated math problem.
32 |
33 | Returns:
34 | tuple[bool, dict]: A tuple containing a boolean indicating if the validation passed and a dictionary with validation details.
35 | """
36 | """Reformulates multiple-choice math problems into open-ended format"""
37 |
38 | def __init__(self):
39 | super().__init__()
40 | self.reformulate = dspy.Predict(MathReformulationSignature)
41 | self.judge = dspy.Predict(ReformulationJudgeSignature)
42 |
43 | def forward(self, problem: str, correct_answer: str):
44 | # First pass - reformulate the problem
45 | result = self.reformulate(original_problem=problem, correct_answer=correct_answer)
46 |
47 | # Skip judge if problem wasn't actually multiple choice
48 | if result.reformulated_problem.strip() == "N/A":
49 | # Return early with minimal response
50 | return dspy.Prediction(
51 | reformulation_process=result.reformulation_process,
52 | reformulated_problem="N/A",
53 | validation_result=True, # Consider it valid since it's properly identified
54 | validation_details={"not_multiple_choice": True},
55 | reasoning=result.reasoning,
56 | reformulation_solution=result.solution,
57 | judge_corrected_reformulated_problem=None,
58 | judge_verdict=None,
59 | judge_issues=None,
60 | judge_suggestions=None,
61 | judge_reasoning=None
62 | )
63 |
64 | # Second pass - judge the reformulation
65 | judgment = self.judge(
66 | original_problem=problem,
67 | correct_answer=correct_answer,
68 | reformulated_problem=result.reformulated_problem
69 | )
70 |
71 | # Run basic validation
72 | is_valid, validation_details = self._validate(
73 | result.reformulation_process,
74 | result.reformulated_problem
75 | )
76 |
77 | # Add corrected_reformulated_problem without overwriting
78 | corrected_reformulation = judgment.corrected_version if judgment.corrected_version else None
79 |
80 | return dspy.Prediction(
81 | reformulation_process=result.reformulation_process,
82 | reformulated_problem=result.reformulated_problem,
83 | validation_result=is_valid,
84 | validation_details=validation_details,
85 | reasoning=result.reasoning,
86 | reformulation_solution=result.solution,
87 | judge_corrected_reformulated_problem=corrected_reformulation,
88 | judge_verdict=judgment.verdict,
89 | judge_issues=judgment.issues,
90 | judge_suggestions=judgment.suggestions,
91 | judge_reasoning=judgment.reasoning
92 | )
93 |
94 | def _validate(self, reformulation_process: dict, reformulated_problem: str) -> tuple[bool, dict]:
95 | """Validate reformulation output"""
96 | # expected keys in snake_case to match model output
97 | expected_keys = {
98 | 'core_mathematical_concept',
99 | 'key_information_extraction',
100 | 'problem_structure_analysis',
101 | 'multiple_choice_removal_strategy',
102 | 'rephrasing_approach',
103 | 'problem_integrity_preservation',
104 | 'answer_format_specification',
105 | 'is_multiple_choice'
106 | }
107 |
108 | # Validate reformulation process keys
109 | provided_keys = set(reformulation_process.keys())
110 | missing_keys = expected_keys - provided_keys
111 | extra_keys = provided_keys - expected_keys
112 |
113 | # Log any key mismatches for debugging
114 | if missing_keys:
115 | logger.debug(f"Missing keys: {missing_keys}")
116 | if extra_keys:
117 | logger.debug(f"Unexpected extra keys: {extra_keys}")
118 |
119 | # # Check if reformulated problem contains boxed instruction
120 | # has_box_instruction = "\\boxed{" in reformulated_problem
121 |
122 | # Check if multiple choice options were removed
123 | no_mc_options = not any(x in reformulated_problem for x in ["(A)", "(B)", "(C)", "(D)", "(E)"])
124 |
125 | process_complete = not missing_keys and not extra_keys
126 |
127 | validation_details = {
128 | # "has_box_instruction": has_box_instruction,
129 | "no_mc_options": no_mc_options,
130 | "process_complete": process_complete
131 | }
132 |
133 | return all(
134 | [
135 | # has_box_instruction,
136 | no_mc_options,
137 | process_complete
138 | ]
139 | ), validation_details
--------------------------------------------------------------------------------
/reformulation/requirements.txt:
--------------------------------------------------------------------------------
1 | tenacity==9.0.0
2 | dspy==2.5.40
3 | loguru==0.7.2
4 | python-dotenv==1.0.1
5 | hf-transfer==0.1.8
6 | duckdb==1.1.3
7 | httpx==0.25.2
--------------------------------------------------------------------------------
/reformulation/signatures.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import dspy
3 |
4 |
5 | class MathReformulationSignature(dspy.Signature):
6 | """You are an AI assistant specializing in transforming multiple-choice math problems into open-ended, solvable questions suitable for an automatic grading system that relies on regex pattern matching within `\boxed{}`. Your task is to reformulate the given problem while adhering to specific guidelines.
7 |
8 | **Rules:**
9 | - Never turn problems into proofs, all problems should have an answer that is machine-verifiable
10 | - Reformulated problems should have the same final answer as the original problem (just not in the multiple-choice format)
11 | - It must be reasonably clear from the problem what the student is expected to write inside of \boxed{final_answer}. E.g. "What are its key properties and characteristics?" is almost always a bad reformulation because it is ambiguous what the student should write inside of the `\boxed{}`. "Express your answer in centimeters." is a good answer format specification because it does NOT give away the answer but does specify a clear format.
12 | - Do NOT reformulate problems that are not actually multiple choice or are impossible to reformulate
13 | - If a problem is not multiple choice, return "N/A" for the reformulated problem and other "N/A" fields
14 |
15 | **Follow these steps:**
16 |
17 | 1. **Analyze the original problem:**
18 | - Identify the core mathematical concept.
19 | - Note any crucial information, including numerical values, equations, and key terms.
20 | - Determine the structure of the problem (e.g., scenario-based, direct question, multi-step).
21 | - Preserve all mathematical notations, symbols, and formatting as they appear in the original problem.
22 |
23 | 2. **Remove multiple-choice options:**
24 | - Eliminate all answer choices and their labels (A, B, C, etc.).
25 | - If any options contain information essential to solving the problem, incorporate that information into the main problem statement.
26 |
27 | 3. **Rephrase the question:**
28 | - Transform the question into an open-ended format that requires a direct numerical or algebraic answer.
29 | - Ensure the rephrased question is clear, unambiguous, and uses language appropriate for the student's level.
30 |
31 | 4. **Maintain problem integrity:**
32 | - Keep all original numerical values, equations, figures, and key terms intact.
33 | - Preserve any scenarios, dialogues, or conditional information crucial for solving the problem.
34 | - Do not introduce new information, alter units of measurement, or change the mathematical intent of the problem.
35 | - If the problem references diagrams or figures, ensure that any necessary descriptions are included.
36 |
37 | 5. **Specify the answer format:**
38 | - Instruct the student to provide their answer using the `\boxed{}` format.
39 | - Do not include placeholders like `\boxed{N}`; instead, guide the student to input their calculated answer within the boxed format.
40 | - For example, "Provide your answer in the form `\boxed{\text{your answer here}}`."
41 |
42 | 6. **Final check:**
43 | - Ensure the reformulated problem contains all necessary information for independent solving.
44 | - Verify that the problem hasn't become easier or harder than the original.
45 | - Check for any common errors, such as unit inconsistencies or typographical mistakes.
46 | - Confirm that no hints, solution methods, or additional explanations have been inadvertently added.
47 |
48 | 7. **Is actually multiple choice:**
49 | - Some problems are not actually multiple choice and do NOT actually need to be reformulated.
50 | - If the problem is **NOT** multiple choice, do NOT reformulate it! Note here if it is not multiple choice and return "N/A" for the reformulated problem.
51 |
52 | **Before providing the final reformulated problem, please create a `reformulation_process` dictionary (for internal use; do not include this dictionary or its content in the final problem). The dictionary should have the following exact keys:**
53 |
54 | - `"core_mathematical_concept"`: Summarize the core mathematical concept.
55 | - `"key_information_extraction"`: List key information (numerical values, equations, terms).
56 | - `"problem_structure_analysis"`: Identify the problem structure.
57 | - `"multiple_choice_removal_strategy"`: Plan how to remove multiple-choice options.
58 | - `"rephrasing_approach"`: Outline the rephrasing strategy.
59 | - `"problem_integrity_preservation"`: Note how to maintain problem integrity.
60 | - `"answer_format_specification"`: Specify the answer format.
61 | - `"is_multiple_choice"`: Whether the problem is actually multiple choice.
62 |
63 | **This structured approach will help ensure accuracy and completeness in your final output.**
64 |
65 | **After your reformulation process, present the reformulated problem as plain text, making sure it meets all the criteria outlined above. Do not include the `` tags or any of their content in the final problem presented to the student.**
66 |
67 | **Examples:**
68 | ```
69 |
70 |
71 | Three squares with the same center and corresponding parallel sides are drawn. The distance between the sides of successive squares is 3 units, and the side length of the largest square is 22 units. What is the perimeter of the smallest square?
72 | (A) 40
73 | (B) 64
74 | (C) 100
75 | (D) 10
76 | (E) 20
77 |
78 |
79 |
80 | {
81 | "core_mathematical_concept": "Geometry - concentric squares and their dimensions",
82 | "key_information_extraction": [
83 | "Three concentric squares",
84 | "Distance between sides of successive squares: 3 units",
85 | "Side length of the largest square: 22 units",
86 | "Need to find the perimeter of the smallest square"
87 | ],
88 | "problem_structure_analysis": "Direct question based on given information",
89 | "multiple_choice_removal_strategy": [
90 | "Eliminate all answer choices and their labels",
91 | "No essential information is lost by removing the options"
92 | ],
93 | "rephrasing_approach": [
94 | "Keep the description of the squares intact",
95 | "Rephrase the question to ask for the perimeter in an open-ended format",
96 | "Specify units and number format requirements"
97 | ],
98 | "problem_integrity_preservation": [
99 | "Preserve all original numerical values and key terms",
100 | "Do not add any new information or hints"
101 | ],
102 | "answer_format_specification": [
103 | "Answer should be expressed in units",
104 | "Answer should be given as a whole number",
105 | "No decimal places required based on original options"
106 | ],
107 | "is_multiple_choice": true
108 | }
109 |
110 |
111 |
112 | Let me think about this step by step:
113 |
114 | 1. Analyzing the multiple choice aspect:
115 | - The options are all whole numbers: 40, 64, 100, 10, 20
116 | - No units are specified in the options but units are used in the problem
117 | - The options suggest an exact answer is expected, not an approximation
118 |
119 | 2. Looking at the essential information:
120 | - The geometric relationships are clearly stated
121 | - All measurements are given in units
122 | - No complex transformations or unit conversions needed
123 |
124 | 3. What needs to be preserved:
125 | - The concentric squares description
126 | - The 3-unit distance between squares
127 | - The 22-unit largest square side length
128 | - The focus on perimeter rather than area
129 |
130 | 4. What needs to be clarified:
131 | - The answer should include units
132 | - A whole number is expected
133 | - No decimal places needed
134 |
135 | 5. How to make it open-ended:
136 | - Remove all multiple choice options
137 | - Keep the geometric scenario exactly as is
138 | - Add specific answer format requirements
139 | - Maintain the same level of difficulty
140 |
141 | 6. Final considerations:
142 | - The reformulated question should be solvable without reference to the original options
143 | - The specifications should ensure an unambiguous answer format
144 | - The mathematical challenge should remain unchanged
145 |
146 |
147 |
148 | Three squares with the same center and corresponding parallel sides are drawn. The distance between the sides of each successive square is 3 units, and the side length of the largest square is 22 units. Find the perimeter of the smallest square. Express your answer in units as a whole number.
149 |
150 |
151 | ```
152 |
153 | ```
154 |
155 |
156 | The teacher picked 9 out of 13 cards labeled with numbers 1 to 13 and stuck them on the foreheads of 9 students. Each student can see the numbers on the other 8 students' foreheads but not their own. The teacher asked, "Those who know the number of divisors of their own number, please raise your hand." Two students raised their hands. After they put their hands down, three students had the following conversation.
157 | - A: "I know my number now."
158 | - B: "I still don't know what my number is, but I have figured out whether it is odd or even."
159 | - C: "My number is 2 less than B's number and 1 more than A's number."
160 | What is the sum of the numbers on the four cards that were not picked?
161 |
162 |
163 |
164 | {
165 | "core_mathematical_concept": "Number theory and logical deduction with combinatorics",
166 | "key_information_extraction": [
167 | "13 cards numbered from 1 to 13",
168 | "9 cards picked and assigned to students",
169 | "Students can see others' numbers but not their own",
170 | "Two students know their number based on number of divisors",
171 | "Conversation between students A, B, and C revealing relationships",
172 | "Need to find sum of the 4 unpicked cards"
173 | ],
174 | "problem_structure_analysis": "Complex logical deduction problem combining number theory with information from multiple stages",
175 | "multiple_choice_removal_strategy": "Not applicable as this is not a multiple choice question",
176 | "rephrasing_approach": "Not needed as problem is already in appropriate format",
177 | "problem_integrity_preservation": "No modifications needed",
178 | "answer_format_specification": "Answer should be expressed as a single integer (sum of the unpicked cards)",
179 | "is_multiple_choice": false
180 | }
181 |
182 |
183 |
184 | Let me think about this step by step:
185 |
186 | 1. First, assess if reformulation is needed:
187 | - This is not a multiple choice question
188 | - It already asks for a specific numerical answer
189 | - The format is clear and unambiguous
190 |
191 | 2. Analyzing the problem structure:
192 | - It's a complex logic puzzle
193 | - All information provided is necessary
194 | - The sequence of events matters
195 | - The dialogue provides crucial information
196 |
197 | 3. Consider what might need clarification:
198 | - Answer format is implicitly clear (a sum)
199 | - No unit specification needed
200 | - No decimal places involved (working with integers 1-13)
201 | - No ambiguity in what constitutes an answer
202 |
203 | 4. Check for potential improvements:
204 | - Current format is optimal for the problem type
205 | - Adding structure might make it harder to follow
206 | - The narrative flow is important for understanding
207 |
208 | 5. Final decision:
209 | - No reformulation needed
210 | - Problem is already in ideal open-ended format
211 | - Answer requirements are clear (single integer sum)
212 |
213 |
214 |
215 | N/A
216 |
217 |
218 | ```
219 |
220 | ```
221 |
222 |
223 | In the cube shown, \( L \) and \( K \) are midpoints of adjacent edges \( AD \) and \( AB \), respectively. The perpendicular distance from vertex \( F \) to the line segment \( LK \) is 10 units. Which of the following measurements is closest to the volume of the cube after subtracting 3 units?
224 |
225 | (A) 200 in³
226 | (B) 500 cm³
227 | (C) 0.01039 m³
228 | (D) Tie
229 |
230 |
231 |
232 | {
233 | "core_mathematical_concept": "3D geometry involving cube volume calculation",
234 |
235 | "key_information_extraction": [
236 | "Points L and K are midpoints of adjacent edges AD and AB",
237 | "Perpendicular distance from F to LK is 10 units",
238 | "Need to subtract 3 units from final volume"
239 | ],
240 |
241 | "problem_structure_analysis": "Geometry problem requiring volume calculation",
242 |
243 | "multiple_choice_removal_strategy": [
244 | "Remove all comparisons",
245 | "Ask for direct volume calculation"
246 | ],
247 |
248 | "rephrasing_approach": [
249 | "Keep geometric relationships intact",
250 | "Request exact volume calculation",
251 | "Specify units and precision"
252 | ],
253 |
254 | "problem_integrity_preservation": [
255 | "Maintain all numerical values and geometric relationships",
256 | "Remove comparative aspect entirely"
257 | ],
258 |
259 | "answer_format_specification": [
260 | "Answer should be expressed in cubic meters (m³)",
261 | "Round to 5 decimal places",
262 | "Include units in answer"
263 | ],
264 |
265 | "is_multiple_choice": true
266 | }
267 |
268 |
269 | Let me think about this step by step:
270 |
271 | 1. First, I need to identify what makes this multiple choice:
272 |
273 | - The core task is actually calculating a specific volume
274 | - The multiple choice aspect is just about comparing to different unit measurements
275 | - The comparison isn't fundamental to the mathematical concept being tested
276 |
277 | 2. Looking at the original answer choices:
278 |
279 | - They use different units (in³, cm³, m³)
280 | - The m³ value has 5 decimal places (0.01039)
281 | - This suggests m³ with 5 decimal places is an appropriate precision level
282 |
283 | 3. How to remove the multiple choice aspect:
284 |
285 | - Instead of asking which measurement is closest
286 | - Ask for the exact calculation
287 | - Need to specify one consistent unit to avoid ambiguity
288 |
289 | 4. What to preserve:
290 |
291 | - All geometric relationships
292 | - The 3-unit subtraction
293 | - The original scenario and values
294 |
295 | 5. What needs to be added:
296 |
297 | - Specific unit requirement (m³)
298 | - Decimal place requirement (5 places)
299 | - Clear instruction about including units
300 |
301 | 6. The reformulated question should:
302 |
303 | - Be completely open-ended
304 | - Have a unique, calculable answer
305 | - Remove all traces of comparison
306 | - Maintain the same difficulty level
307 |
308 |
309 |
310 | In the cube shown, \( L \) and \( K \) are midpoints of adjacent edges \( AD \) and \( AB \), respectively. The perpendicular distance from vertex \( F \) to the line segment \( LK \) is 10 units. Calculate the volume of the cube after subtracting 3 cubic units. Express your answer in cubic meters (m³), rounded to 5 decimal places.
311 |
312 |
313 | ```
314 | """
315 |
316 | original_problem: str = dspy.InputField(
317 | desc="""The original multiple-choice math problem text, including all options and formatting."""
318 | )
319 |
320 | correct_answer: str = dspy.InputField(
321 | desc="The correct answer to the original problem."
322 | )
323 |
324 | reformulation_process: dict = dspy.OutputField(
325 | desc="""Structured analysis of the problem reformulation process containing the following exact keys:
326 | {
327 | "core_mathematical_concept": "Identify & summarize state the main mathematical concept",
328 | "key_information_extraction": "List of numerical values, equations, and key terms",
329 | "problem_structure_analysis": "Description of problem type and structure",
330 | "multiple_choice_removal_strategy": "How to handle removing MC options",
331 | "rephrasing_approach": "Strategy for rephrasing as open-ended",
332 | "problem_integrity_preservation": "How to maintain original difficulty and context",
333 | "answer_format_specification": "How the answer should be formatted with \\boxed{}"
334 | }""",
335 | prefix="\n\n{"
336 | )
337 |
338 | reasoning: str = dspy.OutputField(
339 | desc="""Think step by step about how to reformulate the problem while adhering to the guidelines provided. Or, if the problem is NOT a multiple choice problem, justify why it is not.""",
340 | prefix="\n\n"
341 | )
342 |
343 | reformulated_problem: str = dspy.OutputField(
344 | desc=r"""The reformulated open-ended problem that:
345 | - Preserves all mathematical notations and symbols
346 | - Maintains the original difficulty level
347 | - Specifies, and if necessary has clear instructions, for the expected answer format. The student should not be confused about the type of answer that should be presented. (E.g. round your answer to the nearest hundredth, or express your answer as a percentage, list your answer as a coordinate, etc.). If the answer cannot be easily expressed in boxed latex format (for auto-grading), then the problem was not well formulated!
348 | - Contains no multiple-choice options
349 | - Does not add any hints or solution methods
350 | - Retains all necessary context to solve the problem
351 | (If the problem is not a multiple choice problem or cannot be reformulated, then just return N/A)""",
352 | prefix="\n\n"
353 | )
354 |
355 | solution: str = dspy.OutputField(
356 | desc="""The solution to the reformulated problem. What the student should write inside of the `\boxed{}`.""",
357 | prefix="\n\n"
358 | )
359 |
360 |
361 | class ReformulationJudgeSignature(dspy.Signature):
362 | """You are a IMO judge, specialized in evaluating the quality of IMO math problem reformulations. Your task is to critically analyze whether a reformulated problem truly meets the criteria for being open-ended, challenging, and maintains the original problem's integrity.
363 |
364 | **Important:**
365 | - Problems you review are being transformed from multiple choice problems into open-ended, solvable questions suitable for an automatic grading system that relies on regex pattern matching within `\boxed{}`. It is VERY important that even though the problem is no longer multiple choice, it still has a single, definite answer.
366 | - It must be reasonably clear from the problem what the student is expected to write inside of \boxed{final_answer}. E.g. "What are its key properties and characteristics?" is almost always a bad reformulation because it is ambiguous what the student should write inside of the `\boxed{}`. "Express your answer in centimeters." is a good answer format specification because it does NOT give away the answer but does specify a clear format.
367 |
368 | Key aspects to evaluate:
369 |
370 | 1. Hidden Multiple Choice
371 | - Check if the reformulation still effectively presents multiple choice options by:
372 | - Embedding a limited/fixed set of choices (e.g. 4 or 5 options) within the problem text that limits the answer choices to one of those options
373 | - Asking to compare with specific values
374 | - Limiting answers to specific options
375 | - Flag any reformulations that are just disguised multiple choice questions
376 |
377 | Example:
378 |
379 |
380 | For real numbers $t \neq 0,$ the point \[(x,y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right)\]is plotted. All the plotted points lie on what kind of curve? (A) Line (B) Circle (C) Parabola (D) Ellipse (E) Hyperbola Enter the letter of the correct option.
381 |
382 |
383 | Reformulated failed example (all multiple choice aspect still present):
384 |
385 | For real numbers $t \neq 0,$ the point \[(x,y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right)\]is plotted. What type of curve do all the plotted points lie on? Provide your answer as a specific curve type (e.g., line, circle, parabola, ellipse, hyperbola).
386 |
387 |
388 | Reformulated successful example (multiple choice aspect removed):
389 |
390 | For real numbers \( t \neq 0 \), the point
391 | \[
392 | (x, y) = \left( \frac{t + 1}{t}, \frac{t - 1}{t} \right)
393 | \]
394 | is plotted. Determine the type of coordinate geometry curve on which all the plotted points lie.
395 |
396 |
397 | 2. Mathematical Integrity
398 | - Verify that the mathematical difficulty remains unchanged
399 | - Ensure no accidental hints or simplifications were introduced
400 | - Check that all necessary information was preserved
401 | - Confirm no extraneous information was added
402 |
403 | Example 1:
404 | This problem was not actually multiple choice, but it does _imply_ options that are clearly missing/omitted.
405 |
406 | A resident wants to renovate their house and buys several strips of wood, each with a length of 0.7 meters and 0.8 meters. By connecting some of these wood strips, many different lengths of wood can be obtained. For example, \(0.7 + 0.7 = 1.4\) meters, \(0.7 + 0.8 = 1.5\) meters, etc. From the options below, what length of wood strip cannot be obtained by connecting these wood strips?
407 |
408 |
409 |
410 | Example 2:
411 |
412 | $\triangle ABC$ is inscribed in a semicircle of radius $r$ so that its base $AB$ coincides with diameter $AB$. Point $C$ does not coincide with either $A$ or $B$. Let $s=AC+BC$. Then, for all permissible positions of $C$: $\textbf{(A)}\ s^2\le8r^2\qquad \textbf{(B)}\ s^2=8r^2 \qquad \textbf{(C)}\ s^2 \ge 8r^2 \qquad\\ \textbf{(D)}\ s^2\le4r^2 \qquad \textbf{(E)}\ s^2=4r^2$
413 |
414 |
415 | The reformulation is flawed because it prematurely focuses on a single inequality (\(s^2 \leq 8r^2\)) rather than inviting exploration of the entire range of \(s^2\). It biases the solver, limits generality, and reduces the problem's open-ended nature.
416 |
417 | $\triangle ABC$ is inscribed in a semicircle of radius $r$ so that its base $AB$ coincides with diameter $AB$. Point $C$ does not coincide with either $A$ or $B$. Let $s=AC+BC$. Prove or disprove the inequality $s^2\le8r^2$ for all permissible positions of $C$.
418 |
419 |
420 | An open-ended revision that requires the same analysis and leads to the same conclusion
421 |
422 | Let triangle \( ABC \) be inscribed in a semicircle of radius \( r \), with its base \( AB \) coinciding with the diameter \( AB \). Point \( C \) lies on the semicircle but does not coincide with \( A \) or \( B \). Let \( s = AC + BC \). Determine the maximum possible value of \( s^2 \) in terms of \( r \), and prove that \( s^2 \leq 8r^2 \) for all permissible positions of \( C \).
423 |
424 |
425 | Example 3:
426 |
427 | For real numbers $t,$ the point \[(x,y) = \left( \frac{1 - t^2}{1 + t^2}, \frac{2t}{1 + t^2} \right)\]is plotted. All the plotted points lie on what kind of curve? (A) Line (B) Circle (C) Parabola (D) Ellipse (E) Hyperbola Enter the letter of the correct option.
428 |
429 |
430 | Although it removes the multiple-choice format, it does not truly open the problem to exploration. The answer asks for a description, which introduces too much variability in responses
431 |
432 | For real numbers $t,$ the point \[(x,y) = \left( \frac{1 - t^2}{1 + t^2}, \frac{2t}{1 + t^2} \right)\]is plotted. What kind of curve do all the plotted points lie on? Provide a brief description of the curve.
433 |
434 |
435 | This reformulation is clear, concise, and ensures consistent answers while remaining open-ended for exploration.
436 |
437 | For real numbers \( t \), consider the point
438 |
439 | \[
440 | (x, y) = \left( \frac{1 - t^2}{1 + t^2},\ \frac{2t}{1 + t^2} \right).
441 | \]
442 |
443 | Determine the type of curve on which all such points \((x, y)\) lie. State your answer as the name of the curve.
444 |
445 |
446 | Example 4:
447 |
448 |
449 | Define * as an operation on ordered pairs of real numbers, such that $(a, b) *(c, d)=$ $(a c+b d, a d+b c)$. If $(a, b) *(x, y)=(a, b)$, then which of the following is $(x, y)$? (A) $(0,0)$. (B) $(0,1)$. (C) $(1,0)$. (D) $(1,1)$. (E) $(1,-1)$.
450 |
451 |
452 | Ambiguous about whether the equation holds for all \((a, b)\) or specific values, lacks context to ensure a unique solution, and removes guidance provided by the multiple-choice format. Becomes unbounded.
453 |
454 | Define * as an operation on ordered pairs of real numbers, such that $(a, b) *(c, d)=$ $(a c+b d, a d+b c)$. If $(a, b) *(x, y)=(a, b)$, then calculate the ordered pair $(x, y)$. Express your answer in the format $(x, y)$, including parentheses and a comma.
455 |
456 |
457 | Ideal because it clarifies the scope (applies to all \((a, b)\)), ensures a unique solution, and balances open-ended exploration with clear grading criteria
458 |
459 | Define \( * \) as an operation on ordered pairs of real numbers, such that
460 | \[
461 | (a, b) * (c, d) = (ac + bd, ad + bc).
462 | \]
463 | Find the ordered pair \((x, y)\) that satisfies
464 | \[
465 | (a, b) * (x, y) = (a, b)
466 | \]
467 | for **all** real numbers \(a\) and \(b\). Express your answer as \((x, y)\).
468 |
469 |
470 | 3. Answer Format Clarity
471 | - Evaluate if the answer format specification is:
472 | - Clear and unambiguous
473 | - Appropriate for the mathematical concept
474 | - Not overly unbounded or restrictive in a way that creates an open-ended problem out of it
475 | - Check if unit/precision requirements make sense for the problem
476 |
477 | 4. Problem Independence
478 | - Verify the reformulated problem can stand alone
479 | - Ensure it doesn't rely on knowledge of the original options
480 | - Check that answer requirements aren't derived solely from original choices
481 |
482 | Example:
483 |
484 | Which of the following is a root of the equation \( x^2 - x - 6 = 0 \)?
485 | (A) \( -3 \) (B) \( -2 \) (C) \( 2 \) (D) \( 3 \) (E) \( 6 \)
486 |
487 |
488 | Relies on original options, failing to stand independently and limiting exploration.
489 |
490 | Find a root of the equation \( x^2 - x - 6 = 0 \). Your answer must be one of the following: \( -3, -2, 2, 3, 6 \).
491 |
492 |
493 | Ideal because it ensures the solver identifies all roots without being constrained by the original options.
494 |
495 | Solve the quadratic equation \( x^2 - x - 6 = 0 \) and find all real roots. Provide your answers in increasing order.
496 |
497 |
498 | Remember: A truly open-ended reformulation should allow for calculation and expression of the answer
499 | without any reference to or knowledge of the original multiple choice options."""
500 |
501 | original_problem: str = dspy.InputField(
502 | desc="The original multiple-choice math problem text, including all options and formatting."
503 | )
504 | correct_answer: str = dspy.InputField(
505 | desc="The correct answer to the original problem."
506 | )
507 |
508 | reformulated_problem: str = dspy.InputField(
509 | desc="The reformulated open-ended version of the problem that needs to be evaluated."
510 | )
511 |
512 | issues: list[str] = dspy.OutputField(
513 | desc="List of specific issues found with the reformulation, if any.",
514 | prefix="\n"
515 | )
516 |
517 | suggestions: list[str] = dspy.OutputField(
518 | desc="Specific recommendations for improving the reformulation.",
519 | prefix="\n"
520 | )
521 |
522 | reasoning: str = dspy.OutputField(
523 | desc="Step-by-step explanation of the evaluation process and rationale for the verdict.",
524 | prefix="\n"
525 | )
526 |
527 | verdict: str = dspy.OutputField(
528 | desc="Either 'PASS' or 'FAIL' based on comprehensive evaluation of the reformulation.",
529 | prefix="\n"
530 | )
531 |
532 | corrected_version: Optional[str] = dspy.OutputField(
533 | desc="An improved version of the reformulation if issues were found. Should be None if verdict is PASS.",
534 | prefix="\n"
535 | )
--------------------------------------------------------------------------------
/signals/README.md:
--------------------------------------------------------------------------------
1 | # Signals
2 |
3 | The `signals` folder contains scripts and resources for implementing rule-based and model-based filters for dataset curation, as described in the [Big-Math paper](https://alon-albalak.github.io/images/Big_MATH.pdf). These filtering techniques are critical for ensuring the quality, diversity, and relevance of the dataset for reinforcement learning.
4 |
5 | ---
6 |
7 | ## 📂 Contents
8 |
9 | ### Rule-based filters
10 | - `add_empty_boxed_signal.py` (only useful for extracting answers from full solutions (e.g. from [NuminaMath](https://huggingface.co/datasets/AI-MO/NuminaMath-CoT)))
11 | - `add_hyperlink_signal.py`
12 | - `add_language_signal.py`
13 | - `add_multipartq_signal.py`
14 | - `add_multiple_choice_signal.py`
15 | - `add_proof_signal.py`
16 | - `add_true_false_signal.py`
17 | - `add_yes_no_signal.py`
18 |
19 | ### Deduplication
20 | - `add_semdedup_signal.py`
21 |
22 | ### Model-based signals
23 | - `model_based_signals.py`
24 |
25 | ### Solve rate
26 | - `rollouts_based_signals/example_solve_rate_script.sh`
27 |
28 |
29 | ## 🚀 Getting Started
30 |
31 | ### Prerequisites
32 | - Python3.10+
33 | - Install dependencies:
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ## 🛠 Usage
39 |
40 | ### Rule-based signals
41 | Rule-based signals are written in separate files, but share the same use pattern. Each signal requires only a single input variable, `dataset_path`. The script will automatically update the source dataset with a new column.
42 |
43 | This is a simple example that will detect whether the `problem` or `solution` columns have a hyperlink, adding 2 new columns, `has_hyperlink_problem` and `has_hyperlink_solution`:
44 | ```bash
45 | DATASET_PATH=
46 | python3 add_hyperlink_signal.py --dataset_path ${DATASET_PATH}
47 | ```
48 |
49 | ### Deduplication
50 | We use SemDeDup to mark duplicated problems. Just like the rule-based signals, this requires only a single input variable, `dataset_path`. The script, by default, uses a similarity threshold of 0.5, but that can easily be adjusted by adding/removing values from the `EPSILONS` variable.
51 | ```bash
52 | DATASET_PATH=
53 | python3 add_semdedup_signal.py --dataset_path ${DATASET_PATH}
54 | ```
55 |
56 | ### Model-based signals
57 | The model-based signals are all included in a single file for efficiency when running multiple filters. Our scripts make use of [SGLang](https://github.com/sgl-project/sglang) for fast inference.
58 |
59 | Model-based signals require a bit more specification than rule-based signals. Here, we need to specify:
60 | - `model_name`: The name of a local or HF model
61 | - `output_model_name`: The name of model you want to use in the signal's column name. For example, the below code uses `--output_model_name "llama3-70b"` and will add the multiple choice signal in a new column: "multiple_choice_llama3-70b"
62 | - `dataset_name`: A path to HF or local dataset
63 | - `dataset_name_outputs`: Optional - The path to save the updated dataset to. Will save to `dataset_name` when not specified
64 | - `save_folder`: Local directory to save intermediate outputs to
65 | - `save_name`: The name of the file to save intermediate outputs to
66 | - `tp`: tensor parallelism
67 | - Signal options: `--multiple_choice`, `--proof`, `--yes_no`, `--true_false`, `--multiple_part`
68 |
69 | The below example runs all model-based signals using Llama-3.1-70B:
70 | ```bash
71 | OUTPUT_DATASET_NAME=
72 |
73 | python3 model_based_filters.py \
74 | --model_name "meta-llama/Meta-Llama-3.1-70B-Instruct"\
75 | --output_model_name "llama3-70b"\
76 | --dataset_name "SynthLabsAI/Big-Math-RL-Verified"\
77 | --dataset_name_outputs "${OUTPUT_DATASET_NAME}"\
78 | --save_folder "signal_outputs"\
79 | --save_name "model_based_signals"\
80 | --tp 8\
81 | --multiple_choice\
82 | --proof\
83 | --yes_no\
84 | --true_false\
85 | --multiple_part\
86 | > model_based_signal_outputs.txt 2>&1
87 | ```
88 |
89 | ### Solve rate
90 | See `rollouts_based_signals/example_solve_rate_script.sh` for example usage.
91 |
--------------------------------------------------------------------------------
/signals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/__init__.py
--------------------------------------------------------------------------------
/signals/add_empty_boxed_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from functools import partial
3 | import multiprocessing as mp
4 | import re
5 | import argparse
6 |
7 | def is_boxed_empty(row, patterns):
8 | row['is_boxed_empty'] = False
9 | for pattern in patterns:
10 | if re.search(pattern, row['solution']):
11 | row['is_boxed_empty'] = True
12 | break
13 | return row
14 |
15 | def main(dataset_path):
16 | # load the dataset
17 | dataset = load_dataset(dataset_path, split="train")
18 |
19 | # define the regex pattern for an empty boxed solution
20 | empty_boxed_patterns = [
21 | r'boxed\{\s*\}',
22 | r'boxed\{[\s\n\r]*\}',
23 | r'\\boxed\{\s*\}'
24 | ]
25 |
26 | # run the detection over the full dataset
27 | is_boxed_empty_partial = partial(is_boxed_empty, patterns=empty_boxed_patterns)
28 | dataset = dataset.map(is_boxed_empty_partial, num_proc=mp.cpu_count())
29 |
30 | # add the new column, 'is_boxed_empty', to the dataset
31 | dataset.push_to_hub(dataset_path)
32 |
33 | # print stats
34 | boxed_empty_only = dataset.filter(lambda x: x['is_boxed_empty'])
35 | print(f"Boxed empty: {len(boxed_empty_only)} / {len(dataset)}")
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser(description="Detect empty boxed solutions in a dataset.")
39 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
40 | args = parser.parse_args()
41 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_hyperlink_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import multiprocessing as mp
3 | import re
4 | import argparse
5 |
6 | def detect_latex_hyperlinks(row):
7 | # Check for \url{} links
8 | url_pattern = r'\\url'
9 | # Check for \href{}{} links
10 | href_pattern = r'\\href'
11 | # Check for http:// or https:// links
12 | http_pattern = r'http://'
13 | https_pattern = r'https://'
14 |
15 | problem = row['problem']
16 |
17 | # Search for patterns
18 | if re.search(url_pattern, problem) or re.search(href_pattern, problem) \
19 | or re.search(http_pattern, problem) or re.search(https_pattern, problem):
20 | row['has_hyperlink_problem'] = True
21 | else:
22 | row['has_hyperlink_problem'] = False
23 |
24 | solution = row['solution']
25 |
26 | # Search for patterns
27 | if re.search(url_pattern, solution) or re.search(href_pattern, solution) \
28 | or re.search(http_pattern, solution) or re.search(https_pattern, solution):
29 | row['has_hyperlink_solution'] = True
30 | else:
31 | row['has_hyperlink_solution'] = False
32 |
33 | return row
34 |
35 | def main(dataset_path):
36 | # load the dataset
37 | dataset = load_dataset(dataset_path, split="train")
38 |
39 | # run hyperlink detection over the full dataset
40 | dataset = dataset.map(detect_latex_hyperlinks, num_proc=mp.cpu_count())
41 |
42 | # push the updated dataset
43 | dataset.push_to_hub(dataset_path)
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser(description="Detect LaTeX hyperlinks in a dataset.")
47 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
48 | args = parser.parse_args()
49 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_language_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import fasttext
3 | from functools import partial
4 | from huggingface_hub import hf_hub_download
5 | import multiprocessing as mp
6 | import re
7 | import argparse
8 |
9 | # load the fasttext language identification model
10 | model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
11 | fasttext_lang_id_model = fasttext.load_model(model_path)
12 |
13 | def remove_special_chars(text):
14 | # Remove numbers, special characters, and math symbols
15 | pattern = r'[0-9*,\.+\-=\(\)\/\^\[\]{}|<>~`!@#$%&?_]'
16 | text = re.sub(pattern, '', text)
17 |
18 | # Remove extra whitespace
19 | text = re.sub(r'\s+', ' ', text).strip()
20 |
21 | return text
22 |
23 | def remove_latex_commands(text):
24 | # remove \frac{}{} commands, and any arbitrary text inside the curly braces
25 | text = re.sub(r'\\frac\{[^{}]*\}\{[^{}]*\}', '', text)
26 |
27 | # Remove commands with arguments
28 | text = re.sub(r'\\[a-zA-Z]+\{[^{}]*\}', '', text)
29 |
30 | # Remove standalone commands
31 | text = re.sub(r'\\[a-zA-Z]+', '', text)
32 |
33 | # Remove math environments
34 | text = re.sub(r'\$\$(.*?)\$\$', '', text)
35 | text = re.sub(r'\$(.*?)\$', '', text)
36 |
37 | # Remove extra whitespace
38 | text = re.sub(r'\s+', ' ', text).strip()
39 |
40 | return text
41 |
42 | def predict_text(text: str):
43 | preds = fasttext_lang_id_model.predict(text, k=-1)
44 | return [(label.replace("__label__", ""), float(score)) for label, score in zip(*preds)]
45 |
46 | def detect_lang(row, column_name='problem'):
47 | if not row[column_name]:
48 | row[f'fasttext_lang_{column_name}'] = []
49 | row[f'{column_name}_language'] = "No language detected"
50 | return row
51 |
52 | cleaned_text = remove_special_chars(remove_latex_commands(row[column_name]).replace("\n", ""))
53 | lang_probs = predict_text(cleaned_text)
54 |
55 | # reformat the lang_probs into list of strings
56 | row[f'fasttext_lang_{column_name}'] = [f"{lang}: {prob}" for lang, prob in lang_probs[:5]]
57 |
58 | # if the text is too short, just allow it
59 | if len(cleaned_text) < 10:
60 | row[f'{column_name}_language'] = "en"
61 | return row
62 |
63 | # set default language
64 | row[f'{column_name}_language'] = "en"
65 |
66 | for lang, prob in lang_probs:
67 | if lang == "eng_Latn":
68 | eng_prob = prob
69 |
70 | # label anything with >2% probability of english as english
71 | if eng_prob < 0.02:
72 | highest_prob_lang = lang_probs[0][0]
73 | # if highest likelihood language uses latin character, set to "en"
74 | if highest_prob_lang.endswith("_Latn"):
75 | row[f'{column_name}_language'] = "en"
76 | else:
77 | row[f'{column_name}_language'] = highest_prob_lang
78 |
79 | return row
80 |
81 | def main(dataset_path):
82 | # define the columns to detect language on
83 | detect_language_on = ["problem", "final_answer"]
84 | new_columns = [f'{col}_language' for col in detect_language_on]
85 | new_columns += [f'fasttext_lang_{col}' for col in detect_language_on]
86 |
87 | # load the dataset
88 | dataset = load_dataset(dataset_path, split="train")
89 |
90 | for col in detect_language_on:
91 | # create function on the "problem" column
92 | detect_language = partial(detect_lang, column_name=col)
93 |
94 | # run language detection over the full dataset
95 | dataset = dataset.map(detect_language, num_proc=mp.cpu_count())
96 |
97 | # push the updated dataset
98 | dataset.push_to_hub(dataset_path)
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser(description="Detect language in a dataset.")
102 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
103 | args = parser.parse_args()
104 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_multipartq_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import multiprocessing as mp
3 | import re
4 | import argparse
5 |
6 | def has_multi_part_q(row):
7 | lower_problem = row['problem'].lower()
8 | pattern1 = r'\([IⅠ\s,\.]\).*\([IⅡ\s,\.]\)' # Roman numerals
9 | pattern2 = r'\([1-9][^\)]*\).*\([1-9][^\)]*\)' # Numbered parts in parentheses
10 | pattern3 = r'(?:\d+\.).*(?:\d+\.)' # Numbered parts with period
11 | pattern4 = r'\(I+\).*\(I+\)' # Traditional Roman numerals
12 | pattern5 = r'(?:(?:^|\s)(?:1[.\)]|[Ii][.\)]|①).*(?:\s|$)(?:2[.\)]|[Ii]{2}[.\)]|②))' # Mixed types
13 |
14 | # try the patterns in order
15 | if re.search(pattern1, lower_problem):
16 | row['is_multi_part_q_regex'] = True
17 | elif re.search(pattern2, lower_problem):
18 | row['is_multi_part_q_regex'] = True
19 | elif re.search(pattern3, lower_problem):
20 | row['is_multi_part_q_regex'] = True
21 | elif re.search(pattern4, lower_problem):
22 | row['is_multi_part_q_regex'] = True
23 | elif re.search(pattern5, lower_problem):
24 | row['is_multi_part_q_regex'] = True
25 | else:
26 | row['is_multi_part_q_regex'] = False
27 |
28 | return row
29 |
30 | def main(dataset_path):
31 | # load the dataset
32 | dataset = load_dataset(dataset_path, split="train")
33 |
34 | # run multi-part question detection over the full dataset
35 | dataset = dataset.map(has_multi_part_q, num_proc=mp.cpu_count())
36 |
37 | # push the updated dataset
38 | dataset.push_to_hub(dataset_path)
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser(description="Detect multi-part questions in a dataset.")
42 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
43 | args = parser.parse_args()
44 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_multiple_choice_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import multiprocessing as mp
3 | import argparse
4 |
5 | LETTER_OPTIONS=['A', 'B', 'C', 'D']
6 | NUMERICAL_OPTIONS=['1', '2', '3', '4']
7 |
8 | def is_multiple_choice(question, answer_options):
9 | # first, filter out a question that uses the answer options as part of the question
10 | # for example ABCD can be part of the question, but is not an answer option
11 | question = question.replace("".join(answer_options), "")
12 |
13 | # next, search for the last occurrence of each answer option
14 | options_found = [question.rfind(option) for option in answer_options[::-1]]
15 |
16 | # if any answer option is not found, return False
17 | if any(option == -1 for option in options_found):
18 | return False
19 |
20 | # check if the options are found in reverse order
21 | for i, option in enumerate(options_found):
22 | if i > 0 and option > options_found[i - 1]:
23 | return False
24 |
25 | return True
26 |
27 | def detect_multiple_choice(row):
28 | if is_multiple_choice(row['problem'], LETTER_OPTIONS) or is_multiple_choice(row['problem'], NUMERICAL_OPTIONS):
29 | row['is_multiple_choice'] = True
30 | else:
31 | row['is_multiple_choice'] = False
32 | return row
33 |
34 | def main(dataset_path):
35 | # pull the dataset
36 | dataset = load_dataset(dataset_path, split="train")
37 |
38 | # run language detection over the full dataset
39 | dataset = dataset.map(detect_multiple_choice, num_proc=mp.cpu_count())
40 |
41 | # push the updated dataset
42 | dataset.push_to_hub(dataset_path)
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser(description="Detect multiple choice questions in a dataset.")
46 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
47 | args = parser.parse_args()
48 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_proof_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import multiprocessing as mp
3 | import nltk
4 | import argparse
5 |
6 | nltk.download('punkt_tab')
7 |
8 | # given a single row of the dataset, determine if there is a proof
9 | def has_proof(row):
10 | row['is_math_proof'] = False
11 |
12 | lower_problem = row['problem'].lower()
13 | if 'prove that' in lower_problem or 'a proof' in lower_problem:
14 | row['is_math_proof'] = True
15 |
16 | # special search in "olympiads" subset, search for "show" at the beginning of a sentence
17 | if row['source'] == "olympiads" and not row['is_math_proof']:
18 | sentences = nltk.sent_tokenize(row['problem'])
19 | for sentence in sentences:
20 | sentence = sentence.lstrip().lower()
21 | if sentence.startswith('show'):
22 | row['is_math_proof'] = True
23 | break
24 |
25 | return row
26 |
27 | def main(dataset_path):
28 | # load the dataset
29 | dataset = load_dataset(dataset_path, split="train")
30 |
31 | # run proof detection over the full dataset
32 | dataset = dataset.map(has_proof, num_proc=mp.cpu_count())
33 |
34 | # push the dataset
35 | dataset.push_to_hub(dataset_path)
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser(description="Detect proofs in a dataset.")
39 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
40 | args = parser.parse_args()
41 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_semdedup_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from functools import partial
3 | import multiprocessing as mp
4 | import argparse
5 | from semdedup import semantic_deduplication
6 |
7 | EPSILONS = [0.5]
8 |
9 | def merge_problem_and_answer(row):
10 | row['problem_answer'] = f"Problem: {row['problem']} Answer: {row['final_answer']}"
11 | return row
12 |
13 | def is_semdedup_duplicate(row, idx, indices_to_remove, epsilon):
14 | row[f'is_semdedup_duplicate_eps{epsilon}'] = idx in indices_to_remove
15 | return row
16 |
17 | def main(dataset_path):
18 |
19 | dataset = load_dataset(dataset_path, split="train")
20 | print(f"Original Dataset: {dataset}")
21 |
22 | # create a new column that merges the problem and the answer
23 | dataset = dataset.map(merge_problem_and_answer, num_proc=mp.cpu_count())
24 |
25 | # convert to pandas and run semdedup
26 | df = dataset.data.to_pandas()
27 | for epsilon in EPSILONS:
28 | print(f"Running semdedup with epsilon {epsilon}")
29 | df, indices_to_remove, cluster_duplicates_dfs = semantic_deduplication(
30 | df = df,
31 | required_columns = ["problem_answer"],
32 | num_kmeans_clusters = len(df) // 100,
33 | embedding_batch_size = 2500,
34 | use_gpu=True
35 | )
36 |
37 | # add the semdedup outcome to the dataset
38 | partial_is_semdedup_duplicate = partial(is_semdedup_duplicate, indices_to_remove=indices_to_remove, epsilon=epsilon)
39 | dataset = dataset.map(partial_is_semdedup_duplicate, with_indices=True, num_proc=mp.cpu_count())
40 |
41 | print(f"Epsilon {epsilon} Filtered Dataset: {dataset}")
42 | print(f"Epsilon {epsilon} Number of duplicates: {len(indices_to_remove)}")
43 |
44 | # remove the problem_answer column
45 | dataset = dataset.remove_columns("problem_answer")
46 |
47 | # push the merged dataset to the hub
48 | dataset.push_to_hub(dataset_path)
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser(description="Detect semantic duplicates in a dataset.")
52 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
53 | args = parser.parse_args()
54 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_true_false_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import re
3 | import multiprocessing as mp
4 | import argparse
5 |
6 | def is_true_false(row):
7 | row['is_true_false_question'] = False
8 |
9 | # check if solution has flexible "true or false" in it
10 |
11 | # if answer is not empty, check that first:
12 | if row['final_answer']:
13 | if re.search(r'(true|false)', row['final_answer'], re.IGNORECASE):
14 | row['is_true_false_question'] = True
15 | else:
16 | # check the final line of the solutions
17 | for solution in row['hard_math_solutions']:
18 | lines = solution.split('\n')
19 | if re.search(r'(true|false)', lines[-1], re.IGNORECASE):
20 | row['is_true_false_question'] = True
21 |
22 | return row
23 |
24 | def main(dataset_path):
25 | # load the dataset
26 | dataset = load_dataset(dataset_path, split="train")
27 |
28 | # run proof detection over the full dataset
29 | dataset = dataset.map(is_true_false, num_proc=mp.cpu_count())
30 |
31 | # push the dataset
32 | dataset.push_to_hub(dataset_path)
33 |
34 | if __name__ == "__main__":
35 | parser = argparse.ArgumentParser(description="Detect true/false questions in a dataset.")
36 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
37 | args = parser.parse_args()
38 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/add_yes_no_signal.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import re
3 | import argparse
4 |
5 | def add_yes_no_signal(row):
6 | row['is_yes_no_question'] = False
7 |
8 | # check if the solution responds with yes/no
9 | # if final answer is not empty, check that first
10 | if row['final_answer']:
11 | if re.search(r'(\s|\{|\b)(yes|no)(\s|\}|\b)', row['final_answer'], re.IGNORECASE):
12 | row['is_yes_no_question'] = True
13 | else:
14 | # check the final line of the solution
15 | solution_lines = row['solution'].split('\n')
16 | if re.search(r'(\s|\{|\b)(yes|no)(\s|\}|\b)', solution_lines[-1], re.IGNORECASE):
17 | row['is_yes_no_question'] = True
18 |
19 | # fall back to checking if the problem sounds like a yes/no question
20 | problem_lines = row['problem'].split('\n')
21 |
22 | # search for a series of yes/no question prefixes
23 | yes_no_question_prefixes = [
24 | "is ", "are ", "do ", "does ", "can "
25 | ]
26 | # check if there is a previous line, if so then the previous line must end with punctuation ".,:;!?".
27 | if (len(problem_lines) > 1 and re.search(r'[.,:;!?\s]$', problem_lines[-2])) or \
28 | len(problem_lines) == 1:
29 | if any(problem_lines[-1].lower().startswith(prefix) for prefix in yes_no_question_prefixes):
30 | row['is_yes_no_question'] = True
31 |
32 | return row
33 |
34 | def main(dataset_path):
35 | # load the dataset
36 | dataset = load_dataset(dataset_path, split="train")
37 |
38 | # run the detection over the full dataset
39 | dataset = dataset.map(add_yes_no_signal)
40 |
41 | # push the dataset
42 | dataset.push_to_hub(dataset_path)
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description="Detect yes/no questions in a dataset.")
46 | parser.add_argument('dataset_path', type=str, help='Path to the dataset')
47 | args = parser.parse_args()
48 | main(args.dataset_path)
--------------------------------------------------------------------------------
/signals/model_based_signals.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from datasets import Dataset, load_dataset
4 | from enum import Enum
5 | import json
6 | import os
7 | from pydantic import BaseModel, Field, field_validator
8 | import traceback
9 | from tqdm import tqdm
10 |
11 | from rollouts_based_signals.utils.sglang_util import SGLangServerManager
12 |
13 | class ModelType(str, Enum):
14 | """supported llm model types"""
15 | Llama3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16 | Llama3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
17 | Llama3_1_405B = "meta-llama/Llama-3.1-405B-Instruct-FP8"
18 | CLAUDE_3_5 = "claude-3-5-sonnet-latest"
19 | O1_PREVIEW = "o1-preview"
20 | O1_MINI = "o1-mini"
21 |
22 | class LLMHyperparams(BaseModel):
23 | system_prompt: str = """You are a math expert. Given the following math problem, provide your solution in Latex format. Always format your final answer in perfect LaTeX \\boxed{{final_answer}} format."""
24 | prompt: str = "{problem}"
25 | temperature: float = Field(
26 | default=0.8,
27 | ge=0.0,
28 | le=2.0,
29 | description='Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.'
30 | )
31 | top_k: int = Field(
32 | default=-1,
33 | description='Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.'
34 | )
35 | top_p: float = Field(
36 | default=0.95,
37 | ge=0.0,
38 | le=1.0,
39 | description='Float that controls the cumulative probability of the top tokens to consider.'
40 | )
41 | max_tokens: int = 2048
42 | model_name: ModelType = Field(
43 | description='The model to use for generation.',
44 | )
45 | n: int = Field(
46 | description='Number of samples to generate.'
47 | )
48 |
49 | @field_validator('temperature')
50 | def validate_temperature(cls, v):
51 | if v < 0:
52 | raise ValueError("temp can't be negative")
53 | return v
54 |
55 | def to_json(self) -> str:
56 | return json.dumps(self.model_dump(), indent=2)
57 |
58 | @classmethod
59 | def load_json(cls, json_str: str) -> 'LLMHyperparams':
60 | data = json.loads(json_str)
61 | return cls(**data)
62 |
63 | no_multiple_parts = [
64 | """\
65 | Ancient Greek mathematicians from the Pythagorean school studied various polygonal numbers, such as triangular numbers 1, 3, 6, 10, ..., with the $n$-th triangular number being $\\frac{n(n+1)}{2} = \\frac{1}{2}n^{2} + \\frac{1}{2}n$. Let the $n$-th $k$-sided polygon number be denoted as $N(n, k)$ ($k \\geq 3$). Below are the expressions for the $n$-th number of some $k$-sided polygon numbers: \nTriangular numbers $N(n, 3) = \\frac{1}{2}n^{2} + \\frac{1}{2}n$ \nSquare numbers $N(n, 4) = n^{2}$ \nPentagonal numbers $N(n, 5) = \\frac{3}{2}n^{2} - \\frac{1}{2}n$ \nHexagonal numbers $N(n, 6) = 2n^{2} - n$ \n... \nFrom this, we can deduce the expression for $N(n, k)$ and calculate $N(8, 12) = \\_\\_\\_\\_\\_\\_$.'
66 | """,
67 | """\
68 | Let x be the dividend, y be the divisor, z be the quotient, and r be the remainder. If y = 3(y1 + y2) + 4, z = 2z1^2 - z2, r = 3r1 + 2, and x = 2x1y1 - x2 + 10, find the values of x, y, z, and r, given that y1 = 2, y2 = 3, z1 = 3, z2 = 5, r1 = 1, x1 = 4, and x2 = 6.
69 | """,
70 | """\
71 | If $x_{1}, x_{2},\ldots ,x_{n}$ are positive real numbers with $x_{1}^2+x_2^{2}+\ldots +x_{n}^{2}=1$ , find the minimum value of $\sum_{i=1}^{n}\frac{x_{i}^{5}}{x_{1}+x_{2}+\ldots +x_{n}-x_{i}}$ .
72 | """,
73 | """
74 | Given that the value of the function \( f(x) = \frac{1}{(x-2)^{2}} - 2x + \cos 2\theta - 3 \sin \theta + 2 \) is always positive for \( x \in (-\infty, 2) \), determine the range of the parameter \( \theta \) within the interval \( (0, \pi) \).
75 | """,
76 | """
77 | A positive integer \( n \) is written on the blackboard. Two players, A and B, take turns performing operations, starting with A. On A's turn, if the number on the board is \( k \), A replaces it with one of \( \frac{k}{2} \), \( \frac{k}{4} \), or \( 3k \) (with the first two choices being allowed only if \( \frac{k}{2} \) or \( \frac{k}{4} \) are integers). On B's turn, if the number on the board is \( k \), B replaces it with either \( k+1 \) or \( k-1 \). Player A wins the game if the number 3 is written on the blackboard at any point. For which positive integers \( n \) does player A have a winning strategy?
78 | """,
79 | """\
80 | A student was studying the properties of the function $f(x) = x^2e^x$ and came to the following conclusions:
81 | - ① The interval where $f(x)$ is monotonically decreasing is $(-2, 0)$;
82 | - ② $f(x)$ has neither a minimum nor a maximum value;
83 | - ③ The graph of $f(x)$ and its tangent line at $(0,0)$ intersect at two points;
84 | - ④ The graph of $f(x)$ and the line $x - y + 2012 = 0$ intersect at two points.
85 |
86 | Among these conclusions, the correct ones are __________.
87 | """,
88 | """\
89 | Prove that the set \(A=\left\{2,2^{2}, \cdots\right., \left.2^{n}, \cdots\right\}\) satisfies the following:
90 |
91 | (1) For every \(a \in A\) and \(b \in \mathbf{N}^{*}\), if \(b < 2a - 1\), then \(b(b+1)\) is not a multiple of \(2a\).
92 |
93 | (2) For every \(a \in \bar{A}\) (where \(\bar{A}\) denotes the complement of \(A\) in \(\mathbf{N}^{*}\)) and \(a \neq 1\), there exists \(b \in \mathbf{N}^{*}\) such that \(b < 2a - 1\) and \(b(b+1)\) is a multiple of \(2a\).
94 | """,
95 | """\
96 | You have a 200 liters mixture of four chemicals: W, X, Y, and Z which are in the ratio of 3:4:6:7. You add 40 liters of chemical W, 35 liters of chemical X, 50 liters of chemical Y and 25 liters of chemical Z to this mixture. What is the new percentage of each chemical W, X, Y and Z in the resulting mixture?
97 | """,
98 | """\
99 | Which of the following statements are correct?
100 |
101 | A: If a sample of size $5$ is drawn from a population of $50$ individuals using simple random sampling, the probability of individual $m$ being selected is $0.1$.
102 |
103 | B: Given a set of data $1$, $2$, $m$, $6$, $7$ with an average of $4$, then the variance of this data set is $\frac{{26}}{5}$.
104 |
105 | C: The $70$th percentile of the data $13$, $27$, $24$, $12$, $14$, $30$, $15$, $17$, $19$, $23$ is $23$.
106 |
107 | D: If the standard deviation of a sample data $x_{1}$, $x_{2}$, $\cdots$, $x_{10}$ is $8$, then the standard deviation of the data $2x_{1}-1$, $2x_{2}-1$, $\cdots$, $2x_{10}-1$ is $32$.
108 | """,
109 | """\
110 | A transparent, sealed cubic container is exactly half filled with water. When this cube is rotated arbitrarily, the shape of the water surface inside the container can be: (1) triangle; (2) rectangle; (3) square; (4) regular hexagon. Among these, the correct conclusions are __________.
111 | """,
112 | """\
113 | Given a 100-sided polygon \( P \) in the Cartesian coordinate plane, satisfying:
114 | (i) The coordinates of the vertices of \( P \) are all integers;
115 | (ii) The sides of \( P \) are parallel to the coordinate axes;
116 | (iii) The side lengths of \( P \) are all odd numbers.
117 |
118 | Prove that the area of \( P \) is an odd number.
119 | """,
120 | """\
121 | Given a positive integer $n \ge 2$. Find all $n$-tuples of positive integers $(a_1,a_2,\ldots,a_n)$, such that $1 0)$ intersects curve $C_1$ at point $A$ (other than the pole) and intersects curve $C_2$ at point $B$. Find $|AB|$.
138 | """,
139 | """\
140 | Given the function $f(x)=|x+2|-|2x-a|$, $(a\in\mathbb{R})$. (I) When $a=3$, solve the inequality $f(x) > 0$; (II) When $x \in [0, +\infty)$, $f(x) < 3$ always holds, find the range of $a$.
141 | """,
142 | """\
143 | Given an ellipse $C_1$: $\frac{x^2}{a^2} + \frac{y^2}{b^2} = 1$ ($a > b > 0$) with a major axis length of 4 and an eccentricity of $\frac{1}{2}$, where $F_1$ and $F_2$ are its left and right foci, respectively. A moving circle passes through point $F_2$ and is tangent to the line $x = -1$. (Ⅰ) (i) Find the equation of the ellipse $C_1$; (ii) Find the equation of the trajectory of the center $C$ of the moving circle; (Ⅱ) On the curve $C$, there are two points $M$ and $N$, and on the ellipse $C_1$, there are two points $P$ and $Q$, satisfying that $MF_2$ and $\overrightarrow{NF_2}$ are collinear, $\overrightarrow{PF_2}$ and $\overrightarrow{QF_2}$ are collinear, and $\overrightarrow{PF_2} \cdot \overrightarrow{MF_2} = 0$, find the minimum value of the area of quadrilateral $PMQN$.
144 | """,
145 | """\
146 | In the rectangular coordinate system $xOy$, a polar coordinate system is established with the coordinate origin as the pole and the positive semi-axis of the $x$-axis as the polar axis. The polar coordinate equation of circle $C$ is $\rho^2 - 2m\rho\cos\theta + 4\rho\sin\theta = 1 - 2m$. (1) Find the rectangular coordinate equation of $C$ and its radius. (2) When the radius of $C$ is the smallest, the curve $y = \sqrt{3}|x - 1| - 2$ intersects $C$ at points $A$ and $B$, and point $M(1, -4)$. Find the area of $\triangle MAB$.
147 | """,
148 | """\
149 | In recent years, the emergence of "shared bicycles" has greatly facilitated the "green travel" for citizens. A shared bicycle company "Mobie" plans to invest a total of 1.2 million yuan in two cities, A and B. According to industry regulations, each city must receive an investment of at least 400,000 yuan. Preliminary market research shows that the profit $P$ in city A and the investment $a$ (in units of 10,000 yuan) satisfy $P=3 \sqrt{2a}-6$, and the profit $Q$ in city B and the investment $a$ (in units of 10,000 yuan) satisfy $Q= \frac{1}{4}a+2$. Let the investment in city A be $x$ (in units of 10,000 yuan), and the total profit of the two cities be $f(x)$ (in units of 10,000 yuan). $(1)$ When the investment in city A is 500,000 yuan, calculate the total profit of the company at this time; $(2)$ How should the investments in cities A and B be arranged to maximize the total profit?
150 | """,
151 | """
152 | From points \( M \) and \( K \), which are 70 km apart, a bus and a cyclist set out towards each other simultaneously. They met after 1 hour and 24 minutes. Continuing at the same speed, the bus arrived at \( K \) and left for the return journey after a 20-minute stop. Find the speeds of the bus and the cyclist if the bus overtook the cyclist 2 hours and 41 minutes after their first meeting.
153 | """,
154 | """\
155 | Let $ L$ denote the set of all lattice points of the plane (points with integral coordinates). Show that for any three points $ A,B,C$ of $ L$ there is a fourth point $ D,$ different from $ A,B,C,$ such that the interiors of the segments $ AD,BD,CD$ contain no points of $ L.$ Is the statement true if one considers four points of $ L$ instead of three?
156 | """,
157 | """\
158 | For n real numbers $a_{1},\, a_{2},\, \ldots\, , a_{n},$ let $d$ denote the difference between the greatest and smallest of them and $S = \sum_{i a > 1\), and \(a\) does not divide \(b\), and a given sequence of positive integers \(\{b_n\}_{n=1}^{\infty}\) satisfying \(b_{n+1} \geq 2b_n\) for all positive integers \(n\). Does there always exist a sequence of positive integers \(\{a_n\}_{n=1}^{\infty}\) such that for all positive integers \(n\), \(a_{n+1} - a_n \in \{a, b\}\), and for all positive integers \(m\) and \(l\) (which can be the same), \(a_m + a_l \notin \{b_n\}\) for all \(n\)?
165 | """,
166 | """\
167 | Let \( f(x) = x^n, x \in D, n \in \mathbf{N}^{+} \). Determine whether \( f(x) \) is a solution to the functional inequality
168 | \[
169 | f(x) + f(1-x) > 1
170 | \]
171 | If so, find the domain \( D \); if not, provide an explanation.
172 | """,
173 | """\
174 | In a right angled-triangle $ABC$, $\angle{ACB} = 90^o$. Its incircle $O$ meets $BC$, $AC$, $AB$ at $D$,$E$,$F$ respectively. $AD$ cuts $O$ at $P$. If $\angle{BPC} = 90^o$, prove $AE + AP = PD$.
175 | """,
176 | """\
177 | A(x,y), B(x,y), and C(x,y) are three homogeneous real-coefficient polynomials of x and y with degree 2, 3, and 4 respectively. we know that there is a real-coefficient polinimial R(x,y) such that $B(x,y)^2-4A(x,y)C(x,y)=-R(x,y)^2$. Show that there exist 2 polynomials F(x,y,z) and G(x,y,z) such that $F(x,y,z)^2+G(x,y,z)^2=A(x,y)z^2+B(x,y)z+C(x,y)$ if for any x, y, z real numbers $A(x,y)z^2+B(x,y)z+C(x,y)\ge 0$
178 | """,
179 | """\
180 | Prove \[\frac{1}{\cos 0^\circ \cos 1^\circ} + \frac{1}{\cos 1^\circ \cos 2^\circ} + \cdots + \frac{1}{\cos 88^\circ \cos 89^\circ} = \frac{\cos 1^\circ}{\sin^2 1^\circ}.\]
181 | """,
182 | ]
183 |
184 | no_proof_questions = [
185 | """\
186 | In a $100 \times 25$ rectangular table, each cell is filled with a non-negative real number. The number in the $i$-th row and $j$-th column is denoted by $x_{i, j}$ $(i=1,2,\ldots, 100; j=1,2,\ldots, 25)$ (Table 1). The numbers in each column of Table 1 are then reordered in descending order to create Table 2 such that $x_{1, j}^{\prime} \geq x_{2, j}^{\prime} \geq \cdots \geq x_{100, j}^{\prime}$ $(j=1,2,\ldots, 25)$. Find the smallest natural number $k$ such that if the numbers in Table 1 satisfy $\sum_{j=1}^{25} x_{i, j} \leq 1$ $(i=1,2,\ldots, 100)$, then for $i \geq k$, Table 2 satisfies $\sum_{j=1}^{25} x_{i, j}^{\prime} \leq 1$ $(i=1,2,\ldots, 100)$.
187 | """,
188 | """\
189 | We are given $2n$ natural numbers
190 | \[1, 1, 2, 2, 3, 3, \ldots, n - 1, n - 1, n, n.\]
191 | Find all $n$ for which these numbers can be arranged in a row such that for each $k \leq n$, there are exactly $k$ numbers between the two numbers $k$.
192 | """,
193 | """\
194 | Determine all positive integers $n$, $n\ge2$, such that the following statement is true: If $(a_1,a_2,...,a_n)$ is a sequence of positive integers with $a_1+a_2+\cdots+a_n=2n-1$, then there is block of (at least two) consecutive terms in the sequence with their (arithmetic) mean being an integer.
195 | """,
196 | """\
197 | Turbo the snail sits on a point on a circle with circumference $1$. Given an infinite sequence of positive real numbers $c_1, c_2, c_3, \dots$, Turbo successively crawls distances $c_1, c_2, c_3, \dots$ around the circle, each time choosing to crawl either clockwise or counterclockwise.
198 | Determine the largest constant $C > 0$ with the following property: for every sequence of positive real numbers $c_1, c_2, c_3, \dots$ with $c_i < C$ for all $i$, Turbo can (after studying the sequence) ensure that there is some point on the circle that it will never visit or crawl across.
199 | """,
200 | """\
201 | For an even integer positive integer $n$ Kevin has a tape of length $4 n$ with marks at $-2 n,-2 n+1, \ldots, 2 n-1,2 n$. He then randomly picks $n$ points in the set $-n,-n+1,-n+2, \ldots, n-1, n$, and places a stone on each of these points. We call a stone 'stuck' if it is on $2 n$ or $-2 n$, or either all the points to the right, or all the points to the left, all contain stones. Then, every minute, Kevin shifts the unstuck stones in the following manner: He picks an unstuck stone uniformly at random and then flips a fair coin. If the coin came up heads, he then moves that stone and every stone in the largest contiguous set containing that stone one point to the left. If the coin came up tails, he moves every stone in that set one point right instead. He repeats until all the stones are stuck. Let $p_{k}$ be the probability that at the end of the process there are exactly $k$ stones in the right half. Evaluate $$\frac{p_{n-1}-p_{n-2}+p_{n-3}-\ldots+p_{3}-p_{2}+p_{1}}{p_{n-1}+p_{n-2}+p_{n-3}+\ldots+p_{3}+p_{2}+p_{1}}$$ in terms of $n$.
202 | """,
203 | """\
204 | A 0-1 sequence of length $2^k$ is given. Alice can pick a member from the sequence, and reveal it (its place and its value) to Bob. Find the largest number $s$ for which Bob can always pick $s$ members of the sequence, and guess all their values correctly.
205 |
206 | Alice and Bob can discuss a strategy before the game with the aim of maximizing the number of correct guesses of Bob. The only information Bob has is the length of the sequence and the member of the sequence picked by Alice.
207 | """
208 | ]
209 |
210 |
211 | def formatted_items(problem, filter_type):
212 | if filter_type == "multiple_choice":
213 | prompt = f"""\
214 | Given this question: {problem}
215 |
216 | Is this a multiple choice question (a question that provides specific options to choose from, typically labeled as A, B, C, D or 1, 2, 3, 4)?
217 |
218 | Return only "yes" or "no" without any additional explanation.
219 | """
220 | elif filter_type == "proof":
221 | prompt = f"""\
222 | Given this question: {problem}
223 |
224 | Is this a mathematical proof question (a question that asks to prove a statement, theorem, or property...)?
225 |
226 | Examples of proof indicators:
227 | - "Prove that..."
228 | - "Show that..."
229 | - "Demonstrate why..."
230 | - "Justify your answer..."
231 | - "Explain why..."
232 | etc.
233 | Here are examples of proof questions:
234 | Example 1: {proof_questions[0]}
235 | Example 2: {proof_questions[1]}
236 | Example 3: {proof_questions[2]}
237 | Example 4: {proof_questions[3]}
238 | Example 5: {proof_questions[4]}
239 | Here are examples of non-proof questions:
240 | Example 1: {no_proof_questions[0]}
241 | Example 2: {no_proof_questions[1]}
242 | Example 3: {no_proof_questions[2]}
243 | Example 4: {no_proof_questions[3]}
244 | Example 5: {no_proof_questions[4]}
245 | Example 6: {no_proof_questions[5]}
246 |
247 | Return only "yes" or "no" without any additional explanation.
248 | """
249 | elif filter_type == "yes_no":
250 | prompt = f"""\
251 | Given this question: {problem}
252 |
253 | Is this a yes/no question (a question that asks to choose between two options, typically labeled as yes or no)?
254 |
255 | Return only "yes" or "no" without any additional explanation.
256 | """
257 | elif filter_type == "true_false":
258 | prompt = f"""\
259 | Given this question: {problem}
260 |
261 | Is this a true/false question (a question that asks to choose between two options, typically labeled as true or false)?
262 |
263 | Return only "true" or "false" without any additional explanation.
264 | """
265 | elif filter_type == "multiple_part":
266 | prompt = f"""\
267 | Your task is to determine if the given question contains multiple sub-questions, sub-parts, or sub-tasks.
268 | A multi-part question requires separate answers for different components, rather than a single comprehensive answer.
269 | Besides that, if the question is multiple choice and only requires to select one option, it is not a multi-part question.
270 |
271 | Here are examples of multi-part questions that require multiple distinct answers:
272 | Example 1: {multiple_parts[0]}
273 | Example 2: {multiple_parts[1]}
274 | Example 3: {multiple_parts[2]}
275 | Example 4: {multiple_parts[3]}
276 | Example 5: {multiple_parts[4]}
277 | Example 6: {multiple_parts[5]}
278 | Example 7: {multiple_parts[6]}
279 | Example 8: {multiple_parts[7]}
280 | Example 9: {multiple_parts[8]}
281 |
282 | Here are examples of single-part questions that require only one answer:
283 | Example 1: {no_multiple_parts[0]}
284 | Example 2: {no_multiple_parts[1]}
285 | Example 3: {no_multiple_parts[2]}
286 | Example 4: {no_multiple_parts[3]}
287 | Example 5: {no_multiple_parts[4]}
288 | Example 6: {no_multiple_parts[5]}
289 | Example 7: {no_multiple_parts[6]}
290 | Please analyze this question: {problem}
291 |
292 | Does this question contain multiple parts requiring separate answers?
293 | Return only "yes" or "no" without any additional explanation.
294 | """
295 | else:
296 | raise ValueError(f"Invalid type: {filter_type}")
297 |
298 | return [{"role": "user", "content": prompt}]
299 |
300 |
301 | async def main():
302 | # Configuration
303 | model_name = args.model_name
304 | dataset_name = args.dataset_name
305 | save_name = args.save_name
306 | save_folder = args.save_folder
307 | problem_column_name = args.problem_column_name
308 | if args.dataset_name_outputs:
309 | hf_save_dataset_name = args.dataset_name_outputs
310 | else:
311 | hf_save_dataset_name = dataset_name
312 | os.makedirs(save_folder, exist_ok=True)
313 |
314 | batch_size = args.batch_size
315 | save_interval = args.save_interval
316 |
317 | # Sampling hyperparameters
318 | llm_params = LLMHyperparams(
319 | top_k=args.top_k,
320 | top_p=args.top_p,
321 | temperature=args.temperature,
322 | model_name=ModelType(model_name),
323 | n=args.n,
324 | max_tokens=args.max_tokens
325 | )
326 |
327 | print("Hyperparameters:")
328 | print(llm_params.to_json())
329 |
330 | # Load and preprocess dataset
331 | ds = load_dataset(dataset_name, split=args.dataset_split)
332 | df = ds.to_pandas()
333 |
334 | # if only using partial dataset, slice it
335 | if args.end == -1:
336 | args.end = len(df)
337 | df = df.iloc[args.start:args.end]
338 | df = df.reset_index(drop=True)
339 |
340 | print(f"Total dataset: {len(df)}")
341 |
342 | # prepare data for generation
343 | items = {}
344 | if args.multiple_choice:
345 | items['multiple_choice'] = [formatted_items(row[problem_column_name], "multiple_choice") for _, row in df.iterrows()]
346 | if args.proof:
347 | items['proof'] = [formatted_items(row[problem_column_name], "proof") for _, row in df.iterrows()]
348 | if args.yes_no:
349 | items['yes_no'] = [formatted_items(row[problem_column_name], "yes_no") for _, row in df.iterrows()]
350 | if args.true_false:
351 | items['true_false'] = [formatted_items(row[problem_column_name], "true_false") for _, row in df.iterrows()]
352 | if args.multiple_part:
353 | items['multiple_part'] = [formatted_items(row[problem_column_name], "multiple_part") for _, row in df.iterrows()]
354 |
355 |
356 | # Process items in batches
357 | total_items = len(df)
358 | count = 0
359 | with SGLangServerManager(model_name, tp=args.tp) as server_handler:
360 | for idx in tqdm(range(0, total_items, batch_size)):
361 | print(f"Processing indices {idx}:{idx+batch_size}...")
362 | for filter_type in items.keys():
363 | batch_items = items[filter_type][idx:idx+batch_size]
364 | batch_outputs = await server_handler.get_chat_responses(
365 | batch_items,
366 | n=llm_params.n,
367 | top_p=llm_params.top_p,
368 | temperature=llm_params.temperature,
369 | max_tokens=llm_params.max_tokens
370 | )
371 | batch_responses = []
372 | for resp in batch_outputs:
373 | try:
374 | batch_responses.append(resp[-1]["responses"])
375 | except Exception as e:
376 | print(f"Response: {resp}")
377 | traceback_str = traceback.format_exc()
378 | print(f"Error processing response: {traceback_str}")
379 | batch_responses.append([""])
380 |
381 | count += 1
382 | # Assign responses to dataframe
383 | for i, response_list in enumerate(batch_responses):
384 | if isinstance(response_list, list) and len(response_list) > args.n:
385 | print(f"Response List Error: Length > n: {response_list}")
386 | response_list = response_list[:args.n]
387 | # if processing a single output (this is the usual case), extract the response from the list
388 | if args.n == 1:
389 | response_list = response_list[0]
390 | df.at[idx+i, f"{args.output_model_name}_{filter_type}"] = response_list
391 |
392 | # Save checkpoint
393 | if count % save_interval == 0:
394 | try:
395 | df.iloc[:idx+batch_size].to_parquet(
396 | os.path.join(save_folder, f"{save_name}_{count}_batch.parquet")
397 | )
398 | ds = Dataset.from_pandas(df)
399 | ds.push_to_hub(hf_save_dataset_name, private=True)
400 | except Exception as e:
401 | print(f"Error saving checkpoint: {e}")
402 | # Save final results
403 | try:
404 | df.to_parquet(os.path.join(save_folder, f"{save_name}.parquet"))
405 | ds = Dataset.from_pandas(df)
406 | ds.push_to_hub(hf_save_dataset_name, private=True)
407 | print(f"Saved to {os.path.join(save_folder, f'{save_name}.parquet')}")
408 | except Exception as e:
409 | print(f"Error saving final results: {e}")
410 |
411 | if __name__ == "__main__":
412 | parser = argparse.ArgumentParser()
413 | # Model and dataset configuration
414 | parser.add_argument("--model_name", type=str, required=True)
415 | parser.add_argument("--output_model_name", type=str, required=True,
416 | help="Name to be prepended to new column names.")
417 | parser.add_argument("--dataset_name", type=str, required=True)
418 | parser.add_argument("--dataset_name_outputs", type=str,
419 | help="To save the outputs to a different HF dataset, specify here.")
420 | parser.add_argument("--dataset_split", type=str, default="train")
421 | parser.add_argument("--problem_column_name", type=str, default="problem")
422 |
423 | # Filters configuration
424 | parser.add_argument("--multiple_choice", action="store_true")
425 | parser.add_argument("--proof", action="store_true")
426 | parser.add_argument("--yes_no", action="store_true")
427 | parser.add_argument("--true_false", action="store_true")
428 | parser.add_argument("--multiple_part", action="store_true")
429 |
430 | # Save configuration
431 | parser.add_argument("--save_folder", type=str, required=True)
432 | parser.add_argument("--save_name", type=str, required=True)
433 | parser.add_argument("--save_interval", type=int, default=10000,
434 | help="Save every n batches.")
435 |
436 | # SGLang server configuration
437 | parser.add_argument("--tp", type=int, default=1)
438 | parser.add_argument("--batch_size", type=int, default=10000,
439 | help="Total batch size will be args.batch_size * args.n.")
440 |
441 | # LLM Hyperparameters
442 | parser.add_argument("--top_k", type=int, default=-1)
443 | parser.add_argument("--top_p", type=float, default=0.95)
444 | parser.add_argument("--temperature", type=float, default=0.2)
445 | parser.add_argument("--max_tokens", type=int, default=5)
446 | parser.add_argument("--n", type=int, default=1)
447 |
448 | # dataset slicing
449 | parser.add_argument("--start", type=int, default=0)
450 | parser.add_argument("--end", type=int, default=-1)
451 | args = parser.parse_args()
452 |
453 | asyncio.run(main())
454 |
455 | # Example Usage:
456 | # python3 model_based_filters.py --model_name "meta-llama/Meta-Llama-3.1-70B-Instruct" --output_model_name "llama3-70b" --dataset_name "RLAIF/INTERNAL-ONLY-Big-Math-RL-Verified-MC-Rewrites" --dataset_name_outputs "RLAIF/model_filter_testing" --save_folder "outputs" --save_name "model_based_filters" --multiple_choice --proof --yes_no --true_false --multiple_part --tp 8 > model_based_filter_outputs.txt 2>&1
--------------------------------------------------------------------------------
/signals/requirements.txt:
--------------------------------------------------------------------------------
1 | asynciolimiter==1.0.0
2 | datasets==3.0.1
3 | faiss_cpu==1.9.0.post1
4 | fasttext==0.9.3
5 | huggingface_hub==0.26.1
6 | loguru==0.7.2
7 | nltk==3.9.1
8 | openai==1.63.2
9 | pandas==2.2.3
10 | pydantic==2.10.6
11 | sympy==1.13.3
12 | tenacity==9.0.0
13 | torch==2.4.0
14 | tqdm==4.66.5
15 | transformers==4.45.2
16 |
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/rollouts_based_signals/__init__.py
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/evaluate_responses.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from datasets import load_dataset
4 | from functools import partial
5 | from tqdm.asyncio import tqdm_asyncio
6 | from typing import Union
7 |
8 | from math_eval import MathEvaluator, is_correct_no_judge, get_answer_expr
9 |
10 | async def rate_limit_is_correct(evaluator: MathEvaluator, answer: str, pred: str, sem: asyncio.Semaphore):
11 | """
12 | Asynchronously checks if the predicted answer is correct, with rate limiting.
13 | """
14 | async with sem:
15 | return await evaluator.is_correct(answer, pred)
16 |
17 | async def evaluate_preds_async(preds: Union[list[list],list], ground_truth_answers: list, evaluator: MathEvaluator, sem: asyncio.Semaphore = None):
18 | """
19 | Asynchronously evaluates predictions against ground truth answers using a provided evaluator.
20 |
21 | Args:
22 | preds (Union[list[list], list]): A list of predictions or a list of lists of predictions.
23 | ground_truth_answers (list): A list of ground truth answers corresponding to the predictions.
24 | evaluator (MathEvaluator): An instance of MathEvaluator used to evaluate the predictions.
25 | sem (asyncio.Semaphore, optional): An optional semaphore to limit the number of concurrent evaluations. Defaults to None.
26 |
27 | Returns:
28 | list: A list of evaluation results.
29 |
30 | """
31 | print("Evaluating predictions...")
32 | if isinstance(preds[0], list):
33 | tasks = []
34 | for pred, answer in zip(preds, ground_truth_answers):
35 | tasks.extend([asyncio.create_task(rate_limit_is_correct(evaluator, answer, p, sem)) for p in pred])
36 | else:
37 | tasks = [asyncio.create_task(rate_limit_is_correct(evaluator, answer, pred, sem)) for pred, answer in zip(preds, ground_truth_answers)]
38 | result = await tqdm_asyncio.gather(*tasks)
39 | return result
40 |
41 | def extract_answers(row, evaluator: MathEvaluator, response_column_name: str, predicted_answer_column_name: str):
42 | """Extracts answers from the response column and adds them as a separate column."""
43 | if isinstance(row[response_column_name], list):
44 | row[predicted_answer_column_name] = [evaluator.get_answer_expr(p) for p in row[response_column_name]]
45 | else:
46 | row[predicted_answer_column_name] = evaluator.get_answer_expr(row[response_column_name])
47 | return row
48 |
49 | def extract_and_evaluate_answers(
50 | row,
51 | response_column_name: str,
52 | predicted_answer_column_name: str,
53 | predicted_answer_correctness_column_name: str,
54 | ground_truth_answer_column_name: str,
55 | ):
56 | """Extracts answers from the response column and evaluates correctness."""
57 | if isinstance(row[response_column_name], list):
58 | row[predicted_answer_column_name] = [get_answer_expr(p) for p in row[response_column_name]]
59 | row[predicted_answer_correctness_column_name] = [is_correct_no_judge(row[ground_truth_answer_column_name], p) for p in row[predicted_answer_column_name]]
60 | else:
61 | row[predicted_answer_column_name] = get_answer_expr(row[response_column_name])
62 | row[predicted_answer_correctness_column_name] = is_correct_no_judge(row[ground_truth_answer_column_name], row[predicted_answer_column_name])
63 | return row
64 |
65 | async def async_main(args):
66 | # define the column names
67 |
68 | response_column_name = args.response_column_name
69 | ground_truth_answer_column_name = args.ground_truth_answer_column_name
70 |
71 | # define the answer and correctness columns
72 | predicted_answer_column_name = f"{args.response_column_name}_extracted_answers"
73 | predicted_answer_correctness_column_name = f"{args.response_column_name}_correctness"
74 |
75 | # load the dataset
76 | preds_dataset = load_dataset(args.predictions_dataset_name, split=args.dataset_split)
77 |
78 | # Make sure we can fulfill the evaluation requirements
79 | if args.calculate_greedy_accuracy:
80 | assert "greedy_correct" in preds_dataset.column_names, "Predictions dataset does not contain greedy_correct column"
81 |
82 | evaluator = MathEvaluator()
83 |
84 | # Extract the answers from responses
85 | if predicted_answer_column_name in preds_dataset.column_names or predicted_answer_correctness_column_name in preds_dataset.column_names:
86 | # if answers are already extracted, or correctness is already calculated, skip extraction
87 | print("Answers already extracted. Skipping extraction")
88 | else:
89 | # extract the answers from the response column and add them as a separate column
90 | print(f"Extracting answers to {predicted_answer_column_name}...")
91 | partial_extract_answers = partial(
92 | extract_answers,
93 | evaluator=evaluator,
94 | response_column_name=response_column_name,
95 | predicted_answer_column_name=predicted_answer_column_name
96 | )
97 | # TODO: Rewrite the math_evaluator.is_correct() function so that it can live outside of the MathEvaluator object
98 | # otherwise, this can only run with num_proc=1
99 | # HuggingFace datasets cannot pickle the MathEvaluator object
100 | # returns error: "TypeError: cannot pickle 'SSLContext' object"
101 | preds_dataset = preds_dataset.map(partial_extract_answers)
102 |
103 | if predicted_answer_correctness_column_name in preds_dataset.column_names:
104 | # if correctness is already calculated, skip evaluation
105 | print("Correctness already calculated. Skipping evaluation")
106 | else:
107 | # calculate the correctness of predicted answers
108 | print(f"Evaluating correctness of {predicted_answer_column_name}...")
109 | sem = asyncio.Semaphore(args.semaphore_limit)
110 | answer_correctness = await evaluate_preds_async(
111 | preds=preds_dataset[predicted_answer_column_name],
112 | ground_truth_answers=preds_dataset[ground_truth_answer_column_name],
113 | evaluator=evaluator,
114 | sem=sem
115 | )
116 |
117 | # reshape the correctness list to match the shape of the predictions
118 | if isinstance(preds_dataset[predicted_answer_column_name][0], list):
119 | reshaped_answer_correctness = []
120 | cur_start = 0
121 | for i in range(len(preds_dataset[predicted_answer_column_name])):
122 | cur_len = len(preds_dataset[predicted_answer_column_name][i])
123 | reshaped_answer_correctness.append(answer_correctness[cur_start:cur_start+cur_len])
124 | cur_start += cur_len
125 | answer_correctness = reshaped_answer_correctness
126 |
127 | # add the correctness column to the dataset
128 | preds_dataset = preds_dataset.add_column(predicted_answer_correctness_column_name, answer_correctness)
129 |
130 | # push the dataset back to the hub
131 | preds_dataset.push_to_hub(args.predictions_dataset_name, private=True)
132 |
133 | def main(args):
134 | # define the column names
135 |
136 | response_column_name = args.response_column_name
137 | ground_truth_answer_column_name = args.ground_truth_answer_column_name
138 |
139 | # define the answer and correctness columns
140 | predicted_answer_column_name = f"{args.response_column_name}_extracted_answers"
141 | predicted_answer_correctness_column_name = f"{args.response_column_name}_correctness"
142 |
143 | # load the dataset
144 | preds_dataset = load_dataset(args.predictions_dataset_name, split=args.dataset_split)
145 |
146 | # Make sure we can fulfill the evaluation requirements
147 | if args.calculate_greedy_accuracy:
148 | assert "greedy_correct" in preds_dataset.column_names, "Predictions dataset does not contain greedy_correct column"
149 |
150 | # Extract answers and evaluate correctness in a single step
151 | print(f"Extracting and evaluating correctness of {predicted_answer_column_name}...")
152 | partial_extract_and_evaluate_answers = partial(
153 | extract_and_evaluate_answers,
154 | response_column_name=response_column_name,
155 | predicted_answer_column_name=predicted_answer_column_name,
156 | predicted_answer_correctness_column_name=predicted_answer_correctness_column_name,
157 | ground_truth_answer_column_name=ground_truth_answer_column_name
158 | )
159 |
160 | preds_dataset = preds_dataset.map(partial_extract_and_evaluate_answers, num_proc=64)
161 |
162 | # push the dataset back to the hub
163 | preds_dataset.push_to_hub(args.predictions_dataset_name, private=True)
164 |
165 | if __name__ == "__main__":
166 | parser = argparse.ArgumentParser(description="Evaluate rollouts")
167 | # Model and dataset arguments
168 | parser.add_argument("--predictions_dataset_name", type=str, required=True,
169 | help="The name of the dataset containing the predictions.")
170 | parser.add_argument("--response_column_name", type=str, required=True,
171 | help="The name of the column containing the model's responses.")
172 | parser.add_argument("--ground_truth_answer_column_name", type=str, required=True,
173 | help="The name of the column containing the ground truth answers.")
174 | parser.add_argument("--dataset_split", type=str, default="train")
175 | parser.add_argument("--calculate_greedy_accuracy", action="store_true")
176 |
177 | # Evaluation config
178 | parser.add_argument("--use_llm_judge_backup", action="store_true",
179 | help="Whether to use the LLM judge as a backup for evaluation.")
180 |
181 | # system configuration
182 | parser.add_argument("--num_proc", type=int, default=1, help="Number of processes to use for data processing.")
183 | parser.add_argument("--semaphore_limit", type=int, default=20,
184 | help="The maximum number of concurrent requests to the evaluator, when using LLM judge as backup.")
185 |
186 | args = parser.parse_args()
187 |
188 | if args.use_llm_judge_backup:
189 | asyncio.run(async_main(args))
190 | else:
191 | main(args)
192 |
193 | # Example usage:
194 | # No LLM-judge
195 | # python3 evaluate_responses.py --predictions_dataset_name RLAIF/Big-Math-needs-llama3-8b-rollouts --response_column_name responses --ground_truth_answer_column_name final_answer
196 | # With LLM-judge
197 | # python3 evaluate_responses.py --predictions_dataset_name RLAIF/Big-Math-needs-llama3-8b-rollouts --response_column_name responses --ground_truth_answer_column_name final_answer --use_llm_judge_backup
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/example_solve_rate_script.sh:
--------------------------------------------------------------------------------
1 | # define some variables
2 |
3 | MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
4 |
5 | # path to the dataset to generate rollouts on
6 | EVAL_DATASET_PATH="SynthLabsAI/Big-Math-RL-Verified"
7 |
8 | # path to save the updated dataset to: change this to a valid path
9 | SAVE_DATASET_PATH=""
10 |
11 | # local path to save intermediate results to
12 | INTERMEDIATE_STEPS_SAVE_DIR="llama8b_solve_rate"
13 |
14 | # column to save the models responses to
15 | RESPONSE_COLUMN_NAME="llama8b_response"
16 |
17 | # set the number of processes to use
18 | HF_NUM_PROC=16
19 |
20 | # make sure the user changed the SAVE_DATASET_PATH
21 | if [ "${SAVE_DATASET_PATH}" == "" ]; then
22 | echo "Please change the SAVE_DATASET_PATH variable to a valid path"
23 | exit 1
24 | fi
25 |
26 | # first, we need to run the inference script
27 | echo "Running inference on ${MODEL} on ${EVAL_DATASET_PATH} and saving to ${SAVE_DATASET_PATH}"
28 | mkdir -p ${INTERMEDIATE_STEPS_SAVE_DIR}
29 | python3 sample_from_model.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --dataset_name ${EVAL_DATASET_PATH} --dataset_name_outputs ${SAVE_DATASET_PATH} --response_column_name ${RESPONSE_COLUMN_NAME} --save_folder ${INTERMEDIATE_STEPS_SAVE_DIR} --save_name llama8b_inference --max_tokens 2048 --n 64
30 |
31 | # then, we evaluate the response correctness
32 | echo "Evaluating responses"
33 | python3 evaluate_responses.py --predictions_dataset_name ${SAVE_DATASET_PATH} --response_column_name ${RESPONSE_COLUMN_NAME} --ground_truth_answer_column_name final_answer --num_proc ${HF_NUM_PROC} --save_folder ${INTERMEDIATE_STEPS_SAVE_DIR} --save_name llama8b_eval_responses
34 |
35 | echo "Done!"
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/math_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | # math_eval.py
3 | """
4 |
5 | # TODO: replace is_equiv and sympy_match with math-verify (https://github.com/huggingface/Math-Verify)
6 |
7 |
8 | import re
9 | import signal
10 | import traceback
11 | from typing import AsyncIterator, Dict, List, Optional
12 |
13 | import sympy
14 | from asynciolimiter import StrictLimiter
15 | from loguru import logger
16 | from openai import AsyncOpenAI
17 | from sympy.parsing.latex import parse_latex
18 | from tenacity import AsyncRetrying, stop_after_attempt, wait_exponential
19 |
20 |
21 | def last_boxed_only_string(string: str) -> Optional[str]:
22 | """
23 | Extracts the last boxed expression from a string.
24 |
25 | Args:
26 | string (str): The input string.
27 |
28 | Returns:
29 | Optional[str]: The last boxed expression or None.
30 | """
31 | idx = string.rfind("\\boxed")
32 | if "\\boxed " in string:
33 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
34 | if idx < 0:
35 | idx = string.rfind("\\fbox")
36 | if idx < 0:
37 | return None
38 |
39 | i = idx
40 | right_brace_idx = None
41 | num_left_braces_open = 0
42 | while i < len(string):
43 | if string[i] == "{":
44 | num_left_braces_open += 1
45 | if string[i] == "}":
46 | num_left_braces_open -= 1
47 | if num_left_braces_open == 0:
48 | right_brace_idx = i
49 | break
50 | i += 1
51 |
52 | return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None
53 |
54 | def remove_boxed(s: str) -> Optional[str]:
55 | """
56 | Removes the \boxed or \fbox formatting from a string.
57 |
58 | Args:
59 | s (str): The input string.
60 |
61 | Returns:
62 | Optional[str]: String without boxed formatting or None.
63 | """
64 | # pattern = r"\\boxed\s*{([^}]*)}"
65 | # return re.sub(pattern, r"\1", s, flags=re.DOTALL)
66 |
67 | if "\\boxed " in s:
68 | left = "\\boxed "
69 | assert s[: len(left)] == left
70 | return s[len(left) :]
71 | elif "\\boxed{" in s:
72 | left = "\\boxed{"
73 | assert s[: len(left)] == left
74 | assert s[-1] == "}"
75 | return s[len(left) : -1]
76 | elif "\\fbox{" in s:
77 | left = "\\fbox{"
78 | assert s[: len(left)] == left
79 | assert s[-1] == "}"
80 | return s[len(left) : -1]
81 | else:
82 | return s
83 |
84 | def get_answer_expr(answer: str) -> str:
85 | """
86 | Extracts the mathematical expression from the answer.
87 |
88 | Args:
89 | answer (str): The answer string.
90 |
91 | Returns:
92 | str: Extracted expression.
93 | """
94 | try:
95 | answer = remove_boxed(last_boxed_only_string(answer))
96 | except Exception:
97 | answer = answer.split("\n")[-1]
98 | return answer
99 |
100 | def remove_right_units(string: str) -> str:
101 | """
102 | Removes units described within \\text{ } at the end of the string.
103 |
104 | Args:
105 | string (str): The input string.
106 |
107 | Returns:
108 | str: String without units.
109 | """
110 | if "\\text{ " in string:
111 | splits = string.split("\\text{ ")
112 | if len(splits) == 2:
113 | return splits[0]
114 | return string
115 |
116 | def fix_sqrt(string: str) -> str:
117 | """
118 | Ensures that square roots in the string are properly formatted with braces.
119 |
120 | Args:
121 | string (str): The input string.
122 |
123 | Returns:
124 | str: String with fixed square roots.
125 | """
126 | if "\\sqrt" not in string:
127 | return string
128 | splits = string.split("\\sqrt")
129 | new_string = splits[0]
130 | for split in splits[1:]:
131 | if not split.startswith("{"):
132 | if len(split) < 1:
133 | return string
134 | a = split[0]
135 | new_substr = f"\\sqrt{{{a}}}" + split[1:]
136 | else:
137 | new_substr = "\\sqrt" + split
138 | new_string += new_substr
139 | return new_string
140 |
141 | def fix_fracs(string: str) -> str:
142 | """
143 | Fixes improperly formatted fractions in a LaTeX string.
144 |
145 | Args:
146 | string (str): The input string.
147 |
148 | Returns:
149 | str: String with fixed fractions.
150 | """
151 | substrs = string.split("\\frac")
152 | new_str = substrs[0]
153 | if len(substrs) > 1:
154 | substrs = substrs[1:]
155 | for substr in substrs:
156 | new_str += "\\frac"
157 | if substr.startswith("{"):
158 | new_str += substr
159 | else:
160 | if len(substr) < 2:
161 | return string
162 | a = substr[0]
163 | b = substr[1]
164 | if b != "{":
165 | if len(substr) > 2:
166 | post_substr = substr[2:]
167 | new_str += f"{{{a}}}{{{b}}}{post_substr}"
168 | else:
169 | new_str += f"{{{a}}}{{{b}}}"
170 | else:
171 | if len(substr) > 2:
172 | post_substr = substr[2:]
173 | new_str += f"{{{a}}}{{{b}}}{post_substr}"
174 | else:
175 | new_str += f"{{{a}}}{{{b}}}"
176 | return new_str
177 |
178 | def fix_a_slash_b(string: str) -> str:
179 | """
180 | Converts a simple a/b format to LaTeX fraction if applicable.
181 |
182 | Args:
183 | string (str): The input string.
184 |
185 | Returns:
186 | str: Modified string with fractions fixed.
187 | """
188 | parts = string.split("/")
189 | if len(parts) != 2:
190 | return string
191 | a, b = parts
192 | try:
193 | a = int(a)
194 | b = int(b)
195 | if string == f"{a}/{b}":
196 | return f"\\frac{{{a}}}{{{b}}}"
197 | else:
198 | return string
199 | except ValueError:
200 | return string
201 |
202 | def strip_string(string: str) -> str:
203 | """
204 | Normalizes a LaTeX string by removing unnecessary characters and formatting.
205 |
206 | Args:
207 | string (str): The input string.
208 |
209 | Returns:
210 | str: Normalized string.
211 | """
212 | string = string.replace("\n", "")
213 | string = string.replace("\\!", "")
214 | string = string.replace("\\\\", "\\")
215 | string = string.replace("tfrac", "frac")
216 | string = string.replace("dfrac", "frac")
217 | string = string.replace("\\left", "")
218 | string = string.replace("\\right", "")
219 | string = string.replace("^{\\circ}", "")
220 | string = string.replace("^\\circ", "")
221 | string = string.replace("\\$", "")
222 | string = remove_right_units(string)
223 | string = string.replace("\\%", "")
224 | string = string.replace(r"\%", "")
225 | string = string.replace(" .", " 0.")
226 | string = string.replace("{.", "{0.")
227 |
228 | if len(string) == 0:
229 | return string
230 | if string[0] == ".":
231 | string = "0" + string
232 |
233 | if len(string.split("=")) == 2:
234 | if len(string.split("=")[0]) <= 2:
235 | string = string.split("=")[1]
236 |
237 | string = fix_sqrt(string)
238 | string = string.replace(" ", "")
239 | string = fix_fracs(string)
240 |
241 | if string == "0.5":
242 | string = "\\frac{1}{2}"
243 |
244 | string = fix_a_slash_b(string)
245 |
246 | return string
247 |
248 | def is_equiv(str1: Optional[str], str2: Optional[str], verbose: bool = False) -> bool:
249 | """
250 | Checks if two strings are equivalent after normalization.
251 |
252 | Args:
253 | str1 (Optional[str]): First string.
254 | str2 (Optional[str]): Second string.
255 | verbose (bool): If True, prints the normalized strings.
256 |
257 | Returns:
258 | bool: True if equivalent, False otherwise.
259 | """
260 | if str1 is None and str2 is None:
261 | print("WARNING: Both None")
262 | return True
263 | if str1 is None or str2 is None:
264 | return False
265 |
266 | try:
267 | ss1 = strip_string(str1)
268 | ss2 = strip_string(str2)
269 | if verbose:
270 | print(ss1, ss2)
271 | return ss1 == ss2
272 | except Exception:
273 | return str1 == str2
274 |
275 | def sympy_match(str1: str, str2: str) -> bool:
276 | """
277 | Checks if two mathematical expressions are equivalent using SymPy.
278 | Times out after 3 seconds to prevent hanging on complex expressions.
279 |
280 | Args:
281 | str1 (str): First expression.
282 | str2 (str): Second expression.
283 |
284 | Returns:
285 | bool: True if equivalent, False otherwise.
286 | """
287 | def timeout_handler(signum, frame):
288 | raise TimeoutError("Parsing took too long")
289 |
290 | try:
291 | # Set timeout of 3 seconds
292 | signal.signal(signal.SIGALRM, timeout_handler)
293 | signal.alarm(3)
294 |
295 | expr1 = parse_latex(str1)
296 | expr2 = parse_latex(str2)
297 | diff = sympy.simplify(expr1 - expr2)
298 |
299 | # Disable alarm
300 | signal.alarm(0)
301 | return diff == 0
302 | except (Exception, TimeoutError):
303 | # Disable alarm
304 | signal.alarm(0)
305 | # print(f"Error in sympy_match. str1: {str1} ||| str2: {str2}")
306 | return False
307 |
308 | def is_correct_no_judge(correct_answer: str, proposed_answer: str) -> bool:
309 | """
310 | checks if the provided answer is correct.
311 | """
312 | extracted_answer = get_answer_expr(proposed_answer)
313 |
314 | # If the extracted answer is empty, it's not correct.
315 | if extracted_answer.strip() == "":
316 | return False
317 |
318 | # Three-stage process based off of https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/eval_details.md#math
319 | if is_equiv(extracted_answer, correct_answer):
320 | return True
321 | elif sympy_match(extracted_answer, correct_answer):
322 | return True
323 | else:
324 | return False
325 |
326 | class MathEvaluator:
327 | """
328 | either pass in an openai api key, or set the OPENAI_API_KEY environment variable.
329 | export OPENAI_API_KEY="sk-proj-..."
330 |
331 | usage:
332 | ```python
333 | evaluator = MathEvaluator()
334 | result = await evaluator.is_correct(correct_answer="4", proposed_answer="4")
335 | assert result
336 | ```
337 | """
338 |
339 | def __init__(
340 | self,
341 | model_name: str = "gpt-4o-mini",
342 | rate_limit: float = 10000 / 60,
343 | api_key: Optional[str] = None,
344 | ):
345 | """
346 | Initializes the MathEvaluator with dataset paths, OpenAI client, and processing configurations.
347 |
348 | Args:
349 | train_data_path (str): Path to the training dataset.
350 | test_data_path (str): Path to the testing dataset.
351 | model_name (str): OpenAI model name to use for judging equality.
352 | rate_limit (float): Rate limit for asynchronous operations.
353 | api_key (Optional[str]): OpenAI API key. If None, it will use the OPENAI_API_KEY environment variable.
354 | """
355 | self.model_name = model_name
356 | self.rate_limiter = StrictLimiter(rate_limit)
357 | self.openai_client = AsyncOpenAI(api_key=api_key)
358 |
359 | async def judge_equality(self, expr1: str, expr2: str) -> bool:
360 | """
361 | Determines if two mathematical expressions are equivalent using the OpenAI client.
362 |
363 | Args:
364 | expr1 (str): Generated answer.
365 | expr2 (str): True answer.
366 |
367 | Returns:
368 | bool: True if equivalent, False otherwise.
369 | """
370 | EQUALITY_TEMPLATE = """
371 | Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
372 |
373 | Examples:
374 |
375 | Expression 1: $2x+3$
376 | Expression 2: $3+2x$
377 |
378 | Yes
379 |
380 | Expression 1: 3/2
381 | Expression 2: 1.5
382 |
383 | Yes
384 |
385 | Expression 1: $x^2+2x+1$
386 | Expression 2: $y^2+2y+1$
387 |
388 | No
389 |
390 | Expression 1: $x^2+2x+1$
391 | Expression 2: $(x+1)^2$
392 |
393 | Yes
394 |
395 | Expression 1: 3245/5
396 | Expression 2: 649
397 |
398 | No
399 | (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
400 |
401 | Expression 1: 2/(-3)
402 | Expression 2: -2/3
403 |
404 | Yes
405 | (trivial simplifications are allowed)
406 |
407 | Expression 1: 72 degrees
408 | Expression 2: 72
409 |
410 | Yes
411 | (give benefit of the doubt to units)
412 |
413 | Expression 1: 64
414 | Expression 2: 64 square feet
415 |
416 | Yes
417 | (give benefit of the doubt to units)
418 |
419 | Expression 1: 2^{{n-1}} - n
420 | Expression 2: 4 - \\frac{{n+2}}{{2^{{n-1}}}}
421 |
422 | No
423 |
424 | Expression 1: 8
425 | Expression 2: Therefore, there are 8 cats.
426 |
427 | Yes
428 | (simple conclusion sentences giving an answer are allowed)
429 |
430 | Expression 1: 3n^2
431 | Expression 2: a_n = 3n^2
432 |
433 | Yes
434 | (variable names are allowed)
435 |
436 | Expression 1: a=3, b=4, e=7
437 | Expression 2: {{a=3, b=4, e=7}}
438 |
439 | Yes
440 |
441 | Expression 1: 453.6235
442 | Expression 2: 454.0231
443 |
444 | No
445 | (approximately equal is not equivalent)
446 |
447 | Expression 1: 1/3
448 | Expression 2: So we have that $A$
449 |
450 | No
451 | (anything that appears cut off or nonsensical is not equivalent)
452 |
453 | ---
454 |
455 | YOUR TASK
456 |
457 |
458 | Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
459 |
460 | Expression 1: {expr1}
461 | Expression 2: {expr2}
462 | """.strip()
463 | prompt = EQUALITY_TEMPLATE.format(expr1=expr1, expr2=expr2)
464 |
465 | await self.rate_limiter.wait()
466 | try:
467 | response = await self.openai_client.chat.completions.create(
468 | model=self.model_name,
469 | messages=[{"role": "user", "content": prompt}],
470 | max_tokens=20,
471 | n=1,
472 | temperature=0.0,
473 | )
474 | result = response.choices[0].message.content.strip()
475 | return result.lower().strip() == "yes"
476 | except Exception as e:
477 | logger.error(f"error in judge_equality: {e}")
478 | traceback.print_exc()
479 | return False
480 |
481 | async def is_correct(self, correct_answer: str, proposed_answer: str) -> bool:
482 | """
483 | checks if the provided answer is correct.
484 | """
485 | async for attempt in AsyncRetrying(
486 | stop=stop_after_attempt(3),
487 | wait=wait_exponential(multiplier=1, min=4, max=120),
488 | reraise=True,
489 | ):
490 | with attempt:
491 | extracted_answer = self.get_answer_expr(proposed_answer)
492 |
493 | # If the extracted answer is empty, it's not correct.
494 | if extracted_answer.strip() == "":
495 | return False
496 |
497 | # Three-stage process based off of https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/eval_details.md#math
498 | if self.is_equiv(extracted_answer, correct_answer):
499 | return True
500 | elif self.sympy_match(extracted_answer, correct_answer):
501 | return True
502 | else:
503 | return await self.judge_equality(extracted_answer, correct_answer)
504 |
505 | async def is_correct_anywhere(self, correct_answer: str, proposed_answer: str) -> bool:
506 | """
507 | checks if the correct answer appears anywhere in the proposed answer.
508 | """
509 | if await self.is_correct(correct_answer, proposed_answer):
510 | return True
511 |
512 | boxed_expressions = self.extract_boxed_expressions(proposed_answer)
513 |
514 | for expr in boxed_expressions:
515 | extracted_answer = self.remove_boxed(expr)
516 | if self.is_equiv(extracted_answer, correct_answer):
517 | return True
518 | elif self.sympy_match(extracted_answer, correct_answer):
519 | return True
520 | elif await self.judge_equality(extracted_answer, correct_answer):
521 | return True
522 |
523 | return False
524 |
525 | async def __call__(self, split: str) -> AsyncIterator[Dict]:
526 | """
527 | Allows the MathEvaluator to be called as an async generator.
528 |
529 | Args:
530 | split (str): The dataset split to use ('train' or 'test').
531 |
532 | Yields:
533 | Dict: The next item in the dataset.
534 |
535 | Raises:
536 | ValueError: If an invalid split is provided.
537 | """
538 | if split == "train":
539 | dataset = self.ds_train
540 | elif split == "test":
541 | dataset = self.ds_test
542 | else:
543 | raise ValueError("split must be 'train' or 'test'")
544 |
545 | for item in dataset:
546 | yield item
547 |
548 | @staticmethod
549 | def has_formatted_answer(answer: str) -> bool:
550 | """
551 | Checks if the answer contains a formatted solution.
552 |
553 | Args:
554 | answer (str): The answer string.
555 |
556 | Returns:
557 | bool: True if formatted answer exists, False otherwise.
558 | """
559 | try:
560 | if MathEvaluator.remove_boxed(MathEvaluator.last_boxed_only_string(answer)):
561 | return True
562 | return False
563 | except Exception:
564 | return False
565 |
566 | @staticmethod
567 | def get_answer_expr(answer: str) -> str:
568 | """
569 | Extracts the mathematical expression from the answer.
570 |
571 | Args:
572 | answer (str): The answer string.
573 |
574 | Returns:
575 | str: Extracted expression.
576 | """
577 | try:
578 | answer = MathEvaluator.remove_boxed(MathEvaluator.last_boxed_only_string(answer))
579 | except Exception:
580 | answer = answer.split("\n")[-1]
581 | return answer
582 |
583 | @staticmethod
584 | def extract_boxed_expressions(string: str) -> List[str]:
585 | """
586 | extracts all \boxed{...} and \boxed ... expressions from the string.
587 | """
588 | boxed_expressions = []
589 |
590 | pattern_braces = r"\\boxed\s*\{([^}]*)\}"
591 | boxed_expressions += re.findall(pattern_braces, string)
592 |
593 | pattern_space = r"\\boxed\s+([^\s\$]+)"
594 | boxed_expressions += re.findall(pattern_space, string)
595 |
596 | return ["\\boxed{" + expr + "}" for expr in boxed_expressions]
597 |
598 | @staticmethod
599 | def remove_boxed(s: str) -> Optional[str]:
600 | """
601 | Removes the \boxed or \fbox formatting from a string.
602 |
603 | Args:
604 | s (str): The input string.
605 |
606 | Returns:
607 | Optional[str]: String without boxed formatting or None.
608 | """
609 | # pattern = r"\\boxed\s*{([^}]*)}"
610 | # return re.sub(pattern, r"\1", s, flags=re.DOTALL)
611 |
612 | if "\\boxed " in s:
613 | left = "\\boxed "
614 | assert s[: len(left)] == left
615 | return s[len(left) :]
616 | elif "\\boxed{" in s:
617 | left = "\\boxed{"
618 | assert s[: len(left)] == left
619 | assert s[-1] == "}"
620 | return s[len(left) : -1]
621 | elif "\\fbox{" in s:
622 | left = "\\fbox{"
623 | assert s[: len(left)] == left
624 | assert s[-1] == "}"
625 | return s[len(left) : -1]
626 | else:
627 | return s
628 |
629 | @staticmethod
630 | def last_boxed_only_string(string: str) -> Optional[str]:
631 | """
632 | Extracts the last boxed expression from a string.
633 |
634 | Args:
635 | string (str): The input string.
636 |
637 | Returns:
638 | Optional[str]: The last boxed expression or None.
639 | """
640 | idx = string.rfind("\\boxed")
641 | if "\\boxed " in string:
642 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
643 | if idx < 0:
644 | idx = string.rfind("\\fbox")
645 | if idx < 0:
646 | return None
647 |
648 | i = idx
649 | right_brace_idx = None
650 | num_left_braces_open = 0
651 | while i < len(string):
652 | if string[i] == "{":
653 | num_left_braces_open += 1
654 | if string[i] == "}":
655 | num_left_braces_open -= 1
656 | if num_left_braces_open == 0:
657 | right_brace_idx = i
658 | break
659 | i += 1
660 |
661 | return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None
662 |
663 | @staticmethod
664 | def all_boxed_strings(string: str) -> List[str]:
665 | """
666 | Extracts all boxed expressions from a string in order of appearance.
667 |
668 | Args:
669 | string (str): The input string.
670 |
671 | Returns:
672 | List[str]: List of all boxed expressions in order of appearance.
673 | """
674 | results = []
675 | i = 0
676 |
677 | while i < len(string):
678 | # Find next occurrence of either \boxed or \fbox
679 | boxed_idx = string.find("\\boxed", i)
680 | fbox_idx = string.find("\\fbox", i)
681 |
682 | # Determine which comes first
683 | if boxed_idx == -1 and fbox_idx == -1:
684 | break
685 | elif boxed_idx == -1:
686 | idx = fbox_idx
687 | is_boxed = False
688 | elif fbox_idx == -1:
689 | idx = boxed_idx
690 | is_boxed = True
691 | else:
692 | if boxed_idx < fbox_idx:
693 | idx = boxed_idx
694 | is_boxed = True
695 | else:
696 | idx = fbox_idx
697 | is_boxed = False
698 |
699 | if is_boxed and idx + 6 < len(string) and string[idx:idx+6] == "\\boxed ":
700 | # Handle \boxed space case
701 | expr = "\\boxed " + string[idx+6:].split("$")[0].split()[0]
702 | results.append(expr)
703 | i = idx + len(expr)
704 | else:
705 | # Handle \boxed{...} or \fbox{...} case
706 | j = idx
707 | right_brace_idx = None
708 | num_left_braces_open = 0
709 | while j < len(string):
710 | if string[j] == "{":
711 | num_left_braces_open += 1
712 | if string[j] == "}":
713 | num_left_braces_open -= 1
714 | if num_left_braces_open == 0:
715 | right_brace_idx = j
716 | break
717 | j += 1
718 |
719 | if right_brace_idx is not None:
720 | results.append(string[idx:right_brace_idx + 1])
721 | i = right_brace_idx + 1
722 | else:
723 | i = idx + 1
724 |
725 | return results
726 |
727 | @staticmethod
728 | def is_equiv(str1: Optional[str], str2: Optional[str], verbose: bool = False) -> bool:
729 | """
730 | Checks if two strings are equivalent after normalization.
731 |
732 | Args:
733 | str1 (Optional[str]): First string.
734 | str2 (Optional[str]): Second string.
735 | verbose (bool): If True, prints the normalized strings.
736 |
737 | Returns:
738 | bool: True if equivalent, False otherwise.
739 | """
740 | if str1 is None and str2 is None:
741 | print("WARNING: Both None")
742 | return True
743 | if str1 is None or str2 is None:
744 | return False
745 |
746 | try:
747 | ss1 = MathEvaluator.strip_string(str1)
748 | ss2 = MathEvaluator.strip_string(str2)
749 | if verbose:
750 | print(ss1, ss2)
751 | return ss1 == ss2
752 | except Exception:
753 | return str1 == str2
754 |
755 | @staticmethod
756 | def sympy_match(str1: str, str2: str) -> bool:
757 | """
758 | Checks if two mathematical expressions are equivalent using SymPy.
759 | Times out after 3 seconds to prevent hanging on complex expressions.
760 |
761 | Args:
762 | str1 (str): First expression.
763 | str2 (str): Second expression.
764 |
765 | Returns:
766 | bool: True if equivalent, False otherwise.
767 | """
768 | def timeout_handler(signum, frame):
769 | raise TimeoutError("Parsing took too long")
770 |
771 | try:
772 | # Set timeout of 3 seconds
773 | signal.signal(signal.SIGALRM, timeout_handler)
774 | signal.alarm(3)
775 |
776 | expr1 = parse_latex(str1)
777 | expr2 = parse_latex(str2)
778 | diff = sympy.simplify(expr1 - expr2)
779 |
780 | # Disable alarm
781 | signal.alarm(0)
782 | return diff == 0
783 | except (Exception, TimeoutError):
784 | # Disable alarm
785 | signal.alarm(0)
786 | # print(f"Error in sympy_match. str1: {str1} ||| str2: {str2}")
787 | return False
788 |
789 | @staticmethod
790 | def strip_string(string: str) -> str:
791 | """
792 | Normalizes a LaTeX string by removing unnecessary characters and formatting.
793 |
794 | Args:
795 | string (str): The input string.
796 |
797 | Returns:
798 | str: Normalized string.
799 | """
800 | string = string.replace("\n", "")
801 | string = string.replace("\\!", "")
802 | string = string.replace("\\\\", "\\")
803 | string = string.replace("tfrac", "frac")
804 | string = string.replace("dfrac", "frac")
805 | string = string.replace("\\left", "")
806 | string = string.replace("\\right", "")
807 | string = string.replace("^{\\circ}", "")
808 | string = string.replace("^\\circ", "")
809 | string = string.replace("\\$", "")
810 | string = MathEvaluator.remove_right_units(string)
811 | string = string.replace("\\%", "")
812 | string = string.replace(r"\%", "")
813 | string = string.replace(" .", " 0.")
814 | string = string.replace("{.", "{0.")
815 |
816 | if len(string) == 0:
817 | return string
818 | if string[0] == ".":
819 | string = "0" + string
820 |
821 | if len(string.split("=")) == 2:
822 | if len(string.split("=")[0]) <= 2:
823 | string = string.split("=")[1]
824 |
825 | string = MathEvaluator.fix_sqrt(string)
826 | string = string.replace(" ", "")
827 | string = MathEvaluator.fix_fracs(string)
828 |
829 | if string == "0.5":
830 | string = "\\frac{1}{2}"
831 |
832 | string = MathEvaluator.fix_a_slash_b(string)
833 |
834 | return string
835 |
836 | @staticmethod
837 | def fix_fracs(string: str) -> str:
838 | """
839 | Fixes improperly formatted fractions in a LaTeX string.
840 |
841 | Args:
842 | string (str): The input string.
843 |
844 | Returns:
845 | str: String with fixed fractions.
846 | """
847 | substrs = string.split("\\frac")
848 | new_str = substrs[0]
849 | if len(substrs) > 1:
850 | substrs = substrs[1:]
851 | for substr in substrs:
852 | new_str += "\\frac"
853 | if substr.startswith("{"):
854 | new_str += substr
855 | else:
856 | if len(substr) < 2:
857 | return string
858 | a = substr[0]
859 | b = substr[1]
860 | if b != "{":
861 | if len(substr) > 2:
862 | post_substr = substr[2:]
863 | new_str += f"{{{a}}}{{{b}}}{post_substr}"
864 | else:
865 | new_str += f"{{{a}}}{{{b}}}"
866 | else:
867 | if len(substr) > 2:
868 | post_substr = substr[2:]
869 | new_str += f"{{{a}}}{{{b}}}{post_substr}"
870 | else:
871 | new_str += f"{{{a}}}{{{b}}}"
872 | return new_str
873 |
874 | @staticmethod
875 | def fix_a_slash_b(string: str) -> str:
876 | """
877 | Converts a simple a/b format to LaTeX fraction if applicable.
878 |
879 | Args:
880 | string (str): The input string.
881 |
882 | Returns:
883 | str: Modified string with fractions fixed.
884 | """
885 | parts = string.split("/")
886 | if len(parts) != 2:
887 | return string
888 | a, b = parts
889 | try:
890 | a = int(a)
891 | b = int(b)
892 | if string == f"{a}/{b}":
893 | return f"\\frac{{{a}}}{{{b}}}"
894 | else:
895 | return string
896 | except ValueError:
897 | return string
898 |
899 | @staticmethod
900 | def remove_right_units(string: str) -> str:
901 | """
902 | Removes units described within \\text{ } at the end of the string.
903 |
904 | Args:
905 | string (str): The input string.
906 |
907 | Returns:
908 | str: String without units.
909 | """
910 | if "\\text{ " in string:
911 | splits = string.split("\\text{ ")
912 | if len(splits) == 2:
913 | return splits[0]
914 | return string
915 |
916 | @staticmethod
917 | def fix_sqrt(string: str) -> str:
918 | """
919 | Ensures that square roots in the string are properly formatted with braces.
920 |
921 | Args:
922 | string (str): The input string.
923 |
924 | Returns:
925 | str: String with fixed square roots.
926 | """
927 | if "\\sqrt" not in string:
928 | return string
929 | splits = string.split("\\sqrt")
930 | new_string = splits[0]
931 | for split in splits[1:]:
932 | if not split.startswith("{"):
933 | if len(split) < 1:
934 | return string
935 | a = split[0]
936 | new_substr = f"\\sqrt{{{a}}}" + split[1:]
937 | else:
938 | new_substr = "\\sqrt" + split
939 | new_string += new_substr
940 | return new_string
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/sample_from_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from datasets import Dataset, load_dataset
4 | from enum import Enum
5 | import json
6 | import os
7 | from pydantic import BaseModel, Field, field_validator
8 | import tenacity
9 | import traceback
10 | from tqdm import tqdm
11 |
12 | from utils.sglang_util import SGLangServerManager
13 |
14 | class ModelType(str, Enum):
15 | """supported llm model types"""
16 | Llama3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"
17 | Llama3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
18 | Llama3_1_405B = "meta-llama/Llama-3.1-405B-Instruct-FP8"
19 | CLAUDE_3_5 = "claude-3-5-sonnet-latest"
20 | O1_PREVIEW = "o1-preview"
21 | O1_MINI = "o1-mini"
22 |
23 | class LLMHyperparams(BaseModel):
24 | system_prompt: str = """You are a math expert. Given the following math problem, provide your solution in Latex format. Always format your final answer in perfect LaTeX \\boxed{{final_answer}} format."""
25 | prompt: str = "{problem}"
26 | temperature: float = Field(
27 | default=0.8,
28 | ge=0.0,
29 | le=2.0,
30 | description='Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.'
31 | )
32 | top_k: int = Field(
33 | default=-1,
34 | description='Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.'
35 | )
36 | top_p: float = Field(
37 | default=0.95,
38 | ge=0.0,
39 | le=1.0,
40 | description='Float that controls the cumulative probability of the top tokens to consider.'
41 | )
42 | max_tokens: int = 2048
43 | model_name: ModelType = Field(
44 | description='The model to use for generation.',
45 | )
46 | n: int = Field(
47 | description='Number of samples to generate.'
48 | )
49 |
50 | @field_validator('temperature')
51 | def validate_temperature(cls, v):
52 | if v < 0:
53 | raise ValueError("temp can't be negative")
54 | return v
55 |
56 | def to_json(self) -> str:
57 | return json.dumps(self.model_dump(), indent=2)
58 |
59 | @classmethod
60 | def load_json(cls, json_str: str) -> 'LLMHyperparams':
61 | data = json.loads(json_str)
62 | return cls(**data)
63 |
64 | def construct_few_shot_prompt(problem):
65 | # Updated prompt as per the user's instructions
66 | system_prompt = r"""From this moment forward, adopt the persona of a tenured mathematics professor. You receive math problems and you solve them, step by step, following the formatting instructions below.
67 |
68 | # Mathematical Solution Formatting Guidelines
69 |
70 | ## Mathematical Notation Rules
71 |
72 | **Use $$ $$ for:**
73 | - Standalone equations
74 | - Complex mathematical expressions
75 | - Multi-line equations
76 | - Final answers
77 |
78 | **Use single $ $ for:**
79 | - Simple inline expressions
80 | - Variables
81 | - Numbers with mathematical meaning
82 | - Parts of expressions being discussed
83 |
84 | **For equations with multiple lines:**
85 | `latex
86 | $$\\begin{aligned}
87 | equation1 &= expression1 \\\\
88 | equation2 &= expression2
89 | \\end{aligned}$$
90 | `
91 |
92 | **Place the final answer inside \boxed{answer_here}**
93 | After solving and verifying your full solution, write the final answer in \\boxed{answer} notation.
94 |
95 | ## Thought Structure
96 | Each thought must:
97 | 1. Begin with active voice ("I" statements or direct observations)
98 | 2. Express exactly ONE logical step
99 | 3. Integrate mathematical notation naturally within sentences
100 | 4. Use appropriate notation based on context
101 | 5. No more than 1 sentence + interleaved math notation long
102 |
103 | ### Example:
104 | "I can see that when $x = 2$, the expression $$y = x^2 + 3x + 1$$ becomes $$y = 4 + 6 + 1 = 11$$."
105 |
106 | ## Common Mathematical Patterns
107 | - Use $$ $$ for equations being solved
108 | - Use $ $ for discussing components: "where $m$ is the slope"
109 | - Keep expressions inline when discussing specific parts
110 | - Use block notation ($$) for key steps and results
111 |
112 | ## Visual Formatting
113 | - Two blank lines between thoughts (\n\n)
114 | - No bullet points or numbered lists
115 | - No section headers within the solution
116 | - Mathematical expressions should flow naturally within sentences
117 | - No explicit step labeling
118 | - Human-style reasoning
119 |
120 | ---
121 |
122 | ## Examples
123 |
124 | ### Problem 1:
125 |
126 | A standard die is rolled six times. What is the probability that the product of all six rolls is odd? Express your answer as a common fraction.
127 |
128 | ### Solution 1:
129 |
130 | I need to find the probability by dividing favorable outcomes by total outcomes.
131 |
132 | I know that the total number of possible outcomes when rolling a die six times is $$6^6$$.
133 |
134 | For the product to be odd, each individual roll must be odd.
135 |
136 | Looking at a standard die, the odd numbers are 1, 3, and 5.
137 |
138 | Therefore, for each roll, I have only 3 choices to maintain an odd product.
139 |
140 | The total number of favorable outcomes is $$3^6$$ since I have 3 choices for each of the 6 rolls.
141 |
142 | The probability is thus $$\frac{3^6}{6^6}$$.
143 |
144 | This simplifies to $$\left( \frac{1}{2} \right)^6 = \\boxed{\frac{1}{64}}$$.
145 |
146 | ### Problem 2:
147 |
148 | Solve over the integers: $$2^a+4^b+8^c=328$$
149 |
150 | ### Solution 2:
151 |
152 | I notice that all terms are powers of 2, so I can rewrite everything with base 2.
153 |
154 | I know that $$4 = 2^2$$ and $$8 = 2^3$$, so the equation becomes $$2^a + 2^{2b} + 2^{3c} = 328$$.
155 |
156 | To solve this, I can decompose 328 into powers of 2.
157 |
158 | The largest power of 2 less than 328 is 256 ($$2^8$$).
159 |
160 | Subtracting 256 from 328 leaves 72.
161 |
162 | The largest power of 2 less than 72 is 64 ($$2^6$$).
163 |
164 | Subtracting 64 leaves 8 ($$2^3$$).
165 |
166 | Therefore, $$328 = 2^8 + 2^6 + 2^3$$.
167 |
168 | Comparing terms: $$2^a = 2^8$$, $$2^{2b} = 2^6$$, and $$2^{3c} = 2^3$$.
169 |
170 | Solving these equations: $$a = 8$$, $$b = 3$$, and $$c = 1$$.
171 |
172 | My final answer is \\boxed{(a, b, c) = (8, 3, 1)}
173 |
174 | ### Problem 3:
175 |
176 | Find the equation of the circle which passes through $$(2, 3)$$ and $$(4, 5)$$ and whose center lies on the straight line $$y - 4x + 3 = 0$$.
177 |
178 | ### Solution 3:
179 |
180 | I know that the center lies on the line, so I can write its coordinates as $$(x, 4x - 3)$$.
181 |
182 | The radius can be found using the distance from the center to $$(2, 3)$$.
183 |
184 | Using the distance formula, I get $$r^2 = (x - 2)^2 + (4x - 3 - 3)^2 = (x - 2)^2 + (4x - 6)^2$$.
185 |
186 | Expanding and simplifying: $$r^2 = 17x^2 - 52x + 40$$.
187 |
188 | Since the circle also passes through $$(4, 5)$$, I can write $$r^2 = (x - 4)^2 + (4x - 3 - 5)^2$$.
189 |
190 | This simplifies to $$r^2 = 17x^2 - 72x + 80$$.
191 |
192 | Since both expressions equal $$r^2$$, I can write $$17x^2 - 52x + 40 = 17x^2 - 72x + 80$$.
193 |
194 | I can now simplify: $$20x = 40$$, so $$x = 2$$.
195 |
196 | The center is therefore $$(2, 4(2) - 3) = (2, 5)$$.
197 |
198 | The radius squared is $$r^2 = 17(2)^2 - 52(2) + 40 = 16$$, so $$r = 4$$.
199 |
200 | The equation of the circle is $$(x - 2)^2 + (y - 5)^2 = 16$$.
201 |
202 | The final answer is the full equation of the circle, so the correct final answer is \\boxed{(x - 2)^2 + (y - 5)^2 = 16}
203 |
204 | """
205 |
206 | # Append the real example as a user/assistant message
207 | user_problem = r"Given real numbers \( a, b, c \) and a positive number \( \lambda \) such that the polynomial \( f(x) = x^3 + a x^2 + b x + c \) has three real roots \( x_1, x_2, x_3 \), and the conditions \( x_2 - x_1 = \lambda \) and \( x_3 > \frac{1}{2}(x_1 + x_2) \) are satisfied, find the maximum value of \( \frac{2 a^3 + 27 c - 9 a b}{\lambda^3} \)."
208 |
209 | assistant_solution = r"""
210 | I need to find the maximum value of the given expression, which involves the coefficients of the polynomial and the difference between its roots.
211 |
212 | Since \( x_1, x_2, x_3 \) are the roots of the polynomial \( f(x) = x^3 + a x^2 + b x + c \), by Vieta's formulas, I know that $$x_1 + x_2 + x_3 = -a$$.
213 |
214 | Also from Vieta's formulas, $$x_1 x_2 + x_2 x_3 + x_3 x_1 = b$$ and $$x_1 x_2 x_3 = -c$$.
215 |
216 | Given that \( x_2 - x_1 = \lambda \), I can express \( x_2 \) as \( x_2 = x_1 + \lambda \).
217 |
218 | The condition \( x_3 > \frac{1}{2}(x_1 + x_2) \) implies that \( x_3 > x_1 + \frac{\lambda}{2} \).
219 |
220 | However, to simplify the problem, I will consider the equality case, which should give me the maximum value.
221 |
222 | Thus, I set \( x_3 = x_1 + \frac{\lambda}{2} \).
223 |
224 | Now, I substitute \( x_2 \) and \( x_3 \) in terms of \( x_1 \) and \( \lambda \) into the expression for \( a \): $$-a = x_1 + (x_1 + \lambda) + \left( x_1 + \frac{\lambda}{2} \right ).$$
225 |
226 | This simplifies to $$-a = 3x_1 + \frac{3\lambda}{2}.$$
227 |
228 | I can also express \( b \) and \( c \) in terms of \( x_1 \) and \( \lambda \):
229 | $$b = x_1(x_1 + \lambda) + (x_1 + \lambda)\left( x_1 + \frac{\lambda}{2} \right ) + \left( x_1 + \frac{\lambda}{2} \right ) x_1,$$
230 | and
231 | $$-c = x_1(x_1 + \lambda)\left( x_1 + \frac{\lambda}{2} \right ).$$
232 |
233 | Substituting these expressions into the given expression, I get $$\frac{2 a^3 + 27 c - 9 a b}{\lambda^3}$$ in terms of \( x_1 \) and \( \lambda \).
234 |
235 | However, since I want the maximum value of this expression, which is independent of \( x_1 \), I can choose a specific value of \( x_1 \) to simplify the calculation.
236 |
237 | Let's set \( x_1 = 0 \) to get the maximum value.
238 |
239 | Substituting \( x_1 = 0 \) into the expressions for \( a \), \( b \), and \( c \), I get
240 | $$-a = \frac{3\lambda}{2},$$
241 | $$b = \frac{\lambda^2}{2},$$
242 | and
243 | $$-c = 0.$$
244 |
245 | Now, I substitute these values into the given expression:
246 | $$\frac{2 a^3 + 27 c - 9 a b}{\lambda^3} = \frac{2\left( -\frac{3\lambda}{2} \right )^3 + 27(0) - 9\left( -\frac{3\lambda}{2} \right )\left( \frac{\lambda^2}{2} \right )}{\lambda^3}.$$
247 |
248 | Simplifying the numerator:
249 | $$2\left( -\frac{27\lambda^3}{8} \right ) + \frac{27\lambda^3}{4} = -\frac{27\lambda^3}{4} + \frac{27\lambda^3}{4} = 0.$$
250 |
251 | Therefore, the maximum value of the expression is $$\frac{0}{\lambda^3} = 0.$$
252 |
253 | \\boxed{0}
254 | """
255 |
256 | # Prepare the messages
257 | messages = [
258 | {"role": "system", "content": system_prompt},
259 | {"role": "user", "content": user_problem},
260 | {"role": "assistant", "content": assistant_solution},
261 | {"role": "user", "content": f"Problem: {problem}"},
262 | ]
263 | return messages
264 |
265 | def format_items(problem, system_prompt):
266 | return [
267 | {"role": "system", "content": system_prompt},
268 | {"role": "user", "content": problem}
269 | ]
270 |
271 | @tenacity.retry(stop=tenacity.stop_after_attempt(5), wait=tenacity.wait_exponential(multiplier=1, min=4, max=10))
272 | async def get_chat_responses(server_handler, batch_items, n, top_p, temperature, max_tokens):
273 | return await server_handler.get_chat_responses(
274 | batch_items,
275 | n=n,
276 | top_p=top_p,
277 | temperature=temperature,
278 | max_tokens=max_tokens
279 | )
280 |
281 | async def main():
282 | # Configuration
283 | model_name = args.model_name
284 | dataset_name = args.dataset_name
285 | save_name = args.save_name
286 | save_folder = args.save_folder
287 | problem_column_name = args.problem_column_name
288 | response_column_name=args.response_column_name
289 | if args.dataset_name_outputs:
290 | hf_save_dataset_name = args.dataset_name_outputs
291 | else:
292 | hf_save_dataset_name = dataset_name
293 | os.makedirs(save_folder, exist_ok=True)
294 |
295 | batch_size = args.batch_size
296 | save_interval = args.save_interval
297 |
298 | # Sampling hyperparameters
299 | if args.greedy:
300 | # greedy decoding
301 | llm_params = LLMHyperparams(
302 | temperature=0.0,
303 | model_name=ModelType(model_name),
304 | n=1,
305 | max_tokens=args.max_tokens
306 | )
307 | else:
308 | llm_params = LLMHyperparams(
309 | top_k=args.top_k,
310 | top_p=args.top_p,
311 | temperature=args.temperature,
312 | model_name=ModelType(model_name),
313 | n=args.n,
314 | max_tokens=args.max_tokens
315 | )
316 |
317 | print("Hyperparameters:")
318 | print(llm_params.to_json())
319 |
320 | # Load and preprocess dataset
321 | ds = load_dataset(dataset_name, split=args.dataset_split)
322 | df = ds.to_pandas()
323 |
324 | # if only using partial dataset, slice it
325 | if args.end == -1:
326 | args.end = len(df)
327 | df = df.iloc[args.start:args.end]
328 | df = df.reset_index(drop=True)
329 |
330 | print(f"Total dataset: {len(df)}")
331 |
332 | # Prepare items for processing
333 | if args.generate_training_data:
334 | items = [construct_few_shot_prompt(row[problem_column_name]) for _, row in df.iterrows()]
335 | else:
336 | items = [format_items(row[problem_column_name], llm_params.system_prompt) for _, row in df.iterrows()]
337 |
338 | df[response_column_name] = None
339 | # Process items in batches
340 | total_items = len(items)
341 | count = 0
342 | with SGLangServerManager(model_name, tp=args.tp) as server_handler:
343 | for idx in tqdm(range(0, total_items, batch_size)):
344 | batch_items = items[idx:idx+batch_size]
345 | batch_outputs = await get_chat_responses(
346 | server_handler=server_handler,
347 | batch_items=batch_items,
348 | n=llm_params.n,
349 | top_p=llm_params.top_p,
350 | temperature=llm_params.temperature,
351 | max_tokens=llm_params.max_tokens
352 | )
353 | batch_responses = []
354 | for resp in batch_outputs:
355 | try:
356 | batch_responses.append(resp[-1]["responses"])
357 | except Exception as e:
358 | print(f"Response: {resp}")
359 | traceback_str = traceback.format_exc()
360 | print(f"Error processing response: {traceback_str}")
361 | batch_responses.append([""])
362 |
363 | count += 1
364 | # Assign responses to dataframe
365 | for i, response_list in enumerate(batch_responses):
366 | df.at[idx+i, response_column_name] = response_list
367 | if count % save_interval == 0:
368 | try:
369 | df.iloc[:idx+batch_size].to_parquet(
370 | os.path.join(save_folder, f"{save_name}_{count}_batch.parquet")
371 | )
372 | ds = Dataset.from_pandas(df)
373 | ds.push_to_hub(hf_save_dataset_name, private=True)
374 | except Exception as e:
375 | print(f"Error saving checkpoint: {e}")
376 | # Save final results
377 | try:
378 | df.to_parquet(os.path.join(save_folder, f"{save_name}.parquet"))
379 | ds = Dataset.from_pandas(df)
380 | ds.push_to_hub(hf_save_dataset_name, private=True)
381 | print(f"Saved to {os.path.join(save_folder, f'{save_name}.parquet')}")
382 | except Exception as e:
383 | print(f"Error saving final results: {e}")
384 |
385 | if __name__ == "__main__":
386 | parser = argparse.ArgumentParser()
387 | # Model and dataset configuration
388 | parser.add_argument("--model_name", type=str, required=True)
389 | parser.add_argument("--dataset_name", type=str, required=True)
390 | parser.add_argument("--dataset_name_outputs", type=str,
391 | help="To save the outputs to a different HF dataset, specify here.")
392 | parser.add_argument("--dataset_split", type=str, default="train")
393 | parser.add_argument("--problem_column_name", type=str, default="problem")
394 | parser.add_argument("--response_column_name", type=str, required=True)
395 | parser.add_argument("--generate_training_data", action="store_true",
396 | help="Use a few-shot prompt to encourage the model to follow a specific format.")
397 |
398 | # Save configuration
399 | parser.add_argument("--save_folder", type=str, required=True)
400 | parser.add_argument("--save_name", type=str, required=True)
401 | parser.add_argument("--save_interval", type=int, default=10000,
402 | help="Save every n batches.")
403 |
404 | # SGLang server configuration
405 | parser.add_argument("--tp", type=int, default=1)
406 | parser.add_argument("--batch_size", type=int, default=250,
407 | help="Total batch size will be args.batch_size * args.n.")
408 |
409 | # LLM Hyperparameters
410 | parser.add_argument("--top_k", type=int, default=-1)
411 | parser.add_argument("--top_p", type=float, default=0.95)
412 | parser.add_argument("--temperature", type=float, default=0.8)
413 | parser.add_argument("--max_tokens", type=int, default=4096)
414 | parser.add_argument("--n", type=int, default=64)
415 | parser.add_argument("--greedy", action="store_true", help="Use greedy decoding. Ignores temperature, top_p, and top_k.")
416 |
417 | # dataset slicing
418 | parser.add_argument("--start", type=int, default=0)
419 | parser.add_argument("--end", type=int, default=-1)
420 | args = parser.parse_args()
421 |
422 | asyncio.run(main())
423 |
424 | # Example Usage:
425 | # python3 sample_from_model.py --model_name meta-llama/Llama-3.1-405B-Instruct-FP8 --dataset_name RLAIF/Big-Math-Competition-Problems --save_folder math_competition_problems --save_name llama405b_greedy --response_column_name llama_405b_greedy_response --greedy --tp 8
426 |
427 | # python3 sample_from_model.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --dataset_name RLAIF/Big-Math-Competition-Problems --save_folder math_competition_problems --save_name llama8b_greedy --response_column_name llama_8b_greedy_response --greedy
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SynthLabsAI/big-math/420b9a771a7e97a85b81cbdcbd573b1b0d56f522/signals/rollouts_based_signals/utils/__init__.py
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/utils/openai_server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import signal
4 |
5 | import openai
6 | from tqdm.asyncio import tqdm_asyncio
7 |
8 |
9 | class OpenAIServerManager:
10 | def __init__(
11 | self, model_name, start_port=2341, tp=1, max_time=600, trust_remote_code=False,
12 | ):
13 | self.model_name = model_name
14 | self.start_port = start_port
15 | self.tp = tp
16 | self.max_time = max_time
17 | self.trust_remote_code = trust_remote_code
18 |
19 | def __enter__(self):
20 | self.ports, self.subprocs = self.launch_servers(
21 | self.model_name, self.start_port, self.tp, self.max_time,
22 | )
23 | self.async_clients = [
24 | openai.AsyncClient(base_url=f"http://localhost:{port}/v1", api_key="EMPTY")
25 | for port in self.ports
26 | ]
27 | self.sems = [asyncio.Semaphore(64) for _ in self.ports]
28 | return self
29 |
30 | def __exit__(self, exc_type, exc_val, exc_tb):
31 | for proc in self.subprocs:
32 | os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
33 |
34 | def launch_servers(
35 | self, model_name, start_port=8001, tp=1, max_time=600
36 | ):
37 | raise NotImplementedError
38 |
39 | async def get_chat_responses(self, chats, **kwargs) -> list[openai.ChatCompletion]:
40 | """
41 | Get responses from the sglang server with retry functionality
42 | :param chats: list of chats
43 | :return: list of chat completions
44 | """
45 | responses = list()
46 |
47 |
48 | async def response_wrapper(client: openai.AsyncClient, sem, **kwargs):
49 | async with sem:
50 | try:
51 | out = await client.chat.completions.create(**kwargs)
52 | completions = [choice.message.content for choice in out.choices]
53 | messages = kwargs["messages"]
54 | messages[-1]["responses"] = completions
55 | return messages
56 | except Exception as e:
57 | print(f"Error in response_wrapper: {str(e)}")
58 | default_message = ""
59 | messages = kwargs["messages"]
60 | messages[-1]["responses"] = [default_message]
61 | return messages
62 |
63 | for i, chat in enumerate(chats):
64 | curr_kwargs = kwargs.copy()
65 | curr_kwargs["model"] = "default"
66 | curr_kwargs["messages"] = chat
67 | if "max_tokens" not in curr_kwargs:
68 | curr_kwargs["max_tokens"] = (
69 | 2048 # They can go realllyy reallly long if you just yolo it.
70 | )
71 | responses.append(
72 | response_wrapper(
73 | self.async_clients[i % len(self.ports)],
74 | self.sems[i % len(self.ports)],
75 | **curr_kwargs,
76 | )
77 | )
78 | return await tqdm_asyncio.gather(*responses)
79 |
80 | async def get_completion_responses(
81 | self, prompts, **kwargs
82 | ) -> list[openai.Completion]:
83 | """
84 | Get responses from the sglang server
85 | :param prompts: list of prompts
86 | :return:
87 | """
88 | responses = list()
89 |
90 | async def response_wrapper(client: openai.AsyncClient, sem, **kwargs):
91 | async with sem:
92 | completions = await client.completions.create(**kwargs)
93 | completions = [choice.text for choice in completions.choices]
94 | kwargs["responses"] = completions
95 | return kwargs
96 |
97 | for i, prompt in enumerate(prompts):
98 | curr_kwargs = kwargs.copy()
99 | curr_kwargs["model"] = "default"
100 | curr_kwargs["prompt"] = prompt
101 | if "max_tokens" not in curr_kwargs:
102 | curr_kwargs["max_tokens"] = (
103 | 2048 # They can go realllyy reallly long if you just yolo it.
104 | )
105 | responses.append(
106 | response_wrapper(
107 | self.async_clients[i % len(self.ports)],
108 | self.sems[i % len(self.ports)],
109 | **kwargs,
110 | )
111 | )
112 | return await tqdm_asyncio.gather(*responses)
113 |
--------------------------------------------------------------------------------
/signals/rollouts_based_signals/utils/sglang_util.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | import signal
4 | import subprocess
5 | import time
6 |
7 | import openai
8 |
9 | from .openai_server import OpenAIServerManager
10 |
11 |
12 | def get_sglang_response(port):
13 | """
14 | tries to get a response from the sglang server
15 | :return:
16 | """
17 |
18 | client = openai.Client(base_url=f"http://localhost:{port}/v1", api_key="EMPTY")
19 |
20 | # Text completion
21 | response = client.completions.create(
22 | model="default",
23 | prompt="The capital of France is",
24 | temperature=0,
25 | max_tokens=1,
26 | )
27 | print(response)
28 |
29 | def kill_process_group(pid):
30 | """Kill the entire process group for the given PID."""
31 | try:
32 | os.killpg(os.getpgid(pid), signal.SIGTERM)
33 | except ProcessLookupError:
34 | pass # Process already terminated
35 |
36 | class SGLangServerManager(OpenAIServerManager):
37 | def launch_servers(self, model_name, start_port=1234, tp=1, max_time=600):
38 | """
39 | Launches an sglang server on all available devices.
40 |
41 | Args:
42 | model_name (str): Path to the model.
43 | start_port (int): Port to start on.
44 | tp (int): Tensor parallelism.
45 | max_time (int): Maximum time (in seconds) to wait for the server to become ready.
46 |
47 | Returns:
48 | tuple: (ports, subprocesses) where ports is a list of ports and subprocesses is a list of Popen objects.
49 | """
50 | subprocesses = []
51 | # Get list of devices from env var (defaulting to 0-7 if not set)
52 | devices = os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(",")
53 | dp = len(devices) // tp
54 |
55 | # Correctly generate ports based on tp and available devices
56 | # (Even if we only launch one process, we keep a list for compatibility.)
57 | ports = [start_port for port in range(start_port, start_port + len(devices), tp)]
58 |
59 | # Build the command as a list to avoid using the shell
60 | cmd = [
61 | "python",
62 | "-m",
63 | "sglang.launch_server",
64 | "--model-path", model_name,
65 | "--port", str(start_port),
66 | "--tp", str(tp),
67 | "--dp", str(dp),
68 | "--log-level", "error",
69 | ]
70 | if self.trust_remote_code:
71 | cmd.append("--trust-remote-code")
72 |
73 | # Launch the server process in its own process group.
74 | process = subprocess.Popen(cmd, start_new_session=True)
75 | subprocesses.append(process)
76 |
77 | # Ensure that the child process group is killed when the parent exits.
78 | atexit.register(kill_process_group, process.pid)
79 |
80 | # Optionally, also install signal handlers for SIGINT and SIGTERM.
81 | def _signal_handler(sig, frame):
82 | kill_process_group(process.pid)
83 | raise KeyboardInterrupt
84 |
85 | original_sigint = signal.getsignal(signal.SIGINT)
86 | original_sigterm = signal.getsignal(signal.SIGTERM)
87 | signal.signal(signal.SIGINT, _signal_handler)
88 | signal.signal(signal.SIGTERM, _signal_handler)
89 |
90 | # Wait until at least one port is responsive or timeout is reached.
91 | start_time = time.monotonic()
92 | ports_working = []
93 |
94 | while time.monotonic() - start_time < max_time:
95 | for port in ports:
96 | if port in ports_working:
97 | continue
98 | try:
99 | get_sglang_response(port)
100 | ports_working.append(port)
101 | except (openai.APITimeoutError, openai.APIConnectionError) as err:
102 | print(f"Port {port} not ready yet.")
103 | if ports_working:
104 | break
105 | time.sleep(1) # shorter sleep interval for faster feedback
106 |
107 | else:
108 | # Timeout reached, ensure cleanup and then raise error.
109 | kill_process_group(process.pid)
110 | raise TimeoutError("Server did not become ready within the allotted time.")
111 |
112 | # Restore original signal handlers.
113 | signal.signal(signal.SIGINT, original_sigint)
114 | signal.signal(signal.SIGTERM, original_sigterm)
115 |
116 | return ports, subprocesses
--------------------------------------------------------------------------------
/signals/semdedup.py:
--------------------------------------------------------------------------------
1 | import faiss
2 | from transformers import AutoModel, AutoTokenizer
3 | import pandas as pd
4 | import torch
5 | from tqdm import tqdm
6 |
7 | def concat_columns(df: pd.DataFrame, columns: list[str]):
8 | """Concatenate the columns into a new column.
9 |
10 | Args:
11 | df (pd.DataFrame): The dataframe.
12 | columns (List[str]): The columns to combine.
13 | """
14 |
15 | # validate that the columns are of type string
16 | for col in columns:
17 | if df[col].dtype != "object":
18 | raise ValueError(f"Column {col} is not of type object.")
19 | return df[columns].apply(lambda x: " ".join(x), axis=1)
20 |
21 | def semantic_deduplication(
22 | df: pd.DataFrame,
23 | required_columns: list[str],
24 | num_kmeans_clusters: int,
25 | epsilon: float = 0.99,
26 | similarity_metric: str = "cosine",
27 | keep_central: bool = True,
28 | kmeans_with_cosine_distance: bool = False,
29 | model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
30 | embedding_batch_size: int = 100,
31 | use_gpu: bool = False,
32 | ):
33 | """
34 | Perform semantic deduplication on a dataframe.
35 |
36 | Args:
37 | df (pd.DataFrame): The dataframe.
38 | required_columns (List[str]): The columns to use for deduplication. Will be concatenated in order.
39 | epsilon (float): The epsilon value to use for semantic deduplication.
40 | Pairs of items with similarity above epsilon will be considered duplicates.
41 | num_kmeans_clusters (int): The number of clusters to use in kmeans.
42 | similarity_metric (str): The similarity metric to use, only "cosine" currently implemented.
43 | keep_central (bool): Whether to keep the item closest (if True) or farther (if False)
44 | from the cluster centroid when determining which item to keep.
45 | kmeans_with_cosine_distance (bool): Whether to use cosine distance for kmeans,
46 | only False currently implemented.
47 | model_name (str): The model name to use for embedding.
48 | embedding_batch_size (int): The batch size to use for embedding.
49 | """
50 |
51 | # Mean Pooling - Take attention mask into account for correct averaging
52 | def mean_pooling(model_output, attention_mask):
53 | token_embeddings = model_output[
54 | 0
55 | ] # First element of model_output contains all token embeddings
56 | input_mask_expanded = (
57 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
58 | )
59 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
60 | input_mask_expanded.sum(1), min=1e-9
61 | )
62 |
63 | def embed_sentences_batch(sentences):
64 | # Tokenize sentences
65 | encoded_input = tokenizer(
66 | sentences, padding=True, truncation=True, return_tensors="pt"
67 | )
68 |
69 | if use_gpu:
70 | encoded_input = encoded_input.to("cuda")
71 |
72 | # Compute token embeddings
73 | with torch.no_grad():
74 | model_output = model(**encoded_input)
75 |
76 | # Perform pooling
77 | sentence_embeddings = mean_pooling(
78 | model_output, encoded_input["attention_mask"]
79 | )
80 |
81 | # Normalize embeddings
82 | sentence_embeddings = torch.nn.functional.normalize(
83 | sentence_embeddings, p=2, dim=1
84 | )
85 |
86 | return sentence_embeddings
87 |
88 | def embed_sentences(sentences, batch_size=embedding_batch_size):
89 | """
90 | Embed a list of sentences using the sentence-transformers model.
91 |
92 | Parameters
93 | ----------
94 | sentences : List[str]
95 | The list of sentences to embed.
96 | batch_size : int
97 | The batch size for embedding.
98 |
99 | Returns
100 | -------
101 | torch.Tensor
102 | The sentence embeddings.
103 | """
104 |
105 | # iterate over the sentences in batches
106 | sentence_embeddings = []
107 | for i in tqdm(
108 | range(0, len(sentences), batch_size),
109 | dynamic_ncols=True,
110 | desc="Embedding sentences...",
111 | ):
112 | batch = sentences[i : i + batch_size]
113 | sentence_embeddings.append(embed_sentences_batch(batch))
114 |
115 | return torch.cat(sentence_embeddings)
116 |
117 | def semdedup(cluster, eps=0.95):
118 | # compute pairwise cosine similarity between cluster items
119 | pairwise_similarity_matrix = cluster @ cluster.T
120 |
121 | # filter out the diagonal elements
122 | pairwise_similarity_matrix.fill_diagonal_(0.0)
123 |
124 | # create the upper triangular matrix
125 | upper_triangular = torch.triu(pairwise_similarity_matrix)
126 |
127 | # create a binary matrix, where 1 means the similarity is above epsilon
128 | matrix_of_removals = torch.where(upper_triangular > eps, 1, 0)
129 | # get the indices to remove
130 | # head is the row, tail is the column
131 | head_duplicates, tail_duplicates = matrix_of_removals.nonzero(as_tuple=True)
132 |
133 | # get the indices to remove
134 | indices_to_remove = tail_duplicates.tolist()
135 |
136 | return (
137 | list(set(indices_to_remove)),
138 | head_duplicates.tolist(),
139 | tail_duplicates.tolist(),
140 | )
141 |
142 | content_col = "dedup_content"
143 | df[content_col] = concat_columns(df, required_columns)
144 |
145 | # Load model from HuggingFace Hub
146 | tokenizer = AutoTokenizer.from_pretrained(model_name)
147 | model = AutoModel.from_pretrained(model_name)
148 |
149 | if use_gpu:
150 | model = model.to("cuda")
151 |
152 | # embed the content
153 | embedded_content = embed_sentences(
154 | df[content_col].tolist(), batch_size=embedding_batch_size
155 | )
156 |
157 | if use_gpu:
158 | embedded_content = embedded_content.to("cpu")
159 |
160 | kmeans = faiss.Kmeans(
161 | d=embedded_content.size(1),
162 | k=num_kmeans_clusters,
163 | niter=20,
164 | verbose=True,
165 | seed=42,
166 | spherical=kmeans_with_cosine_distance, # only true if using cosine distance
167 | gpu=use_gpu,
168 | )
169 |
170 | # train the kmeans object
171 | kmeans.train(embedded_content)
172 | centroids = kmeans.centroids
173 |
174 | # get the nearest centroid for each data point
175 | dist_to_cent, nearest_cent = kmeans.index.search(embedded_content, 1)
176 | dist_to_cent, nearest_cent = dist_to_cent.squeeze(), nearest_cent.squeeze()
177 |
178 | # assign the distance and cluster to the dataframe
179 | df["distance_to_centroid"] = dist_to_cent
180 | df["kmeans_cluster"] = nearest_cent
181 |
182 | indices_to_remove = []
183 | cluster_duplicates_dfs = {}
184 |
185 | for cluster_id in tqdm(
186 | range(num_kmeans_clusters), desc="Iterating over clusters..."
187 | ):
188 | cluster_df = df[df["kmeans_cluster"] == cluster_id]
189 |
190 | # if cluster is empty, skip
191 | if len(cluster_df) == 0:
192 | continue
193 |
194 | # get only items from this cluster
195 | cluster_idxs = cluster_df.index.tolist()
196 | cluster_embeddings = embedded_content[cluster_idxs]
197 |
198 | # compute the similarity to the centroid
199 | if similarity_metric == "cosine":
200 | if kmeans_with_cosine_distance:
201 | # if cosine distance was used for kmeans clustering, don't recompute
202 | cluster_dists_to_cent = 1 - cluster_df["distance_to_centroid"]
203 | else:
204 | # compute the cosine similarity to the centroid
205 | cluster_centroid = torch.tensor(centroids[cluster_id])
206 | sim_to_cent = torch.nn.functional.cosine_similarity(
207 | cluster_embeddings, cluster_centroid
208 | )
209 | cluster_dists_to_cent = 1 - sim_to_cent
210 | elif similarity_metric == "l2":
211 | cluster_dists_to_cent = cluster_df["distance_to_centroid"]
212 |
213 | # sort the cluster items by distance to centroid
214 | sort_descending = (
215 | keep_central # if keep_central is True, sort in descending order
216 | )
217 | cluster_sorted = sorted(
218 | zip(cluster_idxs, cluster_embeddings, cluster_dists_to_cent),
219 | key=lambda x: x[2],
220 | reverse=sort_descending,
221 | )
222 |
223 | # get the sorted indices
224 | sorted_cluster_idxs = [x[0] for x in cluster_sorted]
225 | sorted_cluster_embeddings = torch.stack([x[1] for x in cluster_sorted])
226 |
227 | # use semdedup to determine which items to remove
228 | (
229 | cluster_indices_to_remove,
230 | cluster_head_duplicates,
231 | cluster_tail_duplicates,
232 | ) = semdedup(sorted_cluster_embeddings, eps=epsilon)
233 |
234 | while cluster_head_duplicates:
235 | assert len(cluster_head_duplicates) == len(
236 | cluster_tail_duplicates
237 | ), "Lengths of head and tail duplicates should be the same."
238 |
239 | # get the first pair of duplicates
240 | head_idx = cluster_head_duplicates.pop(0)
241 | tail_idx = cluster_tail_duplicates.pop(0)
242 |
243 | # if the head index is not in the duplicates, create a new dataframe for it
244 | if sorted_cluster_idxs[head_idx] not in cluster_duplicates_dfs:
245 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]] = pd.DataFrame(
246 | columns=df.columns
247 | )
248 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]].loc[
249 | sorted_cluster_idxs[head_idx]
250 | ] = df.loc[sorted_cluster_idxs[head_idx]]
251 |
252 | # add the tail index to the head duplicates dataframe
253 | cluster_duplicates_dfs[sorted_cluster_idxs[head_idx]].loc[
254 | sorted_cluster_idxs[tail_idx]
255 | ] = df.loc[sorted_cluster_idxs[tail_idx]]
256 |
257 | # remove the tail index if it appears in the head duplicates,
258 | # prevents duplicates from being counted multiple times
259 | tail_indxs_in_head = [
260 | i for i, x in enumerate(cluster_head_duplicates) if x == tail_idx
261 | ]
262 | # remove in reverse order so that the indices don't change
263 | for i in tail_indxs_in_head[::-1]:
264 | cluster_head_duplicates.pop(i)
265 | cluster_tail_duplicates.pop(i)
266 |
267 | # convert cluster indices to global indices
268 | global_indices_to_remove = [
269 | sorted_cluster_idxs[i] for i in cluster_indices_to_remove
270 | ]
271 |
272 | indices_to_remove.extend(global_indices_to_remove)
273 |
274 | # # remove the duplicates
275 | # df = df.drop(indices_to_remove)
276 |
277 | # remove the temporary columns
278 | df = df.drop(
279 | columns=["dedup_content", "distance_to_centroid", "kmeans_cluster"], axis=1
280 | )
281 |
282 | return df, indices_to_remove, cluster_duplicates_dfs
--------------------------------------------------------------------------------