├── .DS_Store ├── README.md ├── code ├── __init__.py ├── constraint_checker.py ├── constraint_registry.py ├── constraint_util.py ├── eval_if.py ├── requirement.txt ├── script │ ├── download.sh │ ├── eval_if.sh │ └── vllm_if.sh └── vllm_core.py └── data ├── aime_double.jsonl ├── aime_single.jsonl ├── aime_triple.jsonl ├── gsm8k_double.jsonl ├── gsm8k_single.jsonl ├── gsm8k_triple.jsonl ├── math500_double.jsonl ├── math500_single.jsonl ├── math500_triple.jsonl ├── minerva_double.jsonl ├── minerva_single.jsonl ├── minerva_triple.jsonl ├── olympiad_double.jsonl ├── olympiad_single.jsonl ├── olympiad_triple.jsonl └── toy.jsonl /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TingchenFu/MathIF/e09c04f4ab40cec0f65105f7b78fcc47a59cf8d2/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | MathIF: Instruction-Following Benchmark for Large Reasoning Models 3 |

4 | 5 | [![Python 3.9+](https://img.shields.io/badge/python-3.9%2B-brightgreen)]() [![CUDA 12.4](https://img.shields.io/badge/CUDA-12.4-red)]() [![License](https://img.shields.io/badge/license-MIT-blue)]() 6 | 7 | MathIF is a dedicated benchmark for evaluating the instruction-following capabilities of large reasoning models (LRMs) on mathematical reasoning tasks. It exposes a fundamental trade-off between a model’s problem-solving strength and its ability to comply with user-specified constraints. 8 | 9 | 10 |
11 |

12 | 13 | • 14 | 📖 Paper • 15 | 16 | 18 | 🔧 Usage 19 | 20 | • 21 | 📊 Leaderboard • 22 | 🤗 Data • 23 | 🐦 Twitter 24 | 25 |

26 |
27 | 28 | 29 | # 📖Features 30 | 31 | - **Compositional Constraints** 32 | 15 Python-verifiable constraint types in four categories (length, lexical, format, affix), combined into single, dual, and triple constraints. 33 | 34 | - **Diverse Math Sources** 35 | Problems drawn from GSM8K, MATH-500, Minerva, Olympiad, and AIME, totaling 420 high-quality evaluation samples. 36 | 37 | - **Fine-Grained Metrics** 38 | - **Hard Accuracy (HAcc):** fraction of examples satisfying _all_ constraints 39 | - **Soft Accuracy (SAcc):** average fraction of satisfied constraints per example 40 | 41 | - **vLLM-Powered Inference** 42 | Efficient decoding with nucleus sampling (T=1.0, p=0.95) and up to 16k token generation. 43 | 44 | # ✨Getting Started 45 | 46 | ## Prerequisites 47 | 48 | - Python 3.9 or later 49 | - CUDA 12.4 50 | - `git`, `bash` 51 | 52 | ## Installation 53 | 54 | ```bash 55 | git clone https://github.com/TingchenFu/MathIF.git 56 | cd MathIF 57 | 58 | # Create and activate virtual environment 59 | python -m venv venv 60 | source venv/bin/activate # Windows: venv\Scripts\activate 61 | 62 | # Install dependencies 63 | pip install -r requirements.txt 64 | ```` 65 | 66 | 67 | 68 | # 🔧Usage 69 | 70 | ## Inference 71 | 72 | ```bash 73 | bash code/scripts/vllm_if.sh 74 | ``` 75 | 76 | ## Evaluation 77 | 78 | ```bash 79 | bash code/scripts/eval_if.sh 80 | ``` 81 | 82 | ## Dataset Format 83 | 84 | Each line in the JSONL file contains: 85 | 86 | | Field | Description | 87 | | ----------------- | --------------------------------- | 88 | | `source` | Original data source | 89 | | `id` | Unique example identifier | 90 | | `question` | Math problem statement | 91 | | `answer` | Ground-truth solution | 92 | | `constraint_desc` | Human-readable constraint summary | 93 | | `constraint_name` | Constraint category | 94 | | `constraint_args` | Arguments used for verification | 95 | 96 | ## Project Structure 97 | 98 | ``` 99 | . 100 | ├── data/ # MathIF JSONL files 101 | ├── code/ 102 | │ ├── scripts/ # Inference & evaluation scripts 103 | │ └── ... # Model wrappers and utilities 104 | ├── output/ # Generated predictions & logs 105 | ├── requirements.txt # Python dependencies 106 | └── README.md # This overview 107 | ``` 108 | 109 | 128 | 129 | Here's your LaTeX table transformed into a clean and readable GitHub-flavored Markdown table, **keeping only HAcc, SAcc, and correctness with constraint** (`w/ const.`). For clarity, the models are grouped by size, but LaTeX-specific formatting (bold/underline) is omitted since GitHub tables do not support rich styling. 130 | 131 | 132 | # 📊Leaderboard 133 | 📢 **Showcase Your Model’s Instruction-Following Capability** 134 | 135 | Feel free to contribute results from your own models—we welcome community submissions! 136 | We currently support evaluation of newly added models on our platform. To be included on the leaderboard, please provide the Hugging Face model link for verification and testing. 137 | 138 | ## **≤ 4B Models** 139 | 140 | | Model | HAcc | SAcc | Correctness | 141 | | ----------------------------- | ----- | ----- | ----------------------- | 142 | | [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) | 44.05 | 61.43 | 58.57 | 143 | | [Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) | 30.24 | 50.24 | 51.19 | 144 | | [Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) | 27.86 | 50.44 | 32.14 | 145 | | [L1-Qwen-1.5B-Exact](https://huggingface.co/l3lab/L1-Qwen-1.5B-Exact) | 19.76 | 39.60 | 42.86 | 146 | | [L1-Qwen-1.5B-Max](https://huggingface.co/l3lab/L1-Qwen-1.5B-Max) | 19.76 | 39.40 | 45.71 | 147 | | [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | 17.14 | 36.62 | 31.67 | 148 | | [DeepScaler-1.5B-Preview](https://huggingface.co/agentica-org/DeepScaleR-1.5B-Preview) | 14.52 | 34.52 | 36.19 | 149 | | [Qwen2.5-1.5B-SimpleRL-Zoo](https://huggingface.co/hkust-nlp/Qwen-2.5-1.5B-SimpleRL-Zoo) | 9.05 | 24.33 | 22.38 | 150 | | [Qwen2.5-Math-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B-Instruct) | 7.62 | 21.39 | 44.29 | 151 | 152 | ## **7B–14B Models** 153 | | Model | HAcc | SAcc | Correctness | 154 | | ----------------------------- | ----- | ----- | ----------------------- | 155 | | [Qwen3-14B](https://huggingface.co/Qwen/Qwen3-14B) | 50.71 | 67.06 | 64.29 | 156 | | [DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) | 39.28 | 60.55 | 50.95 | 157 | | [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) | 37.86 | 57.34 | 66.43 | 158 | | [DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 26.43 | 44.96 | 48.57 | 159 | | [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) | 22.14 | 44.04 | 36.43 | 160 | | [Open-Reasoner-Zero-7B](https://huggingface.co/Open-Reasoner-Zero/Open-Reasoner-Zero-7B) | 13.57 | 32.26 | 51.90 | 161 | | [Qwen2.5-Math-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Math-7B-Instruct) | 9.05 | 25.60 | 37.14 | 162 | 163 | 164 | ## **≥ 32B Models** 165 | | Model | HAcc | SAcc | Correctness | 166 | | ----------------------------- | ----- | ----- | ----------------------- | 167 | | [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | 43.81 | 62.82 | 70.00 | 168 | | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | 42.62 | 60.91 | 57.62 | 169 | | [DeepSeek-R1-Distill-Llama-70B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B) | 41.43 | 61.07 | 54.05 | 170 | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | 40.24 | 59.99 | 68.81 | 171 | | [OlympicCoder-32B](https://huggingface.co/open-r1/OlympicCoder-32B) | 35.95 | 57.97 | 54.52 | 172 | | [s1-32B](https://huggingface.co/simplescaling/s1-32B) | 20.95 | 41.78 | 60.95 | 173 | | [Open-Reasoner-Zero-32B](https://huggingface.co/Open-Reasoner-Zero/Open-Reasoner-Zero-32B) | 15.47 | 35.52 | 67.62 | 174 | 175 | 176 | 177 | 178 | # 🌻Acknowledgements 179 | 180 | MathIF is inspired by prior work on [IFEval](https://huggingface.co/datasets/google/IFEval) and [ComplexBench](https://github.com/thu-coai/ComplexBench), and leverages [vLLM](https://github.com/vllm-project/vllm) for efficient inference. 181 | 182 | # 📖Citation 183 | 184 | ``` 185 | @article{fu2025scaling, 186 | title={Scaling Reasoning, Losing Control: Evaluating Instruction Following in Large Reasoning Models}, 187 | author={Fu, Tingchen and Gu, Jiawei and Li, Yafu and Qu, Xiaoye and Cheng, Yu}, 188 | journal={arXiv preprint arXiv:2505.14810}, 189 | year={2025} 190 | } 191 | ``` 192 | 193 | # 📬Contact 194 | 195 | For questions, feedback, or collaboration inquiries, please contact: 196 | - **Tingchen Fu**: lucas.futingchen@gmail.com 197 | - **Yafu Li**: yafuly@gmail.com 198 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TingchenFu/MathIF/e09c04f4ab40cec0f65105f7b78fcc47a59cf8d2/code/__init__.py -------------------------------------------------------------------------------- /code/constraint_checker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of Constraints.""" 16 | 17 | import collections 18 | import json 19 | import logging 20 | import random 21 | import re 22 | import string 23 | from typing import Dict, Optional, Sequence, Union 24 | 25 | import langdetect 26 | import nltk 27 | 28 | from constraint_util import LANGUAGE_CODES, generate_keywords, count_words 29 | 30 | 31 | 32 | 33 | 34 | _LANGUAGES = LANGUAGE_CODES 35 | 36 | # The relational operation for comparison. 37 | _COMPARISON_RELATION = ("less than", "at least") 38 | 39 | # The maximum number of sentences. 40 | _MAX_NUM_SENTENCES = 20 41 | 42 | # The number of placeholders. 43 | _NUM_PLACEHOLDERS = 4 44 | 45 | # The number of bullet lists. 46 | _NUM_BULLETS = 5 47 | 48 | # The options of constrained response. 49 | _CONSTRAINED_RESPONSE_OPTIONS = ( 50 | "My answer is yes.", 51 | "My answer is no.", 52 | "My answer is maybe.", 53 | ) 54 | 55 | # The options of starter keywords. 56 | _STARTER_OPTIONS = ( 57 | "I would say", 58 | "My answer is", 59 | "I believe", 60 | "In my opinion", 61 | "I think", 62 | "I reckon", 63 | "I feel", 64 | "From my perspective", 65 | "As I see it", 66 | "According to me", 67 | "As far as I'm concerned", 68 | "To my understanding", 69 | "In my view", 70 | "My take on it is", 71 | "As per my perception", 72 | ) 73 | 74 | # The options of ending keywords. 75 | # TODO(jeffreyzhou) add more ending options 76 | _ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") 77 | 78 | # The number of highlighted sections. 79 | _NUM_HIGHLIGHTED_SECTIONS = 4 80 | 81 | # The section splitter. 82 | _SECTION_SPLITER = ("Section", "SECTION") 83 | 84 | # The number of sections. 85 | _NUM_SECTIONS = 5 86 | 87 | # The number of paragraphs. 88 | _NUM_PARAGRAPHS = 5 89 | 90 | # The postscript marker. 91 | _POSTSCRIPT_MARKER = ("P.S.", "P.P.S") 92 | 93 | # The number of keywords. 94 | _NUM_KEYWORDS = 2 95 | 96 | # The occurrences of a single keyword. 97 | _KEYWORD_FREQUENCY = 3 98 | 99 | # The occurrences of a single letter. 100 | _LETTER_FREQUENCY = 10 101 | 102 | # The occurrences of words with all capital letters. 103 | _ALL_CAPITAL_WORD_FREQUENCY = 20 104 | 105 | # The number of words in the response. 106 | # LEVEL_1_NUM_WORDS_LOWER_LIMIT = 100 107 | # LEVEL_1_NUM_WORDS_UPPER_LIMIT = 500 108 | 109 | NUM_WORDS_LOWER_LIMIT = 128 110 | NUM_WORDS_UPPER_LIMIT = 1024 111 | 112 | 113 | 114 | class Constraint: 115 | """An Constraint template.""" 116 | 117 | def __init__(self, constraint_id=0): 118 | self.id = constraint_id 119 | 120 | def build_description(self, **kwargs): 121 | raise NotImplementedError("`build_description` not implemented.") 122 | 123 | def get_constraint_args(self): 124 | raise NotImplementedError("`get_constraint_args` not implemented.") 125 | 126 | def get_constraint_args_keys(self): 127 | raise NotImplementedError("`get_constraint_args_keys` not implemented.") 128 | 129 | def check_following(self, value): 130 | raise NotImplementedError("`check_following` not implemented.") 131 | 132 | 133 | class ResponseLanguageChecker(Constraint): 134 | """Check the language of the entire response.""" 135 | 136 | def build_description(self, *, language=None): 137 | """Build the Constraint description. 138 | 139 | Args: 140 | language: A string representing the expected language of the response. The 141 | language has to comply to the 97 types defined in 142 | `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows 143 | ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); 144 | for example, `en` for English, `zh` for Chinese, `fr` for French. 145 | 146 | Returns: 147 | A string representing the Constraint description. 148 | """ 149 | self._language = language 150 | if self._language is None: 151 | self._language = random.choice(list(_LANGUAGES.keys())) 152 | 153 | self._description_pattern = ( 154 | "Your answer should be in {language} language, no other language is allowed. " 155 | ) 156 | return self._description_pattern.format(language=_LANGUAGES[self._language]) 157 | 158 | def get_constraint_args(self): 159 | """Returns the keyword args of `build_description`.""" 160 | return {"language": self._language} 161 | 162 | 163 | def check_following(self, value): 164 | """Check if the language of the entire response follows the Constraint. 165 | 166 | Args: 167 | value: A string representing the response. 168 | 169 | Returns: 170 | True if the language of `value` follows Constraint; otherwise False. 171 | """ 172 | assert isinstance(value, str) 173 | 174 | try: 175 | return langdetect.detect(value) == self._language 176 | except langdetect.LangDetectException as e: 177 | # Count as Constraint is followed. 178 | logging.error( 179 | "Unable to detect language for text %s due to %s", value, e 180 | ) # refex: disable=pytotw.037 181 | return True 182 | 183 | 184 | # class NumberOfSentences(Constraint): 185 | # """Check the number of sentences.""" 186 | 187 | # def build_description(self, *, num_sentences=None, relation=None): 188 | # """Build the Constraint description. 189 | 190 | # Args: 191 | # num_sentences: An integer specifying the number of sentences as a 192 | # threshold. 193 | # relation: A string in (`less than`, `at least`), defining the relational 194 | # operator for comparison. 195 | # Two relational comparisons are supported for now: 196 | # if 'less than', the actual number of sentences < the threshold; 197 | # if 'at least', the actual number of sentences >= the threshold. 198 | 199 | # Returns: 200 | # A string representing the Constraint description. 201 | # """ 202 | # # The number of sentences as a threshold for comparison. 203 | # self._num_sentences_threshold = num_sentences 204 | # if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: 205 | # self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) 206 | 207 | # if relation is None: 208 | # self._comparison_relation = random.choice(_COMPARISON_RELATION) 209 | # elif relation not in _COMPARISON_RELATION: 210 | # raise ValueError( 211 | # "The supported relation for comparison must be in " 212 | # f"{_COMPARISON_RELATION}, but {relation} is given." 213 | # ) 214 | # else: 215 | # self._comparison_relation = relation 216 | 217 | # self._description_pattern = ( 218 | # "Your response should contain {relation} {num_sentences} sentences." 219 | # ) 220 | # return self._description_pattern.format( 221 | # relation=self._comparison_relation, 222 | # num_sentences=self._num_sentences_threshold, 223 | # ) 224 | 225 | # def get_constraint_args(self): 226 | # """Returns the keyword args of `build_description`.""" 227 | # return { 228 | # "num_sentences": self._num_sentences_threshold, 229 | # "relation": self._comparison_relation, 230 | # } 231 | 232 | # def get_constraint_args_keys(self): 233 | # """Returns the args keys of `build_description`.""" 234 | # return ["num_sentences", "relation"] 235 | 236 | # def check_following(self, value): 237 | # """Check if the number of sentences follows the Constraint. 238 | 239 | # Args: 240 | # value: A string representing the response. 241 | 242 | # Returns: 243 | # True if the response follows the Constraint. 244 | 245 | # Raise: 246 | # ValueError if the string in `Constraint_args` is not in 247 | # [`less_than`, `at_least`]. 248 | # """ 249 | # num_sentences = Constraints_util.count_sentences(value) 250 | # if self._comparison_relation == _COMPARISON_RELATION[0]: 251 | # return num_sentences < self._num_sentences_threshold 252 | # elif self._comparison_relation == _COMPARISON_RELATION[1]: 253 | # return num_sentences >= self._num_sentences_threshold 254 | 255 | 256 | # class PlaceholderChecker(Constraint): 257 | # """Check the placeholders in template writing.""" 258 | 259 | # def build_description(self, *, num_placeholders=None): 260 | # """Build the Constraint description. 261 | 262 | # Args: 263 | # num_placeholders: An integer denoting the minimum number of 264 | # placeholders required in the response. 265 | 266 | # Returns: 267 | # A string representing the Constraint description. 268 | # """ 269 | # self._num_placeholders = num_placeholders 270 | # if self._num_placeholders is None or self._num_placeholders < 0: 271 | # self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) 272 | # self._description_pattern = ( 273 | # "The response must contain at least {num_placeholders} placeholders " 274 | # + "represented by square brackets, such as [address]." 275 | # ) 276 | # return self._description_pattern.format(num_placeholders=self._num_placeholders) 277 | 278 | # def get_constraint_args(self): 279 | # """Returns the keyword args of `build_description`.""" 280 | # return {"num_placeholders": self._num_placeholders} 281 | 282 | # def get_constraint_args_keys(self): 283 | # """Returns the args keys of `build_description`.""" 284 | # return ["num_placeholders"] 285 | 286 | # def check_following(self, value): 287 | # """Check if the number of placeholders follows the Constraint. 288 | 289 | # Args: 290 | # value: A string representing the response. 291 | 292 | # Returns: 293 | # True if the actual number of placeholders in the response is greater than 294 | # or equal to `num_placeholders`; otherwise, False. 295 | # """ 296 | # placeholders = re.findall(r"\[.*?\]", value) 297 | # num_placeholders = len(placeholders) 298 | # return num_placeholders >= self._num_placeholders 299 | 300 | 301 | class BulletListChecker(Constraint): 302 | """Checks the bullet list in the prompt.""" 303 | 304 | def build_description(self, *, num_bullets=None): 305 | """Build the Constraint description. 306 | 307 | Args: 308 | num_bullets: An integer specifying the exact number of bullet lists 309 | that is required to appear in the response. 310 | 311 | Returns: 312 | A string representing the Constraint description. 313 | """ 314 | self._num_bullets = num_bullets 315 | if self._num_bullets is None or self._num_bullets < 0: 316 | self._num_bullets = random.randint(1, _NUM_BULLETS) 317 | self._description_pattern = ( 318 | "Your answer must contain exactly {num_bullets} bullet points. " 319 | + "Use the markdown bullet points such as:\n" 320 | + "* This is point 1. \n" 321 | + "* This is point 2" 322 | ) 323 | return self._description_pattern.format(num_bullets=self._num_bullets) 324 | 325 | def get_constraint_args(self): 326 | """Returns the keyword args of `build_description`.""" 327 | return {"num_bullets": self._num_bullets} 328 | 329 | def get_constraint_args_keys(self): 330 | """Returns the args keys of `build_description`.""" 331 | return ["num_bullets"] 332 | 333 | def check_following(self, value): 334 | r"""Check if the number of bullet lists meets the requirement. 335 | 336 | Args: 337 | value: A string representing the response. The response is expected to 338 | contain some bullet lists that start with `\*`. 339 | 340 | Returns: 341 | True if the actual number of bullet lists in the response meets the 342 | requirement. 343 | """ 344 | bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) 345 | bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) 346 | num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) 347 | return num_bullet_lists == self._num_bullets 348 | 349 | 350 | # class ConstrainedResponseChecker(Constraint): 351 | # """Checks the constrained response.""" 352 | 353 | # def build_description(self): 354 | # """Build the Constraint description.""" 355 | # # A sequence of string(s) representing the options of the expected response. 356 | # self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS 357 | # self._description_pattern = ( 358 | # "Answer with one of the following options: {response_options}" 359 | # ) 360 | # return self._description_pattern.format( 361 | # response_options=self._constrained_responses 362 | # ) 363 | 364 | # def get_constraint_args(self): 365 | # """Returns the keyword args of `build_description`.""" 366 | # return None 367 | 368 | # def get_constraint_args_keys(self): 369 | # """Returns the args keys of `build_description`.""" 370 | # return [] 371 | 372 | # def check_following(self, value): 373 | # """Checks if the response matches the constrained options. 374 | 375 | # Args: 376 | # value: A string representing the response. 377 | 378 | # Returns: 379 | # True if the actual response contains one of the options in the constrained 380 | # responses; otherwise False. 381 | # """ 382 | # value = value.strip() 383 | # for constrained_response in self._constrained_responses: 384 | # if constrained_response in value: 385 | # return True 386 | # return False 387 | 388 | 389 | # class ConstrainedStartChecker(Constraint): 390 | # """Checks the response start.""" 391 | 392 | # def build_description(self, *, starter=None): 393 | # """Build the Constraint description. 394 | 395 | # Args: 396 | # starter: A string representing the keyword that the response should start 397 | # with. 398 | 399 | # Returns: 400 | # A string representing the Constraint description. 401 | # """ 402 | # self._starter = starter.strip() if isinstance(starter, str) else starter 403 | # if self._starter is None: 404 | # self._starter = random.choice(_STARTER_OPTIONS) 405 | # self._description_pattern = ( 406 | # "During the conversation, when it is your turn, " 407 | # + "please always start with {starter}" 408 | # ) 409 | # return self._description_pattern.format(starter=self._starter) 410 | 411 | # def get_constraint_args(self): 412 | # """Returns the keyword args of `build_description`.""" 413 | # return {"starter": self._starter} 414 | 415 | # def get_constraint_args_keys(self): 416 | # """Returns the args keys of `build_description`.""" 417 | # return ["starter"] 418 | 419 | # def check_following(self, value): 420 | # """Checks if the response starts with the constrained keyword or phrase. 421 | 422 | # Args: 423 | # value: A string representing the response. 424 | 425 | # Returns: 426 | # True if the response starts with the given phrase or keyword that is 427 | # contained in `Constraint_args`; otherwise, False. 428 | # """ 429 | # response_pattern = r"^\s*" + self._starter + r".*$" 430 | # response_with_constrained_start = re.search( 431 | # response_pattern, value, flags=re.MULTILINE 432 | # ) 433 | # return True if response_with_constrained_start else False 434 | 435 | 436 | class HighlightSectionChecker(Constraint): 437 | """Checks the highlighted section.""" 438 | 439 | def build_description(self, *, num_highlights=None): 440 | """Build the Constraint description. 441 | 442 | Args: 443 | num_highlights: An integer specifying the minimum number of highlighted 444 | sections. 445 | 446 | Returns: 447 | A string representing the Constraint description. 448 | """ 449 | self._num_highlights = num_highlights 450 | if self._num_highlights is None or self._num_highlights < 0: 451 | self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) 452 | 453 | self._description_pattern = ( 454 | "Highlight at least {num_highlights} sections in your answer with " 455 | + "markdown, i.e. *highlighted section*." 456 | ) 457 | 458 | return self._description_pattern.format(num_highlights=self._num_highlights) 459 | 460 | def get_constraint_args(self): 461 | """Returns the keyword args of `build_description`.""" 462 | return {"num_highlights": self._num_highlights} 463 | 464 | def get_constraint_args_keys(self): 465 | """Returns the args keys of `build_description`.""" 466 | return ["num_highlights"] 467 | 468 | def check_following(self, value): 469 | """Checks if the number of highlighted sections meets the requirement. 470 | 471 | Args: 472 | value: a string representing the response. The response is expected to 473 | contain highlighted sections in the format of *highlighted*. 474 | 475 | Returns: 476 | True if the actual number of highlighted sections in the format of 477 | *highlighted sections* meets the minimum requirement; otherwise False. 478 | """ 479 | num_highlights = 0 480 | highlights = re.findall(r"\*[^\n\*]*\*", value) 481 | double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) 482 | for highlight in highlights: 483 | if highlight.strip("*").strip(): 484 | num_highlights += 1 485 | for highlight in double_highlights: 486 | if highlight.removeprefix("**").removesuffix("**").strip(): 487 | num_highlights += 1 488 | 489 | return num_highlights >= self._num_highlights 490 | 491 | 492 | class SectionChecker(Constraint): 493 | """Checks the sections.""" 494 | 495 | def build_description(self, *, section_spliter=None, num_sections=None): 496 | """Build the Constraint description. 497 | 498 | Args: 499 | section_spliter: A string represents the section spliter keyword that 500 | marks a new section, i.e., `Section` or `SECTION`. 501 | num_sections: An integer specifying the number of sections. 502 | 503 | Returns: 504 | A string representing the Constraint description. 505 | """ 506 | self._section_spliter = ( 507 | section_spliter.strip() 508 | if isinstance(section_spliter, str) 509 | else section_spliter 510 | ) 511 | if self._section_spliter is None: 512 | self._section_spliter = random.choice(_SECTION_SPLITER) 513 | 514 | self._num_sections = num_sections 515 | if self._num_sections is None or self._num_sections < 0: 516 | self._num_sections = random.randint(1, _NUM_SECTIONS) 517 | 518 | self._description_pattern = ( 519 | "Your response must have {num_sections} sections. Mark the beginning " 520 | + "of each section with {section_spliter} X, such as:\n" 521 | + "{section_spliter} 1\n" 522 | + "[content of section 1]\n" 523 | + "{section_spliter} 2\n" 524 | + "[content of section 2]" 525 | ) 526 | 527 | return self._description_pattern.format( 528 | num_sections=self._num_sections, section_spliter=self._section_spliter 529 | ) 530 | 531 | def get_constraint_args(self): 532 | """Returns the keyword args of `build_description`.""" 533 | return { 534 | "section_spliter": self._section_spliter, 535 | "num_sections": self._num_sections, 536 | } 537 | 538 | def get_constraint_args_keys(self): 539 | """Returns the args keys of `build_description`.""" 540 | return ["section_spliter", "num_sections"] 541 | 542 | def check_following(self, value): 543 | """Checks the response contains multiple sections. 544 | 545 | Args: 546 | value: A string representing the response. The response is expected 547 | to contain multiple sections (number of sections is greater than 1). 548 | A new section starts with `Section 1`, where the number denotes the 549 | section index. 550 | 551 | Returns: 552 | True if the number of sections in the response is greater than or equal to 553 | the minimum number of sections; otherwise, False. 554 | """ 555 | section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" 556 | sections = re.split(section_splitter_patten, value) 557 | num_sections = len(sections) - 1 558 | return num_sections >= self._num_sections 559 | 560 | 561 | # class ParagraphChecker(Constraint): 562 | # """Checks the paragraphs.""" 563 | 564 | # def build_description(self, *, num_paragraphs=None): 565 | # """Build the Constraint description. 566 | 567 | # Args: 568 | # num_paragraphs: An integer specifying the number of paragraphs. 569 | 570 | # Returns: 571 | # A string representing the Constraint description. 572 | # """ 573 | # self._num_paragraphs = num_paragraphs 574 | # if self._num_paragraphs is None or self._num_paragraphs < 0: 575 | # self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) 576 | 577 | # self._description_pattern = ( 578 | # "There should be {num_paragraphs} paragraphs. " 579 | # + "Paragraphs are separated with the markdown divider: ***" 580 | # ) 581 | 582 | # return self._description_pattern.format(num_paragraphs=self._num_paragraphs) 583 | 584 | # def get_constraint_args(self): 585 | # """Returns the keyword args of `build_description`.""" 586 | # return {"num_paragraphs": self._num_paragraphs} 587 | 588 | # def get_constraint_args_keys(self): 589 | # """Returns the args keys of `build_description`.""" 590 | # return ["num_paragraphs"] 591 | 592 | # def check_following(self, value): 593 | # """Checks the response contains required number of paragraphs. 594 | 595 | # Args: 596 | # value: A string representing the response. The response may contain 597 | # paragraphs that are separated by the markdown divider: `***`. 598 | 599 | # Returns: 600 | # True if the actual number of paragraphs is the same as required; 601 | # otherwise, False. 602 | # """ 603 | # paragraphs = re.split(r"\s?\*\*\*\s?", value) 604 | # num_paragraphs = len(paragraphs) 605 | 606 | # for index, paragraph in enumerate(paragraphs): 607 | # if not paragraph.strip(): 608 | # if index == 0 or index == len(paragraphs) - 1: 609 | # num_paragraphs -= 1 610 | # else: 611 | # return False 612 | 613 | # return num_paragraphs == self._num_paragraphs 614 | 615 | 616 | # class PostscriptChecker(Constraint): 617 | # """Checks the postscript.""" 618 | 619 | # def build_description(self, *, postscript_marker=None): 620 | # """Build the Constraint description. 621 | 622 | # Args: 623 | # postscript_marker: A string containing the keyword that marks the start 624 | # of the postscript section. 625 | 626 | # Returns: 627 | # A string representing the Constraint description. 628 | # """ 629 | # self._postscript_marker = ( 630 | # postscript_marker.strip() 631 | # if isinstance(postscript_marker, str) 632 | # else postscript_marker 633 | # ) 634 | # if self._postscript_marker is None: 635 | # self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) 636 | 637 | # self._description_pattern = ( 638 | # "At the end of your response, please explicitly add a postscript " 639 | # + "starting with {postscript}" 640 | # ) 641 | 642 | # return self._description_pattern.format(postscript=self._postscript_marker) 643 | 644 | # def get_constraint_args(self): 645 | # """Returns the keyword args of `build_description`.""" 646 | # return {"postscript_marker": self._postscript_marker} 647 | 648 | # def get_constraint_args_keys(self): 649 | # """Returns the args keys of `build_description`.""" 650 | # return ["postscript_marker"] 651 | 652 | # def check_following(self, value): 653 | # """Checks if the response follows the postscript format. 654 | 655 | # Args: 656 | # value: a string representing the response. The response is expected to 657 | # contain a postscript section. 658 | 659 | # Returns: 660 | # True if the response contains a postscript section starting with 661 | # the keyword containing in the `Constraint_args`; otherwise False. 662 | # """ 663 | # value = value.lower() 664 | # if self._postscript_marker == "P.P.S": 665 | # postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" 666 | # elif self._postscript_marker == "P.S.": 667 | # postscript_pattern = r"\s*p\.\s?s\..*$" 668 | # else: 669 | # postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" 670 | # postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) 671 | # return True if postscript else False 672 | 673 | 674 | # class RephraseChecker(Constraint): 675 | # """Checks the rephrase.""" 676 | 677 | # def build_description(self, *, original_message): 678 | # """Build the Constraint description. 679 | 680 | # Args: 681 | # original_message: A string representing the original message. The 682 | # rephrased response should only change its words/sentences in between 683 | # its two asterisks, for example, *change me*. Both original and rephrased 684 | # messages should contain the changes in the form of *change me*. 685 | 686 | # Returns: 687 | # A string representing the Constraint description. 688 | # """ 689 | # if not self.is_change(original_message): 690 | # raise ValueError( 691 | # f"Message {original_message} does not contain changes " 692 | # "in the form of *change me*." 693 | # ) 694 | 695 | # self._reference_without_change = original_message 696 | # self._description = ( 697 | # "Rephrasing: Your rephrased response should only" 698 | # + "change the words/sentences in between two asterisks" 699 | # + "such as *change me*." 700 | # ) 701 | # return self._description 702 | 703 | # def get_constraint_args(self): 704 | # """Returns the keyword args of `build_description`.""" 705 | # return {"original_message": self._reference_without_change} 706 | 707 | # def get_constraint_args_keys(self): 708 | # """Returns the args keys of `build_description`.""" 709 | # return ["original_message"] 710 | 711 | # def check_following(self, value): 712 | # r"""Checks if the rephrasing follows the Constraint. 713 | 714 | # Args: 715 | # value: A string representing the response, which is expected to rephras 716 | # the string of `Constraint_args`. 717 | 718 | # Returns: 719 | # True if `value` and `Constraint_args` only differ by the words/sentences 720 | # in between two asterisks such as *change me*; otherwise, False. 721 | # """ 722 | 723 | # if not self.is_change(value): 724 | # raise ValueError( 725 | # f"value {value} does not contain changes in the form of *change me*." 726 | # ) 727 | 728 | # response_without_changes = self.strip_changes(value) 729 | # reference_without_changes = self.strip_changes(self._reference_without_change) 730 | 731 | # return response_without_changes == reference_without_changes 732 | 733 | # def is_change(self, response): 734 | # """Check if there is change in the response in the form of *change me*.""" 735 | # return re.search(r"\*.*\*", response) 736 | 737 | # def strip_changes(self, response): 738 | # """Strips off the changes.""" 739 | # return re.sub(r"\*.*\*", "", response) 740 | 741 | 742 | class KeywordChecker(Constraint): 743 | """Check the exisitence of certain keywords.""" 744 | 745 | def build_description(self, *, keywords=None): 746 | """Build the Constraint description. 747 | 748 | Args: 749 | keywords: A sequence of strings representing the keywords that are 750 | expected in the response. 751 | 752 | Returns: 753 | A string representing the Constraint description. 754 | """ 755 | 756 | if not keywords: 757 | self._keywords = generate_keywords( 758 | num_keywords=_NUM_KEYWORDS 759 | ) 760 | else: 761 | self._keywords = keywords 762 | self._keywords = sorted(self._keywords) 763 | 764 | self._description_pattern = "Include keywords \"{keywords}\" in the response." 765 | 766 | return self._description_pattern.format(keywords=self._keywords) 767 | 768 | def get_constraint_args(self): 769 | """Returns the keyword args of `build_description`.""" 770 | return {"keywords": self._keywords} 771 | 772 | def get_constraint_args_keys(self): 773 | """Returns the args keys of `build_description`.""" 774 | return ["keywords"] 775 | 776 | def check_following(self, value): 777 | """Check if the response contain the expected keywords.""" 778 | for keyword in self._keywords: 779 | if not re.search(keyword, value, flags=re.IGNORECASE): 780 | return False 781 | return True 782 | 783 | 784 | class KeywordFrequencyChecker(Constraint): 785 | """Check the keyword frequency.""" 786 | 787 | def build_description(self, *, keyword=None, frequency=None, relation=None): 788 | """Build the Constraint description. 789 | 790 | Args: 791 | keyword: A string representing a keyword that is expected in the response. 792 | frequency: An integer specifying the number of times `keyword` is expected 793 | to appear in the response. 794 | relation: A string in (`less than`, `at least`), defining the relational 795 | operator for comparison. 796 | Two relational comparisons are supported for now: 797 | if 'less than', the actual number of occurrences < frequency; 798 | if 'at least', the actual number of occurrences >= frequency. 799 | 800 | Returns: 801 | A string representing the Constraint description. 802 | """ 803 | if not keyword: 804 | self._keyword = generate_keywords(num_keywords=1)[0] 805 | else: 806 | self._keyword = keyword.strip() 807 | 808 | self._frequency = frequency 809 | if self._frequency is None or self._frequency < 0: 810 | self._frequency = random.randint(1, _KEYWORD_FREQUENCY) 811 | 812 | if relation is None: 813 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 814 | elif relation not in _COMPARISON_RELATION: 815 | raise ValueError( 816 | "The supported relation for comparison must be in " 817 | f"{_COMPARISON_RELATION}, but {relation} is given." 818 | ) 819 | else: 820 | self._comparison_relation = relation 821 | 822 | self._description_pattern = ( 823 | "In your response, the word \"{keyword}\" should appear {relation} " 824 | + "{frequency} times." 825 | ) 826 | 827 | return self._description_pattern.format( 828 | keyword=self._keyword, 829 | relation=self._comparison_relation, 830 | frequency=self._frequency, 831 | ) 832 | 833 | def get_constraint_args(self): 834 | """Returns the keyword args of `build_description`.""" 835 | return { 836 | "keyword": self._keyword, 837 | "frequency": self._frequency, 838 | "relation": self._comparison_relation, 839 | } 840 | 841 | def get_constraint_args_keys(self): 842 | """Returns the args keys of `build_description`.""" 843 | return ["keyword", "frequency", "relation"] 844 | 845 | def check_following(self, value): 846 | """Checks if the response contain the keyword with required frequency.""" 847 | actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) 848 | 849 | if self._comparison_relation == _COMPARISON_RELATION[0]: 850 | return actual_occurrences < self._frequency 851 | elif self._comparison_relation == _COMPARISON_RELATION[1]: 852 | return actual_occurrences >= self._frequency 853 | 854 | 855 | class NumberOfWords(Constraint): 856 | """Checks the number of words.""" 857 | 858 | def build_description(self, *, num_words=None, relation=None): 859 | """Build the Constraint description. 860 | 861 | Args: 862 | num_words: An integer specifying the number of words contained in the 863 | response. 864 | relation: A string in (`less than`, `at least`), defining the relational 865 | operator for comparison. 866 | Two relational comparisons are supported for now: 867 | if 'less than', the actual number of words < num_words; 868 | if 'at least', the actual number of words >= num_words. 869 | 870 | Returns: 871 | A string representing the Constraint description. 872 | """ 873 | 874 | self._num_words = num_words 875 | if self._num_words is None or self._num_words < 0: 876 | self._num_words = random.randint( 877 | NUM_WORDS_LOWER_LIMIT, NUM_WORDS_UPPER_LIMIT 878 | ) 879 | 880 | if relation is None: 881 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 882 | elif relation not in _COMPARISON_RELATION: 883 | raise ValueError( 884 | "The supported relation for comparison must be in " 885 | f"{_COMPARISON_RELATION}, but {relation} is given." 886 | ) 887 | else: 888 | self._comparison_relation = relation 889 | 890 | self._description_pattern = "Answer with {relation} {num_words} words." 891 | 892 | return self._description_pattern.format( 893 | relation=self._comparison_relation, num_words=self._num_words 894 | ) 895 | 896 | def get_constraint_args(self): 897 | """Returns the keyword args of `build_description`.""" 898 | return {"num_words": self._num_words, "relation": self._comparison_relation} 899 | 900 | def get_constraint_args_keys(self): 901 | """Returns the args keys of `build_description`.""" 902 | return ["num_words", "relation"] 903 | 904 | def check_following(self, value): 905 | """Checks if the response contains the expected number of words.""" 906 | num_words = count_words(value) 907 | 908 | if self._comparison_relation == _COMPARISON_RELATION[0]: 909 | return num_words < self._num_words 910 | elif self._comparison_relation == _COMPARISON_RELATION[1]: 911 | return num_words >= self._num_words 912 | 913 | 914 | # class JsonFormat(Constraint): 915 | # """Check the Json format.""" 916 | 917 | # def build_description(self): 918 | # self._description_pattern = ( 919 | # "Entire output should be wrapped in JSON format. You can use markdown" 920 | # " ticks such as ```." 921 | # ) 922 | # return self._description_pattern 923 | 924 | # def get_constraint_args(self): 925 | # """Returns the keyword args of `build_description`.""" 926 | # return None 927 | 928 | # def get_constraint_args_keys(self): 929 | # """Returns the args keys of `build_description`.""" 930 | # return [] 931 | 932 | # def check_following(self, value): 933 | # value = ( 934 | # value.strip() 935 | # .removeprefix("```json") 936 | # .removeprefix("```Json") 937 | # .removeprefix("```JSON") 938 | # .removeprefix("```") 939 | # .removesuffix("```") 940 | # .strip() 941 | # ) 942 | # try: 943 | # json.loads(value) 944 | # except ValueError: 945 | # return False 946 | # return True 947 | 948 | 949 | # class ParagraphFirstWordCheck(Constraint): 950 | # """Check the paragraph and the first word of the nth paragraph.""" 951 | 952 | # def build_description( 953 | # self, num_paragraphs=None, nth_paragraph=None, first_word=None 954 | # ): 955 | # r"""Build the Constraint description. 956 | 957 | # Args: 958 | # num_paragraphs: An integer indicating the number of paragraphs expected 959 | # in the response. A paragraph is a subset of the string that is 960 | # expected to be separated by '\n\n'. 961 | # nth_paragraph: An integer indicating the paragraph number that we look at. 962 | # Note that n starts from 1. 963 | # first_word: A string that represent the first word of the bth paragraph. 964 | 965 | # Returns: 966 | # A string representing the Constraint description. 967 | # """ 968 | # self._num_paragraphs = num_paragraphs 969 | # if self._num_paragraphs is None or self._num_paragraphs < 0: 970 | # self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) 971 | 972 | # self._nth_paragraph = nth_paragraph 973 | # if ( 974 | # self._nth_paragraph is None 975 | # or self._nth_paragraph <= 0 976 | # or self._nth_paragraph > self._num_paragraphs 977 | # ): 978 | # self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) 979 | 980 | # self._first_word = first_word 981 | # if self._first_word is None: 982 | # self._first_word = generate_keywords(num_keywords=1)[0] 983 | # self._first_word = self._first_word.lower() 984 | 985 | # self._description_pattern = ( 986 | # "There should be {num_paragraphs} paragraphs. " 987 | # + "Paragraphs and only paragraphs are separated with each other by two " 988 | # + "new lines as if it was '\\n\\n' in python. " 989 | # + "Paragraph {nth_paragraph} must start with word {first_word}." 990 | # ) 991 | 992 | # return self._description_pattern.format( 993 | # num_paragraphs=self._num_paragraphs, 994 | # nth_paragraph=self._nth_paragraph, 995 | # first_word=self._first_word, 996 | # ) 997 | 998 | # def get_constraint_args(self): 999 | # """Returns the keyword args of `build_description`.""" 1000 | # return { 1001 | # "num_paragraphs": self._num_paragraphs, 1002 | # "nth_paragraph": self._nth_paragraph, 1003 | # "first_word": self._first_word, 1004 | # } 1005 | 1006 | # def get_constraint_args_keys(self): 1007 | # """Returns the args keys of `build_description`.""" 1008 | # return ["num_paragraphs", "nth_paragraph", "first_word"] 1009 | 1010 | # def check_following(self, value): 1011 | # """Checks for required number of paragraphs and correct first word. 1012 | 1013 | # Args: 1014 | # value: a string representing the response. The response may contain 1015 | # paragraphs that are separated by two new lines and the first word of 1016 | # the nth paragraph will have to match a specified word. 1017 | 1018 | # Returns: 1019 | # True if the number of paragraphs is the same as required and the first 1020 | # word of the specified paragraph is the same as required. Otherwise, false. 1021 | # """ 1022 | 1023 | # paragraphs = re.split(r"\n\n", value) 1024 | # num_paragraphs = len(paragraphs) 1025 | 1026 | # for paragraph in paragraphs: 1027 | # if not paragraph.strip(): 1028 | # num_paragraphs -= 1 1029 | 1030 | # # check that index doesn't go out of bounds 1031 | # if self._nth_paragraph <= num_paragraphs: 1032 | # paragraph = paragraphs[self._nth_paragraph - 1].strip() 1033 | # if not paragraph: 1034 | # return False 1035 | # else: 1036 | # return False 1037 | 1038 | # first_word = "" 1039 | # punctuation = {".", ",", "?", "!", "'", '"'} 1040 | 1041 | # # get first word and remove punctuation 1042 | # word = paragraph.split()[0].strip() 1043 | # # TODO(jeffrey): make more complex? 1044 | # word = word.lstrip("'") 1045 | # word = word.lstrip('"') 1046 | 1047 | # for letter in word: 1048 | # if letter in punctuation: 1049 | # break 1050 | # first_word += letter.lower() 1051 | 1052 | # return num_paragraphs == self._num_paragraphs and first_word == self._first_word 1053 | 1054 | 1055 | # # TODO(jeffrey) add relation - at least/at most? 1056 | # class KeySentenceChecker(Constraint): 1057 | # """Check the existence of certain key sentences.""" 1058 | 1059 | # def build_description(self, key_sentences=None, num_sentences=None): 1060 | # """Build the Constraint description. 1061 | 1062 | # Args: 1063 | # key_sentences: A sequences of strings representing the key sentences that 1064 | # are expected in the response. 1065 | # num_sentences: The number of key sentences that are expected to be seen in 1066 | # the response. 1067 | 1068 | # Returns: 1069 | # A string representing the Constraint description. 1070 | # """ 1071 | 1072 | # if not key_sentences: 1073 | # # TODO(jeffrey) make a generate sentences function? wonderwords package 1074 | # self._key_sentences = set(["For now, this is fine."]) 1075 | # else: 1076 | # self._key_sentences = key_sentences 1077 | 1078 | # if not num_sentences: 1079 | # self._num_sentences = random.randint(1, len(self._key_sentences)) 1080 | # else: 1081 | # self._num_sentences = num_sentences 1082 | 1083 | # self._description_pattern = ( 1084 | # "Include {num_sentences} of the following sentences {key_sentences}" 1085 | # ) 1086 | 1087 | # return self._description_pattern.format( 1088 | # num_sentences=self._num_sentences, key_sentences=self._key_sentences 1089 | # ) 1090 | 1091 | # def get_constraint_args(self): 1092 | # """Returns the keyword args of `build_description`.""" 1093 | # return { 1094 | # "num_sentences": self._num_sentences, 1095 | # "key_sentences": list(self._key_sentences), 1096 | # } 1097 | 1098 | # def get_constraint_args_keys(self): 1099 | # """Returns the args keys of `build_description`.""" 1100 | # return ["num_sentences", "key_sentences"] 1101 | 1102 | # def check_following(self, value): 1103 | # """Checks if the response contains the expected key sentences.""" 1104 | # count = 0 1105 | # sentences = Constraints_util.split_into_sentences(value) 1106 | # for sentence in self._key_sentences: 1107 | # if sentence in sentences: 1108 | # count += 1 1109 | 1110 | # return count == self._num_sentences 1111 | 1112 | 1113 | class ForbiddenWords(Constraint): 1114 | """Checks that specified words are not used in response.""" 1115 | 1116 | def build_description(self, forbidden_words=None): 1117 | """Build the Constraint description. 1118 | 1119 | Args: 1120 | forbidden_words: A sequences of strings representing words that are not 1121 | allowed in the response. 1122 | 1123 | Returns: 1124 | A string representing the Constraint description. 1125 | """ 1126 | 1127 | if not forbidden_words: 1128 | self._forbidden_words = generate_keywords( 1129 | num_keywords=_NUM_KEYWORDS 1130 | ) 1131 | else: 1132 | self._forbidden_words = list(set(forbidden_words)) 1133 | self._forbidden_words = sorted(self._forbidden_words) 1134 | self._description_pattern = ( 1135 | "Do not include keywords \"{forbidden_words}\" in the response." 1136 | ) 1137 | 1138 | return self._description_pattern.format(forbidden_words=self._forbidden_words) 1139 | 1140 | def get_constraint_args(self): 1141 | """Returns the keyword args of `build_description`.""" 1142 | return {"forbidden_words": self._forbidden_words} 1143 | 1144 | def get_constraint_args_keys(self): 1145 | """Returns the args keys of `build_description`.""" 1146 | return ["forbidden_words"] 1147 | 1148 | def check_following(self, value): 1149 | """Check if the response does not contain the expected keywords.""" 1150 | for word in self._forbidden_words: 1151 | if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): 1152 | return False 1153 | return True 1154 | 1155 | 1156 | # class RephraseParagraph(Constraint): 1157 | # """Checks that the paragraph is rephrased.""" 1158 | 1159 | # def build_description(self, *, original_paragraph, low, high): 1160 | # """Builds the Constraint description. 1161 | 1162 | # Args: 1163 | # original_paragraph: A string presenting the original paragraph. The 1164 | # rephrases response should have betweeb low-high words in common. 1165 | # low: An integer presenting the lower bound of similar words. 1166 | # high: An integer representing the upper bound of similar words. 1167 | 1168 | # Returns: 1169 | # A string representing the Constraint description. 1170 | # """ 1171 | # # TODO(jeffrey) make more encompassing 1172 | # self._original_paragraph = original_paragraph 1173 | # self._low = low 1174 | # self._high = high 1175 | 1176 | # self._description = ( 1177 | # "Rephrase the following paragraph: " 1178 | # + "{original_paragraph}\nYour response should have " 1179 | # + "between {low} and {high} of the same words. " 1180 | # + "Words are the same if and only if all of the " 1181 | # + "letters, ignoring cases, are the same. For " 1182 | # + "example, 'run' is the same as 'Run' but different " 1183 | # + "to 'ran'." 1184 | # ) 1185 | 1186 | # return self._description.format( 1187 | # original_paragraph=original_paragraph, low=self._low, high=self._high 1188 | # ) 1189 | 1190 | # def get_constraint_args(self): 1191 | # """Returns the keyword args of `build_description`.""" 1192 | # return { 1193 | # "original_paragraph": self._original_paragraph, 1194 | # "low": self._low, 1195 | # "high": self._high, 1196 | # } 1197 | 1198 | # def get_constraint_args_keys(self): 1199 | # """Returns the args keys of `build_description`.""" 1200 | # return ["original_paragraph", "low", "high"] 1201 | 1202 | # def check_following(self, value): 1203 | # val_words = re.findall(r"\w+", value.lower()) 1204 | # original_words = re.findall(r"\w+", self._original_paragraph.lower()) 1205 | # similar_words = 0 1206 | 1207 | # dict_val = collections.Counter(val_words) 1208 | # dict_original = collections.Counter(original_words) 1209 | 1210 | # for word in dict_original: 1211 | # similar_words += min(dict_original[word], dict_val[word]) 1212 | 1213 | # return similar_words >= self._low and similar_words <= self._high 1214 | 1215 | 1216 | # class TwoResponsesChecker(Constraint): 1217 | # """Check that two responses were given.""" 1218 | 1219 | # def build_description(self): 1220 | # """Build the Constraint description.""" 1221 | # self._description_pattern = ( 1222 | # "Give two different responses. Responses and only responses should" 1223 | # " be separated by 6 asterisk symbols: ******." 1224 | # ) 1225 | # return self._description_pattern 1226 | 1227 | # def get_constraint_args(self): 1228 | # """Returns the keyword args of `build_description`.""" 1229 | # return None 1230 | 1231 | # def get_constraint_args_keys(self): 1232 | # """Returns the args keys of `build_description`.""" 1233 | # return [] 1234 | 1235 | # def check_following(self, value): 1236 | # """Checks if the response has two different answers. 1237 | 1238 | # Args: 1239 | # value: A string representing the response. 1240 | 1241 | # Returns: 1242 | # True if two responses are detected and false otherwise. 1243 | # """ 1244 | # valid_responses = list() 1245 | # responses = value.split("******") 1246 | # for index, response in enumerate(responses): 1247 | # if not response.strip(): 1248 | # if index != 0 and index != len(responses) - 1: 1249 | # return False 1250 | # else: 1251 | # valid_responses.append(response) 1252 | # return ( 1253 | # len(valid_responses) == 2 1254 | # and valid_responses[0].strip() != valid_responses[1].strip() 1255 | # ) 1256 | 1257 | 1258 | class RepeatPromptThenAnswer(Constraint): 1259 | """Checks that Prompt is first repeated then answered.""" 1260 | 1261 | def build_description(self, *, prompt_to_repeat=None): 1262 | """Build the Constraint description. 1263 | 1264 | Args: 1265 | prompt_to_repeat: The prompt that is meant to be repeated. 1266 | 1267 | Returns: 1268 | A string representing the Constraint description. 1269 | """ 1270 | if not prompt_to_repeat: 1271 | raise ValueError("prompt_to_repeat must be set.") 1272 | else: 1273 | self._prompt_to_repeat = prompt_to_repeat 1274 | self._description_pattern = ( 1275 | "First repeat the request word for word without change," 1276 | " then give your answer (1. do not say any words or characters" 1277 | " before repeating the request; 2. the request you need to repeat" 1278 | " does not include this sentence)" 1279 | ) 1280 | return self._description_pattern 1281 | 1282 | def get_constraint_args(self): 1283 | return {"prompt_to_repeat": self._prompt_to_repeat} 1284 | 1285 | def get_constraint_args_keys(self): 1286 | """Returns the args keys of `build_description`.""" 1287 | return ["prompt_to_repeat"] 1288 | 1289 | def check_following(self, value): 1290 | if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): 1291 | return True 1292 | return False 1293 | 1294 | 1295 | class EndChecker(Constraint): 1296 | """Checks that the prompt ends with a given phrase.""" 1297 | 1298 | def build_description(self, *, end_phrase=None): 1299 | """Build the Constraint description. 1300 | 1301 | Args: 1302 | end_phrase: A string representing the phrase the response should end with. 1303 | 1304 | Returns: 1305 | A string representing the Constraint description. 1306 | """ 1307 | self._end_phrase = ( 1308 | end_phrase.strip() if isinstance(end_phrase, str) else end_phrase 1309 | ) 1310 | if self._end_phrase is None: 1311 | self._end_phrase = random.choice(_ENDING_OPTIONS) 1312 | self._description_pattern = ( 1313 | "Finish your response with this exact phrase \"{ender}\". " 1314 | "No other words should follow this phrase." 1315 | ) 1316 | return self._description_pattern.format(ender=self._end_phrase) 1317 | 1318 | def get_constraint_args(self): 1319 | return {"end_phrase": self._end_phrase} 1320 | 1321 | def get_constraint_args_keys(self): 1322 | """Returns the args keys of `build_description`.""" 1323 | return ["end_phrase"] 1324 | 1325 | def check_following(self, value): 1326 | """Checks if the response ends with the expected phrase.""" 1327 | value = value.strip().strip('"').lower() 1328 | self._end_phrase = self._end_phrase.strip().lower() 1329 | return value.endswith(self._end_phrase) 1330 | 1331 | 1332 | # class TitleChecker(Constraint): 1333 | # """Checks the response for a title.""" 1334 | 1335 | # def build_description(self): 1336 | # """Build the Constraint description.""" 1337 | # self._description_pattern = ( 1338 | # "Your answer must contain a title, wrapped in double angular brackets," 1339 | # " such as <>." 1340 | # ) 1341 | # return self._description_pattern 1342 | 1343 | # def get_constraint_args(self): 1344 | # return None 1345 | 1346 | # def get_constraint_args_keys(self): 1347 | # """Returns the args keys of `build_description`.""" 1348 | # return [] 1349 | 1350 | # def check_following(self, value): 1351 | # """Checks if the response contains a title.""" 1352 | # pattern = r"<<[^\n]+>>" 1353 | # re_pattern = re.compile(pattern) 1354 | # titles = re.findall(re_pattern, value) 1355 | 1356 | # for title in titles: 1357 | # if title.lstrip("<").rstrip(">").strip(): 1358 | # return True 1359 | # return False 1360 | 1361 | 1362 | # class LetterFrequencyChecker(Constraint): 1363 | # """Checks letter frequency.""" 1364 | 1365 | # def build_description(self, *, letter=None, let_frequency=None, let_relation=None): 1366 | # """Build the Constraint description. 1367 | 1368 | # Args: 1369 | # letter: A string representing a letter that is expected in the response. 1370 | # let_frequency: An integer specifying the number of times `keyword` is 1371 | # expected to appear in the response. 1372 | # let_relation: A string in (`less than`, `at least`), defining the 1373 | # relational operator for comparison. Two relational comparisons are 1374 | # supported for now; if 'less than', the actual number of 1375 | # occurrences < frequency; if 'at least', the actual number of 1376 | # occurrences >= frequency. 1377 | 1378 | # Returns: 1379 | # A string representing the Constraint description. 1380 | # """ 1381 | # if ( 1382 | # not letter 1383 | # or len(letter) > 1 1384 | # or ord(letter.lower()) < 97 1385 | # or ord(letter.lower()) > 122 1386 | # ): 1387 | # self._letter = random.choice(list(string.ascii_letters)) 1388 | # else: 1389 | # self._letter = letter.strip() 1390 | # self._letter = self._letter.lower() 1391 | 1392 | # self._frequency = let_frequency 1393 | # if self._frequency is None or self._frequency < 0: 1394 | # self._frequency = random.randint(1, _LETTER_FREQUENCY) 1395 | 1396 | # if let_relation is None: 1397 | # self._comparison_relation = random.choice(_COMPARISON_RELATION) 1398 | # elif let_relation not in _COMPARISON_RELATION: 1399 | # raise ValueError( 1400 | # "The supported relation for comparison must be in " 1401 | # f"{_COMPARISON_RELATION}, but {let_relation} is given." 1402 | # ) 1403 | # else: 1404 | # self._comparison_relation = let_relation 1405 | 1406 | # self._description_pattern = ( 1407 | # "In your response, the letter {letter} should appear {let_relation}" 1408 | # " {let_frequency} times." 1409 | # ) 1410 | 1411 | # return self._description_pattern.format( 1412 | # letter=self._letter, 1413 | # let_frequency=self._frequency, 1414 | # let_relation=self._comparison_relation, 1415 | # ) 1416 | 1417 | # def get_constraint_args(self): 1418 | # """Returns the keyword args of build description.""" 1419 | # return { 1420 | # "letter": self._letter, 1421 | # "let_frequency": self._frequency, 1422 | # "let_relation": self._comparison_relation, 1423 | # } 1424 | 1425 | # def get_constraint_args_keys(self): 1426 | # """Returns the args keys of `build_description`.""" 1427 | # return ["letter", "let_frequency", "let_relation"] 1428 | 1429 | # def check_following(self, value): 1430 | # """Checks that the response contains the letter at the right frequency.""" 1431 | # value = value.lower() 1432 | # letters = collections.Counter(value) 1433 | 1434 | # if self._comparison_relation == _COMPARISON_RELATION[0]: 1435 | # return letters[self._letter] < self._frequency 1436 | # else: 1437 | # return letters[self._letter] >= self._frequency 1438 | 1439 | 1440 | class CapitalLettersEnglishChecker(Constraint): 1441 | """Checks that the response is in english and is in all capital letters.""" 1442 | 1443 | def build_description(self): 1444 | """Build the Constraint description.""" 1445 | self._description_pattern = ( 1446 | "Your entire response should be in English, and in all capital letters." 1447 | ) 1448 | return self._description_pattern 1449 | 1450 | def get_constraint_args(self): 1451 | return None 1452 | 1453 | def get_constraint_args_keys(self): 1454 | """Returns the args keys of `build_description`.""" 1455 | return [] 1456 | 1457 | def check_following(self, value): 1458 | """Checks that the response is in English and in all capital letters.""" 1459 | assert isinstance(value, str) 1460 | 1461 | try: 1462 | return value.isupper() and langdetect.detect(value) == "en" 1463 | except langdetect.LangDetectException as e: 1464 | # Count as Constraint is followed. 1465 | logging.error( 1466 | "Unable to detect language for text %s due to %s", value, e 1467 | ) # refex: disable=pytotw.037 1468 | return True 1469 | 1470 | 1471 | class LowercaseLettersEnglishChecker(Constraint): 1472 | """Checks that the response is in english and is in all lowercase letters.""" 1473 | 1474 | def build_description(self): 1475 | """Build the Constraint description.""" 1476 | self._description_pattern = ( 1477 | "Your entire response should be in English, and in all lowercase" 1478 | " letters. No capital letters are allowed." 1479 | ) 1480 | return self._description_pattern 1481 | 1482 | def get_constraint_args(self): 1483 | return None 1484 | 1485 | def get_constraint_args_keys(self): 1486 | """Returns the args keys of `build_description`.""" 1487 | return [] 1488 | 1489 | def check_following(self, value): 1490 | """Checks that the response is in English and in all lowercase letters.""" 1491 | assert isinstance(value, str) 1492 | 1493 | try: 1494 | return value.islower() and langdetect.detect(value) == "en" 1495 | except langdetect.LangDetectException as e: 1496 | # Count as Constraint is followed. 1497 | logging.error( 1498 | "Unable to detect language for text %s due to %s", value, e 1499 | ) # refex: disable=pytotw.037 1500 | return True 1501 | 1502 | 1503 | class CommaChecker(Constraint): 1504 | """Checks the response for no commas.""" 1505 | 1506 | def build_description(self): 1507 | """Build the Constraint description.""" 1508 | self._description_pattern = ( 1509 | "In your entire response, refrain from the use of any commas." 1510 | ) 1511 | return self._description_pattern 1512 | 1513 | def get_constraint_args(self): 1514 | return None 1515 | 1516 | def get_constraint_args_keys(self): 1517 | """Returns the args keys of `build_description`.""" 1518 | return [] 1519 | 1520 | def check_following(self, value): 1521 | """Checks that the response does not contain commas.""" 1522 | return not re.search(r"\,", value) 1523 | 1524 | 1525 | class CapitalWordFrequencyChecker(Constraint): 1526 | """Checks frequency of words with all capital letters.""" 1527 | 1528 | def build_description( 1529 | self, 1530 | capital_frequency=None, 1531 | capital_relation=None, 1532 | ): 1533 | """Build the Constraint description. 1534 | 1535 | Args: 1536 | capital_frequency: An integer that represents the number of words that 1537 | should be in all capital letters. 1538 | capital_relation: A string that is 'at least' or 'at most' that refers to 1539 | the frequency. 1540 | 1541 | Returns: 1542 | A string representing the Constraint description. 1543 | """ 1544 | self._frequency = capital_frequency 1545 | if self._frequency is None: 1546 | self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) 1547 | 1548 | self._comparison_relation = capital_relation 1549 | if capital_relation is None: 1550 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 1551 | elif capital_relation not in _COMPARISON_RELATION: 1552 | raise ValueError( 1553 | "The supported relation for comparison must be in " 1554 | f"{_COMPARISON_RELATION}, but {capital_relation} is given." 1555 | ) 1556 | 1557 | self._description_pattern = ( 1558 | "In your response, words with all capital letters should appear" 1559 | " {relation} {frequency} times." 1560 | ) 1561 | 1562 | return self._description_pattern.format( 1563 | frequency=self._frequency, relation=self._comparison_relation 1564 | ) 1565 | 1566 | def get_constraint_args(self): 1567 | """Returns the keyword args of build description.""" 1568 | return { 1569 | "capital_frequency": self._frequency, 1570 | "capital_relation": self._comparison_relation, 1571 | } 1572 | 1573 | def get_constraint_args_keys(self): 1574 | """Returns the args keys of `build_description`.""" 1575 | return ["capital_frequency", "capital_relation"] 1576 | 1577 | def check_following(self, value): 1578 | """Checks the frequency of words with all capital letters.""" 1579 | # Hyphenated words will count as one word 1580 | words = nltk.word_tokenize(value) 1581 | capital_words = [word for word in words if word.isupper()] 1582 | 1583 | capital_words = len(capital_words) 1584 | 1585 | if self._comparison_relation == _COMPARISON_RELATION[0]: 1586 | return capital_words < self._frequency 1587 | else: 1588 | return capital_words >= self._frequency 1589 | 1590 | 1591 | class QuotationChecker(Constraint): 1592 | """Checks response is wrapped with double quotation marks.""" 1593 | 1594 | def build_description(self): 1595 | """Build the Constraint description.""" 1596 | self._description_pattern = ( 1597 | "Wrap your entire response with double quotation marks. " 1598 | ) 1599 | return self._description_pattern 1600 | 1601 | def get_constraint_args(self): 1602 | """Returns the keyword args of build description.""" 1603 | return None 1604 | 1605 | def get_constraint_args_keys(self): 1606 | """Returns the args keys of `build_description`.""" 1607 | return [] 1608 | 1609 | def check_following(self, value): 1610 | """Checks if the response is wrapped with double quotation marks.""" 1611 | value = value.strip() 1612 | return len(value) > 1 and value[0] == '"' and value[-1] == '"' 1613 | -------------------------------------------------------------------------------- /code/constraint_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Registry of all constraint_checker.""" 16 | 17 | import constraint_checker 18 | 19 | 20 | _KEYWORD = "keywords:" 21 | 22 | _LANGUAGE = "language:" 23 | 24 | _LENGTH = "length_constraint_checkers:" 25 | 26 | _CONTENT = "detectable_content:" 27 | 28 | _FORMAT = "detectable_format:" 29 | 30 | _MULTITURN = "multi-turn:" 31 | 32 | _COMBINATION = "combination:" 33 | 34 | _STARTEND = "startend:" 35 | 36 | _CHANGE_CASES = "change_case:" 37 | 38 | _PUNCTUATION = "punctuation:" 39 | 40 | INSTRUCTION_DICT = { 41 | _KEYWORD + "existence": constraint_checker.KeywordChecker, 42 | _KEYWORD + "frequency": constraint_checker.KeywordFrequencyChecker, 43 | # TODO(jeffreyzhou): make a proper set of sentences to choose from 44 | # _KEYWORD + "key_sentences": constraint_checker.KeySentenceChecker, 45 | _KEYWORD + "forbidden_words": constraint_checker.ForbiddenWords, 46 | # _KEYWORD + "letter_frequency": constraint_checker.LetterFrequencyChecker, 47 | _LANGUAGE + "response_language": constraint_checker.ResponseLanguageChecker, 48 | # _LENGTH + "number_sentences": constraint_checker.NumberOfSentences, 49 | # _LENGTH + "number_paragraphs": constraint_checker.ParagraphChecker, 50 | _LENGTH + "number_words": constraint_checker.NumberOfWords, 51 | # _LENGTH + "nth_paragraph_first_word": constraint_checker.ParagraphFirstWordCheck, 52 | # _CONTENT + "number_placeholders": constraint_checker.PlaceholderChecker, 53 | # _CONTENT + "postscript": constraint_checker.PostscriptChecker, 54 | _FORMAT + "number_bullet_lists": constraint_checker.BulletListChecker, 55 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace 56 | # _CONTENT + "rephrase_paragraph": constraint_checker.RephraseParagraph, 57 | # _FORMAT + "constrained_response": constraint_checker.ConstrainedResponseChecker, 58 | _FORMAT + "number_highlighted_sections": (constraint_checker.HighlightSectionChecker), 59 | _FORMAT + "multiple_sections": constraint_checker.SectionChecker, 60 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. 61 | # _FORMAT + "rephrase": constraint_checker.RephraseChecker, 62 | # _FORMAT + "json_format": constraint_checker.JsonFormat, 63 | # _FORMAT + "title": constraint_checker.TitleChecker, 64 | # TODO(tianjianlu): Re-enable with specific prompts. 65 | # _MULTITURN + "constrained_start": constraint_checker.ConstrainedStartChecker, 66 | # _COMBINATION + "two_responses": constraint_checker.TwoResponsesChecker, 67 | _COMBINATION + "repeat_prompt": constraint_checker.RepeatPromptThenAnswer, 68 | _STARTEND + "end_checker": constraint_checker.EndChecker, 69 | _STARTEND + "quotation": constraint_checker.QuotationChecker, 70 | _CHANGE_CASES + "capital_word_frequency": constraint_checker.CapitalWordFrequencyChecker, 71 | _CHANGE_CASES + "english_capital": constraint_checker.CapitalLettersEnglishChecker, 72 | _CHANGE_CASES + "english_lowercase": constraint_checker.LowercaseLettersEnglishChecker, 73 | _PUNCTUATION + "no_comma": constraint_checker.CommaChecker, 74 | } 75 | 76 | INSTRUCTION_CONFLICTS = { 77 | _KEYWORD + "existence": {_KEYWORD + "existence"}, 78 | _KEYWORD + "frequency": {_KEYWORD + "frequency"}, 79 | # TODO(jeffreyzhou): make a proper set of sentences to choose from 80 | # _KEYWORD + "key_sentences": constraint_checker.KeySentenceChecker, 81 | _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"}, 82 | #_KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"}, 83 | _LANGUAGE + "response_language": { 84 | _LANGUAGE + "response_language", 85 | _FORMAT + "multiple_sections", 86 | _KEYWORD + "existence", 87 | _KEYWORD + "frequency", 88 | _KEYWORD + "forbidden_words", 89 | _STARTEND + "end_checker", 90 | _CHANGE_CASES + "english_capital", 91 | _CHANGE_CASES + "english_lowercase", 92 | }, 93 | # _LENGTH + "number_sentences": {_LENGTH + "number_sentences"}, 94 | # _LENGTH + "number_paragraphs": { 95 | # _LENGTH + "number_paragraphs", 96 | # _LENGTH + "nth_paragraph_first_word", 97 | # _LENGTH + "number_sentences", 98 | # _LENGTH + "nth_paragraph_first_word", 99 | # }, 100 | _LENGTH + "number_words": {_LENGTH + "number_words"}, 101 | # _LENGTH + "nth_paragraph_first_word": { 102 | # _LENGTH + "nth_paragraph_first_word", 103 | # _LENGTH + "number_paragraphs", 104 | # }, 105 | # _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"}, 106 | # _CONTENT + "postscript": {_CONTENT + "postscript"}, 107 | _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"}, 108 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace 109 | # _CONTENT + "rephrase_paragraph": constraint_checker.RephraseParagraph, 110 | # _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()), 111 | _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"}, 112 | _FORMAT + "multiple_sections": { 113 | _FORMAT + "multiple_sections", 114 | _LANGUAGE + "response_language", 115 | _FORMAT + "number_highlighted_sections", 116 | }, 117 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. 118 | # _FORMAT + "rephrase": constraint_checker.RephraseChecker, 119 | # _FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference( 120 | # {_KEYWORD + "forbidden_words", _KEYWORD + "existence"} 121 | # ), 122 | # _FORMAT + "title": {_FORMAT + "title"}, 123 | # TODO(tianjianlu): Re-enable with specific prompts. 124 | # _MULTITURN + "constrained_start": constraint_checker.ConstrainedStartChecker, 125 | # _COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference( 126 | # { 127 | # _KEYWORD + "forbidden_words", 128 | # _KEYWORD + "existence", 129 | # _LANGUAGE + "response_language", 130 | # _FORMAT + "title", 131 | # _PUNCTUATION + "no_comma", 132 | # } 133 | # ), 134 | _COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference( 135 | { 136 | _KEYWORD + "existence", 137 | # _FORMAT + "title", 138 | _PUNCTUATION + "no_comma"} 139 | ), 140 | _STARTEND + "end_checker": {_STARTEND + "end_checker"}, 141 | _CHANGE_CASES + "capital_word_frequency": { 142 | _CHANGE_CASES + "capital_word_frequency", 143 | _CHANGE_CASES + "english_lowercase", 144 | _CHANGE_CASES + "english_capital", 145 | }, 146 | _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"}, 147 | _CHANGE_CASES + "english_lowercase": { 148 | _CHANGE_CASES + "english_lowercase", 149 | _CHANGE_CASES + "english_capital", 150 | }, 151 | _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"}, 152 | _STARTEND + "quotation": { 153 | _STARTEND + "quotation", 154 | # _FORMAT + "title" 155 | }, 156 | } 157 | 158 | 159 | def conflict_make(conflicts): 160 | """Makes sure if A conflicts with B, B will conflict with A. 161 | 162 | Args: 163 | conflicts: Dictionary of potential conflicts where key is instruction id 164 | and value is set of instruction ids that it conflicts with. 165 | 166 | Returns: 167 | Revised version of the dictionary. All constraint_checker conflict with 168 | themselves. If A conflicts with B, B will conflict with A. 169 | """ 170 | for key in conflicts: 171 | for k in conflicts[key]: 172 | conflicts[k].add(key) 173 | conflicts[key].add(key) 174 | return conflicts 175 | 176 | DOUBLE_CONSTRAINT = [ 177 | ('combination:repeat_prompt', 'keywords:forbidden_words'), 178 | ('change_case:english_lowercase', 'keywords:frequency'), 179 | ('change_case:english_lowercase', 'keywords:existence'), 180 | # ('detectable_format:multiple_sections', 'detectable_format:number_bullet_lists'), 181 | ('language:response_language', 'startend:quotation'), 182 | ('keywords:existence', 'punctuation:no_comma'), 183 | ('change_case:english_lowercase', 'detectable_format:number_highlighted_sections'), 184 | ('keywords:forbidden_words', 'startend:quotation'), 185 | # ('change_case:capital_word_frequency', 'change_case:capital_word_frequency'), 186 | ('change_case:english_capital', 'keywords:forbidden_words'), 187 | ('language:response_language', 'punctuation:no_comma'), 188 | ('change_case:capital_word_frequency', 'punctuation:no_comma'), 189 | ('keywords:forbidden_words', 'punctuation:no_comma'), 190 | ('keywords:forbidden_words', 'keywords:frequency'), 191 | ('detectable_format:number_highlighted_sections', 'startend:quotation'), 192 | ('change_case:english_lowercase', 'detectable_format:number_bullet_lists'), 193 | ('detectable_format:number_bullet_lists', 'keywords:forbidden_words'), 194 | ('detectable_format:multiple_sections', 'startend:quotation'), 195 | ('keywords:frequency', 'keywords:frequency'), 196 | ('combination:repeat_prompt', 'keywords:existence'), 197 | ('change_case:english_lowercase', 'keywords:forbidden_words'), 198 | ('keywords:frequency', 'punctuation:no_comma'), 199 | ('combination:repeat_prompt', 'keywords:frequency'), 200 | ('combination:repeat_prompt', 'punctuation:no_comma'), 201 | ('keywords:existence', 'keywords:forbidden_words'), 202 | ('detectable_format:number_highlighted_sections', 'keywords:frequency'), 203 | ('detectable_format:number_highlighted_sections', 'punctuation:no_comma'), 204 | ('detectable_format:number_highlighted_sections', 'keywords:existence'), 205 | ('keywords:frequency', 'startend:end_checker'), 206 | ('punctuation:no_comma', 'startend:end_checker'), 207 | ('detectable_format:number_bullet_lists', 'startend:quotation'), 208 | ('change_case:english_capital', 'punctuation:no_comma'), 209 | # ('change_case:english_capital', 'detectable_format:multiple_sections'), 210 | ('change_case:english_capital', 'keywords:existence'), 211 | ('detectable_format:multiple_sections', 'startend:end_checker'), 212 | ('keywords:existence', 'startend:quotation'), 213 | ('change_case:english_lowercase', 'startend:quotation'), 214 | ('change_case:english_capital', 'detectable_format:number_highlighted_sections'), 215 | ('change_case:english_capital', 'startend:end_checker'), 216 | ('detectable_format:number_bullet_lists', 'punctuation:no_comma') 217 | ] 218 | 219 | TRIPLE_CONSTRAINT = [ 220 | ('change_case:capital_word_frequency', 'length_constraint_checkers:number_words', 'keywords:frequency'), 221 | ('combination:repeat_prompt', 'keywords:existence', 'keywords:forbidden_words'), 222 | ('detectable_format:number_highlighted_sections', 'punctuation:no_comma', 'startend:end_checker'), 223 | ('detectable_format:number_highlighted_sections', 'keywords:existence', 'punctuation:no_comma'), 224 | ('change_case:capital_word_frequency', 'keywords:frequency', 'punctuation:no_comma'), 225 | ('keywords:existence', 'keywords:frequency', 'keywords:frequency'), 226 | ('detectable_format:number_highlighted_sections', 'keywords:existence', 'keywords:frequency'), 227 | ('detectable_format:number_bullet_lists', 'keywords:forbidden_words', 'keywords:frequency'), 228 | ('detectable_format:multiple_sections', 'keywords:existence', 'keywords:frequency'), 229 | ('detectable_format:number_bullet_lists', 'keywords:existence', 'length_constraint_checkers:number_words'), 230 | ('combination:repeat_prompt', 'length_constraint_checkers:number_words', 'keywords:forbidden_words'), 231 | ('change_case:english_lowercase', 'keywords:existence', 'keywords:frequency'), 232 | ('keywords:existence', 'length_constraint_checkers:number_words', 'startend:quotation'), 233 | ('combination:repeat_prompt', 'detectable_format:number_highlighted_sections', 'length_constraint_checkers:number_words'), 234 | ('change_case:english_capital', 'punctuation:no_comma', 'startend:quotation') 235 | ] 236 | -------------------------------------------------------------------------------- /code/constraint_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility library of instructions.""" 16 | 17 | import functools 18 | import os 19 | import random 20 | import re 21 | from importlib.metadata import version 22 | 23 | import immutabledict 24 | import nltk 25 | from packaging.version import parse as parse_version 26 | 27 | 28 | # Downloading 'punkt' with nltk<3.9 has a remote code vuln. 29 | # see https://github.com/EleutherAI/lm-evaluation-harness/issues/2210 30 | # and https://github.com/nltk/nltk/issues/3266 31 | # for more information. 32 | NLTK_MIN_VERSION = "3.9.1" 33 | RANK = os.environ.get("LOCAL_RANK", "0") 34 | 35 | 36 | def download_nltk_resources(): 37 | """Download 'punkt' if not already installed""" 38 | assert (nltk_version := parse_version(version("nltk"))) >= parse_version( 39 | NLTK_MIN_VERSION 40 | ), ( 41 | f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." 42 | ) 43 | 44 | try: 45 | nltk.data.find("tokenizers/punkt_tab") 46 | except LookupError: 47 | if RANK == "0": 48 | nltk.download("punkt_tab") 49 | print("Downloaded punkt_tab on rank 0") 50 | 51 | 52 | download_nltk_resources() 53 | 54 | #['center', 'Plugging', 'xy', 'list', '92', 'stars', 'bar', 'items', '2da', '1000a', '100b', 'Using', 'tangent', 'blues', 'simply', 'To', 'mid', 'integer', '17k', 'eta', 'similar', 'way', 'lottery', 'circles', 'IL', '60', '10c', 'rows', 'Taking', 'BCD', '81', 'H', 'using', 'blue', 'red', 'pairs', 'more', 'monotonic', 'intersecting', 'positive', 'consider', '63', 'S', '104', 'congruent', 'parallel', 'lambda', 'would', 'triangles', 'black', 'candy', 'hearts', 'OI', 'occupy', '144', '96', 'segment', 'cong', 'but', 'casework', 'no', 'every', 'four', 'work', 'b_3', 'sum_', 'out', 'Equation', 'area', 'perpendicular', 'inradius', 'coordinates', 'make', 'calculate', 'multiply', 'strategy', 'winning', 'prize', 'DE', 'bi', '108b', '468', '34', 'x_1', 'y_1', 'who', 'Pi_', 'Substituting', '180', 'denote', 'lengths', 'Hence', 'Consider', 'count', 'neq', 'intersection', 'notice', 'through', 'function', 'Theorem', 'hline', 'form', 'least', 'tetrahedron', 'volume', 'note', 'perp', 'rm', 'terms', 'vertical', 'maximize', 'radius', '657', '4r', 'mathcal', '324', 'alpha', '30', 'digit', 'white', 'columns', '18', '45', 'M', 'lines', '33', 'cycle', 'satisfy', 'does', 'prove', 'b_1', 'them', 'q', 'g', 'they', 'remainder', 'pm', 'tfrac', 'OC', 'sphere', 'those', '46', 'lw', 'just', 'final', '404', 'together', 'like', 'different', 'drawn', '210', 'T', 'ordered', '3a', '2b', '4e', 'residents', '234', '2R', 'P_A', 'P_', 'rectangle', 'median', 'abcd', '1000', 'chip', 'Point', 'midpoint', 'AM', 'E', 'EF', 'Also', 'Furthermore', 'law', '39', 'Notice', 'question', 'go', 'import', 'void', '256', 'Technodoggo', 'geq', 'WLOG', 'b_2', 'symmetry', 'b_4', 'configurations', 'unique', 'Adding', 'giving', 'call', 'region', 'compute', 'gcd', 'divided', 'Solution', 'becomes', 'plane', 'ABCD', 'comhttps', 'AR', '2A', 'incenter', 'box', '2a', 'remaining', 'G', 'wins', 'under', 'last', 'integers', 'follows', 'CE', 'Draw', 'horizontal', '117i', '432', 'ge', '1190', '4046', '192', '900', '437', 'bag', 'CL', 'BL', 'ab', 'sqrt3', 'axis', 'x_C', 'db', 'occupied', 'cell', 'x_m', 'y_n', 'hours', '240', 'divide', 'Finally', '113', 'altitude', 'DAC', '4x', 'OD', 'reds', 'vertex', 'Longrightarrow', 'configuration', 'here', 'drawing', 'functions', 'slope', 'whose', 'array', 'bmod', '51', 'pmatrix', 'height', 'due', 'Pythagorean', 'obtain', '189', 'User', 'V', '405', 'coordinate', 'its', 'base', 'polynomial', 'yz', 'First', 'distinct', 'variable', '2r', 'simplifies', 'coin', 'A_i', 'player', 'move', 'their', 'start', 'sets', 'And', 'identical', 'even', 'OH', 'cap', 'path', 'row', '75a', '117b', '75b', '4a', 'phi', 'write', '480', 'greater', '3c', 'label'] 55 | 56 | WORD_LIST = [ 57 | "align", 58 | "number", 59 | "find", 60 | "therefore", 61 | "equation", 62 | "answer", 63 | "must", 64 | "now", 65 | "same", 66 | "imply", 67 | "because", 68 | "solution", 69 | "since", 70 | "where", 71 | "choose", 72 | "between", 73 | "length", 74 | "side", 75 | "follow", 76 | "case", 77 | "when", 78 | "value", 79 | "point", 80 | "because", 81 | "total", 82 | "denote", 83 | "see", 84 | "equal", 85 | "possible", 86 | "problem", 87 | "draw", 88 | "formula", 89 | "expression", 90 | "given", 91 | "adjacent", 92 | "note", 93 | "function", 94 | "above", 95 | "win", 96 | "than", 97 | "maximum", 98 | "root", 99 | "bar", 100 | "yield", 101 | "condition", 102 | "theorem", 103 | "respectively", 104 | "valid", 105 | "simply", 106 | "similar", 107 | "strategy", 108 | "function", 109 | "furthermore", 110 | "question", 111 | "configuration", 112 | "identical" 113 | ] # pylint: disable=line-too-long 114 | 115 | # ISO 639-1 codes to language names. 116 | LANGUAGE_CODES = immutabledict.immutabledict( 117 | { 118 | "en": "English", 119 | "es": "Spanish", 120 | "pt": "Portuguese", 121 | "ar": "Arabic", 122 | "hi": "Hindi", 123 | "fr": "French", 124 | "ru": "Russian", 125 | "de": "German", 126 | "ja": "Japanese", 127 | "it": "Italian", 128 | "bn": "Bengali", 129 | "uk": "Ukrainian", 130 | "th": "Thai", 131 | "ur": "Urdu", 132 | "ta": "Tamil", 133 | "te": "Telugu", 134 | "bg": "Bulgarian", 135 | "ko": "Korean", 136 | "pl": "Polish", 137 | "he": "Hebrew", 138 | "fa": "Persian", 139 | "vi": "Vietnamese", 140 | "ne": "Nepali", 141 | "sw": "Swahili", 142 | "kn": "Kannada", 143 | "mr": "Marathi", 144 | "gu": "Gujarati", 145 | "pa": "Punjabi", 146 | "ml": "Malayalam", 147 | "fi": "Finnish", 148 | } 149 | ) 150 | 151 | _ALPHABETS = "([A-Za-z])" 152 | _PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" 153 | _SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" 154 | _STARTERS = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" 155 | _ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" 156 | _WEBSITES = "[.](com|net|org|io|gov|edu|me)" 157 | _DIGITS = "([0-9])" 158 | _MULTIPLE_DOTS = r"\.{2,}" 159 | 160 | 161 | def split_into_sentences(text): 162 | """Split the text into sentences. 163 | 164 | Args: 165 | text: A string that consists of more than or equal to one sentences. 166 | 167 | Returns: 168 | A list of strings where each string is a sentence. 169 | """ 170 | text = " " + text + " " 171 | text = text.replace("\n", " ") 172 | text = re.sub(_PREFIXES, "\\1", text) 173 | text = re.sub(_WEBSITES, "\\1", text) 174 | text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) 175 | text = re.sub( 176 | _MULTIPLE_DOTS, 177 | lambda match: "" * len(match.group(0)) + "", 178 | text, 179 | ) 180 | if "Ph.D" in text: 181 | text = text.replace("Ph.D.", "PhD") 182 | text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) 183 | text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) 184 | text = re.sub( 185 | _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", 186 | "\\1\\2\\3", 187 | text, 188 | ) 189 | text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) 190 | text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) 191 | text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) 192 | text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) 193 | if "”" in text: 194 | text = text.replace(".”", "”.") 195 | if '"' in text: 196 | text = text.replace('."', '".') 197 | if "!" in text: 198 | text = text.replace('!"', '"!') 199 | if "?" in text: 200 | text = text.replace('?"', '"?') 201 | text = text.replace(".", ".") 202 | text = text.replace("?", "?") 203 | text = text.replace("!", "!") 204 | text = text.replace("", ".") 205 | sentences = text.split("") 206 | sentences = [s.strip() for s in sentences] 207 | if sentences and not sentences[-1]: 208 | sentences = sentences[:-1] 209 | return sentences 210 | 211 | 212 | def count_words(text): 213 | """Counts the number of words.""" 214 | tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") 215 | tokens = tokenizer.tokenize(text) 216 | num_words = len(tokens) 217 | return num_words 218 | 219 | 220 | @functools.lru_cache(maxsize=None) 221 | def _get_sentence_tokenizer(): 222 | return nltk.data.load("nltk:tokenizers/punkt/english.pickle") 223 | 224 | 225 | def count_sentences(text): 226 | """Count the number of sentences.""" 227 | tokenizer = _get_sentence_tokenizer() 228 | tokenized_sentences = tokenizer.tokenize(text) 229 | return len(tokenized_sentences) 230 | 231 | 232 | def generate_keywords(num_keywords): 233 | """Randomly generates a few keywords.""" 234 | return random.sample(WORD_LIST, k=num_keywords) 235 | -------------------------------------------------------------------------------- /code/eval_if.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import json 5 | 6 | from constraint_registry import INSTRUCTION_DICT 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--hypothesis_path', type=str) 10 | parser.add_argument('--data_path', type=str) 11 | parser.add_argument("--delimiter", type=str, default = "") 12 | args = parser.parse_args() 13 | 14 | 15 | def test_instruction_following_strict( 16 | instruction_id_list, 17 | response, 18 | parameters, 19 | prompt, 20 | ): 21 | """Tests response to see if instructions are followed.""" 22 | 23 | is_following_list = [] 24 | for index, instruction_id in enumerate(instruction_id_list): 25 | try: 26 | instruction_cls = INSTRUCTION_DICT[instruction_id] 27 | except: 28 | import pdb 29 | pdb.set_trace() 30 | instruction = instruction_cls(instruction_id) 31 | 32 | # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. 33 | if parameters[index]: 34 | kwargs = {n: p for n, p in parameters[index].items() if p} 35 | else: 36 | kwargs = {} 37 | instruction.build_description(**kwargs) 38 | args = instruction.get_constraint_args() 39 | if args and "prompt" in args: 40 | instruction.build_description(prompt=prompt) 41 | try: 42 | if response.strip() and instruction.check_following(response): 43 | is_following_list.append(True) 44 | else: 45 | is_following_list.append(False) 46 | except: 47 | import pdb 48 | pdb.set_trace() 49 | 50 | return is_following_list 51 | 52 | 53 | strict = [] 54 | loose = [] 55 | 56 | 57 | 58 | for line1,line2 in zip(open(args.hypothesis_path).readlines(), open(args.data_path).readlines()): 59 | try: 60 | hypothesis = json.loads(line1)["output"] 61 | except: 62 | hypothesis = json.loads(line1)["response"] 63 | if isinstance(hypothesis,list): 64 | hypothesis = hypothesis[0] 65 | has_end_think = '' in hypothesis 66 | has_start_think = '' in hypothesis 67 | 68 | think = hypothesis 69 | if has_end_think: 70 | think = think.split("")[0] 71 | if '' in think: 72 | think = think.split('')[1] 73 | 74 | if '' in hypothesis: 75 | hypothesis = hypothesis.split('')[1] 76 | if '' in hypothesis: 77 | hypothesis = hypothesis.split('')[1] 78 | if '' in hypothesis: 79 | hypothesis = hypothesis.split('')[1] 80 | if '' in hypothesis: 81 | hypothesis = hypothesis.split('')[0] 82 | 83 | 84 | data = json.loads(line2) 85 | if not ("noconstraint" in args.hypothesis_path): 86 | is_follow_list = test_instruction_following_strict( 87 | data["constraint_name"], 88 | hypothesis, 89 | data["constraint_args"], 90 | data["question"], 91 | ) 92 | strict.append(all(is_follow_list)) 93 | loose.append(sum(is_follow_list)/len(is_follow_list)) 94 | else: 95 | # only place holder 96 | strict.append(1) 97 | loose.append(1) 98 | 99 | 100 | 101 | print(sum(strict)/len(strict)) 102 | print(sum(loose)/len(loose)) -------------------------------------------------------------------------------- /code/requirement.txt: -------------------------------------------------------------------------------- 1 | vllm==0.7.3 2 | torchdata==0.11.0 3 | tensordict==0.6.0 4 | hydra-core 5 | wheel 6 | flash-attn==2.7.4.post1 7 | accelerate==1.6.0 8 | pylatexenc 9 | codetiming 10 | nltk 11 | langdetect 12 | immutabledict 13 | datasets -------------------------------------------------------------------------------- /code/script/download.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./data 2 | for source in gsm8k math500 minerva olympiad aime 3 | do 4 | for constraint in single double triple 5 | do 6 | wget https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/${source}_${constraint}.jsonl -o data/${source}_${constraint}.jsonl 7 | done 8 | done -------------------------------------------------------------------------------- /code/script/eval_if.sh: -------------------------------------------------------------------------------- 1 | model=deepseek-ai_DeepSeek-R1-Distill-Qwen-1.5B 2 | for dataset in gsm8k math500 minerva olympiad aime 3 | do 4 | for constraint in single double triple 5 | do 6 | echo ${model}_${dataset}_${constraint} 7 | python3 -u code/eval_if.py \ 8 | --data_path data/${dataset}_${constraint}.jsonl \ 9 | --hypothesis_path output/${model}_${dataset}_${constraint}_t1.0p0.95max16384seedNone.jsonl 10 | 11 | echo ${model}_${dataset}_${constraint}_noconstraint 12 | python3 -u code/eval_if.py \ 13 | --data_path data/${dataset}_${constraint}.jsonl \ 14 | --hypothesis_path output/${model}_${dataset}_${constraint}_t1.0p0.95max16384seedNone_noconstraint.jsonl 15 | done 16 | done 17 | 18 | 19 | -------------------------------------------------------------------------------- /code/script/vllm_if.sh: -------------------------------------------------------------------------------- 1 | for model in deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B 2 | do 3 | for dataset in gsm8k math500 minerva olympiad aime 4 | do 5 | for constraint in single double triple 6 | do 7 | python3 -u code/vllm_core.py \ 8 | --test_file data/${dataset}_${constraint}.jsonl \ 9 | --model_name_or_path ${model} \ 10 | --top_p 0.95 \ 11 | --temperature 1.0 \ 12 | --max_token 16384 \ 13 | 14 | python3 -u code/vllm_core.py \ 15 | --test_file data/${dataset}_${constraint}.jsonl \ 16 | --model_name_or_path ${model} \ 17 | --top_p 0.95 \ 18 | --temperature 1.0 \ 19 | --max_token 16384 \ 20 | --no_constraint 21 | 22 | done 23 | done 24 | done 25 | 26 | -------------------------------------------------------------------------------- /code/vllm_core.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | # from peft import PeftModel 5 | import os 6 | import torch 7 | from vllm import LLM, SamplingParams 8 | import json 9 | import sys 10 | from pathlib import Path 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_name_or_path', type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") 15 | parser.add_argument('--tokenizer_name_or_path', type=str, default=None) 16 | parser.add_argument("--test_file",type=str,default=None) 17 | parser.add_argument("--temperature",type=float,default=1.0) 18 | parser.add_argument("--top_p",type=float,default=0.95) 19 | parser.add_argument("--max_token",type=int,default=16384) 20 | parser.add_argument("--n_sample",type=int,default=1) 21 | parser.add_argument("--seed",type=int,default=None) 22 | parser.add_argument("--no_constraint",action="store_true") 23 | args = parser.parse_args() 24 | 25 | 26 | 27 | 28 | 29 | def vllm_inference(prompt): 30 | num_gpus = torch.cuda.device_count() 31 | llm = LLM(model = args.model_name_or_path, 32 | tokenizer = args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path, 33 | dtype='bfloat16', 34 | tensor_parallel_size = 4,#num_gpus, 35 | trust_remote_code=True, 36 | ) 37 | print('>>>>>> model loaded') 38 | 39 | sampling_params = SamplingParams(temperature = args.temperature, top_p = args.top_p, max_tokens = args.max_token, seed = args.seed, n = args.n_sample) 40 | outputs = llm.generate(prompt, sampling_params) 41 | sorted_outputs = sorted(outputs, key=lambda output: int(output.request_id)) 42 | print('>>>>>> generation done') 43 | 44 | return sorted_outputs 45 | 46 | 47 | 48 | ds = load_dataset('json' if 'json' in args.test_file else 'parquet', data_files=args.test_file, split='train') 49 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path) 50 | prompt = [] 51 | cot = "Let's think step by step and output the final answer within \\boxed{}." 52 | for i in range(len(ds)): 53 | if 'question' in ds.column_names: 54 | prompt.append(tokenizer.apply_chat_template([{ 55 | "role": "user", 56 | "content": ds[i]['question']+ cot + (" ".join(ds[i]["constraint_desc"]) if not args.no_constraint else ""), 57 | }],tokenize=False, add_generation_prompt=True)) 58 | elif "prompt" in ds.column_names: 59 | prompt.append(tokenizer.apply_chat_template(ds[i]["prompt"], tokenize=False, add_generation_prompt=True)) 60 | 61 | print(">>>>>>>>>>>>>>>>>>>>>>>>") 62 | print(prompt[0]) 63 | print(">>>>>>>>>>>>>>>>>>>>>>>>") 64 | 65 | 66 | 67 | output_path = "output/{}_{}_t{}p{}max{}seed{}.jsonl".format(args.model_name_or_path.replace("/","_"), args.test_file.split('/')[-1].split('.')[0], args.temperature, args.top_p, args.max_token, args.seed) 68 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 69 | output_path = output_path.replace(".jsonl", "_noconstraint.jsonl") if args.no_constraint else output_path 70 | 71 | 72 | output = vllm_inference(prompt) 73 | fout = open(output_path,'w', encoding='utf8') 74 | for i in range(len(prompt)): 75 | fout.write(json.dumps({"output": [ output[i].outputs[j].text for j in range(args.n_sample)]}, ensure_ascii=False)+'\n') 76 | fout.close() 77 | 78 | -------------------------------------------------------------------------------- /data/aime_double.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_double.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:17630 (17K) [text/plain] 6 | 正在保存至: “aime_double.jsonl” 7 | 8 | 0K .......... ....... 100% 118M=0s 9 | 10 | 2025-05-20 16:10:46 (118 MB/s) - 已保存 “aime_double.jsonl” [17630/17630]) 11 | 12 | -------------------------------------------------------------------------------- /data/aime_single.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_single.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:12535 (12K) [text/plain] 6 | 正在保存至: “aime_single.jsonl” 7 | 8 | 0K .......... .. 100% 11.7G=0s 9 | 10 | 2025-05-20 16:10:46 (11.7 GB/s) - 已保存 “aime_single.jsonl” [12535/12535]) 11 | 12 | -------------------------------------------------------------------------------- /data/aime_triple.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_triple.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:19088 (19K) [text/plain] 6 | 正在保存至: “aime_triple.jsonl” 7 | 8 | 0K .......... ........ 100% 150M=0s 9 | 10 | 2025-05-20 16:10:47 (150 MB/s) - 已保存 “aime_triple.jsonl” [19088/19088]) 11 | 12 | -------------------------------------------------------------------------------- /data/gsm8k_double.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:39-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_double.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.103, 54.230.71.28, 54.230.71.56, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.103|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:25342 (25K) [text/plain] 6 | 正在保存至: “gsm8k_double.jsonl” 7 | 8 | 0K .......... .......... .... 100% 47.9M=0.001s 9 | 10 | 2025-05-20 16:10:40 (47.9 MB/s) - 已保存 “gsm8k_double.jsonl” [25342/25342]) 11 | 12 | -------------------------------------------------------------------------------- /data/gsm8k_single.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:37-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_single.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:18960 (19K) [text/plain] 6 | 正在保存至: “gsm8k_single.jsonl” 7 | 8 | 0K .......... ........ 100% 700K=0.03s 9 | 10 | 2025-05-20 16:10:39 (700 KB/s) - 已保存 “gsm8k_single.jsonl” [18960/18960]) 11 | 12 | -------------------------------------------------------------------------------- /data/gsm8k_triple.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:40-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_triple.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.28, 54.230.71.56, 54.230.71.2, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.28|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:29140 (28K) [text/plain] 6 | 正在保存至: “gsm8k_triple.jsonl” 7 | 8 | 0K .......... .......... ........ 100% 90.5M=0s 9 | 10 | 2025-05-20 16:10:41 (90.5 MB/s) - 已保存 “gsm8k_triple.jsonl” [29140/29140]) 11 | 12 | -------------------------------------------------------------------------------- /data/math500_double.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:41-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_double.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:20441 (20K) [text/plain] 6 | 正在保存至: “math500_double.jsonl” 7 | 8 | 0K .......... ......... 100% 105M=0s 9 | 10 | 2025-05-20 16:10:42 (105 MB/s) - 已保存 “math500_double.jsonl” [20441/20441]) 11 | 12 | -------------------------------------------------------------------------------- /data/math500_single.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:41-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_single.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.56, 54.230.71.2, 54.230.71.103, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.56|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:15679 (15K) [text/plain] 6 | 正在保存至: “math500_single.jsonl” 7 | 8 | 0K .......... ..... 100% 67.7M=0s 9 | 10 | 2025-05-20 16:10:41 (67.7 MB/s) - 已保存 “math500_single.jsonl” [15679/15679]) 11 | 12 | -------------------------------------------------------------------------------- /data/math500_triple.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:42-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_triple.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:26727 (26K) [text/plain] 6 | 正在保存至: “math500_triple.jsonl” 7 | 8 | 0K .......... .......... ...... 100% 13.6M=0.002s 9 | 10 | 2025-05-20 16:10:42 (13.6 MB/s) - 已保存 “math500_triple.jsonl” [26727/26727]) 11 | 12 | -------------------------------------------------------------------------------- /data/minerva_double.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:43-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_double.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:27121 (26K) [text/plain] 6 | 正在保存至: “minerva_double.jsonl” 7 | 8 | 0K .......... .......... ...... 100% 26.6M=0.001s 9 | 10 | 2025-05-20 16:10:44 (26.6 MB/s) - 已保存 “minerva_double.jsonl” [27121/27121]) 11 | 12 | -------------------------------------------------------------------------------- /data/minerva_single.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:42-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_single.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:23996 (23K) [text/plain] 6 | 正在保存至: “minerva_single.jsonl” 7 | 8 | 0K .......... .......... ... 100% 48.5M=0s 9 | 10 | 2025-05-20 16:10:43 (48.5 MB/s) - 已保存 “minerva_single.jsonl” [23996/23996]) 11 | 12 | -------------------------------------------------------------------------------- /data/minerva_triple.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:44-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_triple.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:34264 (33K) [text/plain] 6 | 正在保存至: “minerva_triple.jsonl” 7 | 8 | 0K .......... .......... .......... ... 100% 1.93M=0.02s 9 | 10 | 2025-05-20 16:10:44 (1.93 MB/s) - 已保存 “minerva_triple.jsonl” [34264/34264]) 11 | 12 | -------------------------------------------------------------------------------- /data/olympiad_double.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:45-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_double.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:27628 (27K) [text/plain] 6 | 正在保存至: “olympiad_double.jsonl” 7 | 8 | 0K .......... .......... ...... 100% 2.63M=0.01s 9 | 10 | 2025-05-20 16:10:45 (2.63 MB/s) - 已保存 “olympiad_double.jsonl” [27628/27628]) 11 | 12 | -------------------------------------------------------------------------------- /data/olympiad_single.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:44-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_single.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:19297 (19K) [text/plain] 6 | 正在保存至: “olympiad_single.jsonl” 7 | 8 | 0K .......... ........ 100% 52.3M=0s 9 | 10 | 2025-05-20 16:10:45 (52.3 MB/s) - 已保存 “olympiad_single.jsonl” [19297/19297]) 11 | 12 | -------------------------------------------------------------------------------- /data/olympiad_triple.jsonl: -------------------------------------------------------------------------------- 1 | --2025-05-20 16:10:45-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_triple.jsonl 2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ... 3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。 4 | 已发出 HTTP 请求,正在等待回应... 200 OK 5 | 长度:27840 (27K) [text/plain] 6 | 正在保存至: “olympiad_triple.jsonl” 7 | 8 | 0K .......... .......... ....... 100% 1.16M=0.02s 9 | 10 | 2025-05-20 16:10:46 (1.16 MB/s) - 已保存 “olympiad_triple.jsonl” [27840/27840]) 11 | 12 | -------------------------------------------------------------------------------- /data/toy.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "HuggingFaceH4/MATH-500", "id": "level2-single-0", "question": "Simplify $\\tan 100^\\circ + 4 \\sin 100^\\circ.$", "answer": "-\\sqrt{3}", "constraint_desc": ["Include keywords \"['find', 'must']\" in the response."], "constraint_name": ["keywords:existence"], "constraint_args": [{"keywords": ["find", "must"]}]} 2 | {"source": "HuggingFaceH4/MATH-500", "id": "level2-single-2", "question": "In right triangle $ABC$ with $\\angle B = 90^\\circ$, we have $\\sin A = 2\\cos A$. What is $\\tan A$?", "answer": "2", "constraint_desc": ["Do not include keywords \"['respectively', 'strategy']\" in the response."], "constraint_name": ["keywords:forbidden_words"], "constraint_args": [{"forbidden_words": ["respectively", "strategy"]}]} --------------------------------------------------------------------------------