├── .DS_Store
├── README.md
├── code
├── __init__.py
├── constraint_checker.py
├── constraint_registry.py
├── constraint_util.py
├── eval_if.py
├── requirement.txt
├── script
│ ├── download.sh
│ ├── eval_if.sh
│ └── vllm_if.sh
└── vllm_core.py
└── data
├── aime_double.jsonl
├── aime_single.jsonl
├── aime_triple.jsonl
├── gsm8k_double.jsonl
├── gsm8k_single.jsonl
├── gsm8k_triple.jsonl
├── math500_double.jsonl
├── math500_single.jsonl
├── math500_triple.jsonl
├── minerva_double.jsonl
├── minerva_single.jsonl
├── minerva_triple.jsonl
├── olympiad_double.jsonl
├── olympiad_single.jsonl
├── olympiad_triple.jsonl
└── toy.jsonl
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TingchenFu/MathIF/e09c04f4ab40cec0f65105f7b78fcc47a59cf8d2/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | MathIF: Instruction-Following Benchmark for Large Reasoning Models
3 |
4 |
5 | []() []() []()
6 |
7 | MathIF is a dedicated benchmark for evaluating the instruction-following capabilities of large reasoning models (LRMs) on mathematical reasoning tasks. It exposes a fundamental trade-off between a model’s problem-solving strength and its ability to comply with user-specified constraints.
8 |
9 |
10 |
27 |
28 |
29 | # 📖Features
30 |
31 | - **Compositional Constraints**
32 | 15 Python-verifiable constraint types in four categories (length, lexical, format, affix), combined into single, dual, and triple constraints.
33 |
34 | - **Diverse Math Sources**
35 | Problems drawn from GSM8K, MATH-500, Minerva, Olympiad, and AIME, totaling 420 high-quality evaluation samples.
36 |
37 | - **Fine-Grained Metrics**
38 | - **Hard Accuracy (HAcc):** fraction of examples satisfying _all_ constraints
39 | - **Soft Accuracy (SAcc):** average fraction of satisfied constraints per example
40 |
41 | - **vLLM-Powered Inference**
42 | Efficient decoding with nucleus sampling (T=1.0, p=0.95) and up to 16k token generation.
43 |
44 | # ✨Getting Started
45 |
46 | ## Prerequisites
47 |
48 | - Python 3.9 or later
49 | - CUDA 12.4
50 | - `git`, `bash`
51 |
52 | ## Installation
53 |
54 | ```bash
55 | git clone https://github.com/TingchenFu/MathIF.git
56 | cd MathIF
57 |
58 | # Create and activate virtual environment
59 | python -m venv venv
60 | source venv/bin/activate # Windows: venv\Scripts\activate
61 |
62 | # Install dependencies
63 | pip install -r requirements.txt
64 | ````
65 |
66 |
67 |
68 | # 🔧Usage
69 |
70 | ## Inference
71 |
72 | ```bash
73 | bash code/scripts/vllm_if.sh
74 | ```
75 |
76 | ## Evaluation
77 |
78 | ```bash
79 | bash code/scripts/eval_if.sh
80 | ```
81 |
82 | ## Dataset Format
83 |
84 | Each line in the JSONL file contains:
85 |
86 | | Field | Description |
87 | | ----------------- | --------------------------------- |
88 | | `source` | Original data source |
89 | | `id` | Unique example identifier |
90 | | `question` | Math problem statement |
91 | | `answer` | Ground-truth solution |
92 | | `constraint_desc` | Human-readable constraint summary |
93 | | `constraint_name` | Constraint category |
94 | | `constraint_args` | Arguments used for verification |
95 |
96 | ## Project Structure
97 |
98 | ```
99 | .
100 | ├── data/ # MathIF JSONL files
101 | ├── code/
102 | │ ├── scripts/ # Inference & evaluation scripts
103 | │ └── ... # Model wrappers and utilities
104 | ├── output/ # Generated predictions & logs
105 | ├── requirements.txt # Python dependencies
106 | └── README.md # This overview
107 | ```
108 |
109 |
128 |
129 | Here's your LaTeX table transformed into a clean and readable GitHub-flavored Markdown table, **keeping only HAcc, SAcc, and correctness with constraint** (`w/ const.`). For clarity, the models are grouped by size, but LaTeX-specific formatting (bold/underline) is omitted since GitHub tables do not support rich styling.
130 |
131 |
132 | # 📊Leaderboard
133 | 📢 **Showcase Your Model’s Instruction-Following Capability**
134 |
135 | Feel free to contribute results from your own models—we welcome community submissions!
136 | We currently support evaluation of newly added models on our platform. To be included on the leaderboard, please provide the Hugging Face model link for verification and testing.
137 |
138 | ## **≤ 4B Models**
139 |
140 | | Model | HAcc | SAcc | Correctness |
141 | | ----------------------------- | ----- | ----- | ----------------------- |
142 | | [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) | 44.05 | 61.43 | 58.57 |
143 | | [Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) | 30.24 | 50.24 | 51.19 |
144 | | [Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) | 27.86 | 50.44 | 32.14 |
145 | | [L1-Qwen-1.5B-Exact](https://huggingface.co/l3lab/L1-Qwen-1.5B-Exact) | 19.76 | 39.60 | 42.86 |
146 | | [L1-Qwen-1.5B-Max](https://huggingface.co/l3lab/L1-Qwen-1.5B-Max) | 19.76 | 39.40 | 45.71 |
147 | | [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | 17.14 | 36.62 | 31.67 |
148 | | [DeepScaler-1.5B-Preview](https://huggingface.co/agentica-org/DeepScaleR-1.5B-Preview) | 14.52 | 34.52 | 36.19 |
149 | | [Qwen2.5-1.5B-SimpleRL-Zoo](https://huggingface.co/hkust-nlp/Qwen-2.5-1.5B-SimpleRL-Zoo) | 9.05 | 24.33 | 22.38 |
150 | | [Qwen2.5-Math-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B-Instruct) | 7.62 | 21.39 | 44.29 |
151 |
152 | ## **7B–14B Models**
153 | | Model | HAcc | SAcc | Correctness |
154 | | ----------------------------- | ----- | ----- | ----------------------- |
155 | | [Qwen3-14B](https://huggingface.co/Qwen/Qwen3-14B) | 50.71 | 67.06 | 64.29 |
156 | | [DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) | 39.28 | 60.55 | 50.95 |
157 | | [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) | 37.86 | 57.34 | 66.43 |
158 | | [DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 26.43 | 44.96 | 48.57 |
159 | | [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) | 22.14 | 44.04 | 36.43 |
160 | | [Open-Reasoner-Zero-7B](https://huggingface.co/Open-Reasoner-Zero/Open-Reasoner-Zero-7B) | 13.57 | 32.26 | 51.90 |
161 | | [Qwen2.5-Math-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Math-7B-Instruct) | 9.05 | 25.60 | 37.14 |
162 |
163 |
164 | ## **≥ 32B Models**
165 | | Model | HAcc | SAcc | Correctness |
166 | | ----------------------------- | ----- | ----- | ----------------------- |
167 | | [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | 43.81 | 62.82 | 70.00 |
168 | | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | 42.62 | 60.91 | 57.62 |
169 | | [DeepSeek-R1-Distill-Llama-70B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B) | 41.43 | 61.07 | 54.05 |
170 | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | 40.24 | 59.99 | 68.81 |
171 | | [OlympicCoder-32B](https://huggingface.co/open-r1/OlympicCoder-32B) | 35.95 | 57.97 | 54.52 |
172 | | [s1-32B](https://huggingface.co/simplescaling/s1-32B) | 20.95 | 41.78 | 60.95 |
173 | | [Open-Reasoner-Zero-32B](https://huggingface.co/Open-Reasoner-Zero/Open-Reasoner-Zero-32B) | 15.47 | 35.52 | 67.62 |
174 |
175 |
176 |
177 |
178 | # 🌻Acknowledgements
179 |
180 | MathIF is inspired by prior work on [IFEval](https://huggingface.co/datasets/google/IFEval) and [ComplexBench](https://github.com/thu-coai/ComplexBench), and leverages [vLLM](https://github.com/vllm-project/vllm) for efficient inference.
181 |
182 | # 📖Citation
183 |
184 | ```
185 | @article{fu2025scaling,
186 | title={Scaling Reasoning, Losing Control: Evaluating Instruction Following in Large Reasoning Models},
187 | author={Fu, Tingchen and Gu, Jiawei and Li, Yafu and Qu, Xiaoye and Cheng, Yu},
188 | journal={arXiv preprint arXiv:2505.14810},
189 | year={2025}
190 | }
191 | ```
192 |
193 | # 📬Contact
194 |
195 | For questions, feedback, or collaboration inquiries, please contact:
196 | - **Tingchen Fu**: lucas.futingchen@gmail.com
197 | - **Yafu Li**: yafuly@gmail.com
198 |
--------------------------------------------------------------------------------
/code/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TingchenFu/MathIF/e09c04f4ab40cec0f65105f7b78fcc47a59cf8d2/code/__init__.py
--------------------------------------------------------------------------------
/code/constraint_checker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of Constraints."""
16 |
17 | import collections
18 | import json
19 | import logging
20 | import random
21 | import re
22 | import string
23 | from typing import Dict, Optional, Sequence, Union
24 |
25 | import langdetect
26 | import nltk
27 |
28 | from constraint_util import LANGUAGE_CODES, generate_keywords, count_words
29 |
30 |
31 |
32 |
33 |
34 | _LANGUAGES = LANGUAGE_CODES
35 |
36 | # The relational operation for comparison.
37 | _COMPARISON_RELATION = ("less than", "at least")
38 |
39 | # The maximum number of sentences.
40 | _MAX_NUM_SENTENCES = 20
41 |
42 | # The number of placeholders.
43 | _NUM_PLACEHOLDERS = 4
44 |
45 | # The number of bullet lists.
46 | _NUM_BULLETS = 5
47 |
48 | # The options of constrained response.
49 | _CONSTRAINED_RESPONSE_OPTIONS = (
50 | "My answer is yes.",
51 | "My answer is no.",
52 | "My answer is maybe.",
53 | )
54 |
55 | # The options of starter keywords.
56 | _STARTER_OPTIONS = (
57 | "I would say",
58 | "My answer is",
59 | "I believe",
60 | "In my opinion",
61 | "I think",
62 | "I reckon",
63 | "I feel",
64 | "From my perspective",
65 | "As I see it",
66 | "According to me",
67 | "As far as I'm concerned",
68 | "To my understanding",
69 | "In my view",
70 | "My take on it is",
71 | "As per my perception",
72 | )
73 |
74 | # The options of ending keywords.
75 | # TODO(jeffreyzhou) add more ending options
76 | _ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?")
77 |
78 | # The number of highlighted sections.
79 | _NUM_HIGHLIGHTED_SECTIONS = 4
80 |
81 | # The section splitter.
82 | _SECTION_SPLITER = ("Section", "SECTION")
83 |
84 | # The number of sections.
85 | _NUM_SECTIONS = 5
86 |
87 | # The number of paragraphs.
88 | _NUM_PARAGRAPHS = 5
89 |
90 | # The postscript marker.
91 | _POSTSCRIPT_MARKER = ("P.S.", "P.P.S")
92 |
93 | # The number of keywords.
94 | _NUM_KEYWORDS = 2
95 |
96 | # The occurrences of a single keyword.
97 | _KEYWORD_FREQUENCY = 3
98 |
99 | # The occurrences of a single letter.
100 | _LETTER_FREQUENCY = 10
101 |
102 | # The occurrences of words with all capital letters.
103 | _ALL_CAPITAL_WORD_FREQUENCY = 20
104 |
105 | # The number of words in the response.
106 | # LEVEL_1_NUM_WORDS_LOWER_LIMIT = 100
107 | # LEVEL_1_NUM_WORDS_UPPER_LIMIT = 500
108 |
109 | NUM_WORDS_LOWER_LIMIT = 128
110 | NUM_WORDS_UPPER_LIMIT = 1024
111 |
112 |
113 |
114 | class Constraint:
115 | """An Constraint template."""
116 |
117 | def __init__(self, constraint_id=0):
118 | self.id = constraint_id
119 |
120 | def build_description(self, **kwargs):
121 | raise NotImplementedError("`build_description` not implemented.")
122 |
123 | def get_constraint_args(self):
124 | raise NotImplementedError("`get_constraint_args` not implemented.")
125 |
126 | def get_constraint_args_keys(self):
127 | raise NotImplementedError("`get_constraint_args_keys` not implemented.")
128 |
129 | def check_following(self, value):
130 | raise NotImplementedError("`check_following` not implemented.")
131 |
132 |
133 | class ResponseLanguageChecker(Constraint):
134 | """Check the language of the entire response."""
135 |
136 | def build_description(self, *, language=None):
137 | """Build the Constraint description.
138 |
139 | Args:
140 | language: A string representing the expected language of the response. The
141 | language has to comply to the 97 types defined in
142 | `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows
143 | ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes);
144 | for example, `en` for English, `zh` for Chinese, `fr` for French.
145 |
146 | Returns:
147 | A string representing the Constraint description.
148 | """
149 | self._language = language
150 | if self._language is None:
151 | self._language = random.choice(list(_LANGUAGES.keys()))
152 |
153 | self._description_pattern = (
154 | "Your answer should be in {language} language, no other language is allowed. "
155 | )
156 | return self._description_pattern.format(language=_LANGUAGES[self._language])
157 |
158 | def get_constraint_args(self):
159 | """Returns the keyword args of `build_description`."""
160 | return {"language": self._language}
161 |
162 |
163 | def check_following(self, value):
164 | """Check if the language of the entire response follows the Constraint.
165 |
166 | Args:
167 | value: A string representing the response.
168 |
169 | Returns:
170 | True if the language of `value` follows Constraint; otherwise False.
171 | """
172 | assert isinstance(value, str)
173 |
174 | try:
175 | return langdetect.detect(value) == self._language
176 | except langdetect.LangDetectException as e:
177 | # Count as Constraint is followed.
178 | logging.error(
179 | "Unable to detect language for text %s due to %s", value, e
180 | ) # refex: disable=pytotw.037
181 | return True
182 |
183 |
184 | # class NumberOfSentences(Constraint):
185 | # """Check the number of sentences."""
186 |
187 | # def build_description(self, *, num_sentences=None, relation=None):
188 | # """Build the Constraint description.
189 |
190 | # Args:
191 | # num_sentences: An integer specifying the number of sentences as a
192 | # threshold.
193 | # relation: A string in (`less than`, `at least`), defining the relational
194 | # operator for comparison.
195 | # Two relational comparisons are supported for now:
196 | # if 'less than', the actual number of sentences < the threshold;
197 | # if 'at least', the actual number of sentences >= the threshold.
198 |
199 | # Returns:
200 | # A string representing the Constraint description.
201 | # """
202 | # # The number of sentences as a threshold for comparison.
203 | # self._num_sentences_threshold = num_sentences
204 | # if self._num_sentences_threshold is None or self._num_sentences_threshold < 0:
205 | # self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES)
206 |
207 | # if relation is None:
208 | # self._comparison_relation = random.choice(_COMPARISON_RELATION)
209 | # elif relation not in _COMPARISON_RELATION:
210 | # raise ValueError(
211 | # "The supported relation for comparison must be in "
212 | # f"{_COMPARISON_RELATION}, but {relation} is given."
213 | # )
214 | # else:
215 | # self._comparison_relation = relation
216 |
217 | # self._description_pattern = (
218 | # "Your response should contain {relation} {num_sentences} sentences."
219 | # )
220 | # return self._description_pattern.format(
221 | # relation=self._comparison_relation,
222 | # num_sentences=self._num_sentences_threshold,
223 | # )
224 |
225 | # def get_constraint_args(self):
226 | # """Returns the keyword args of `build_description`."""
227 | # return {
228 | # "num_sentences": self._num_sentences_threshold,
229 | # "relation": self._comparison_relation,
230 | # }
231 |
232 | # def get_constraint_args_keys(self):
233 | # """Returns the args keys of `build_description`."""
234 | # return ["num_sentences", "relation"]
235 |
236 | # def check_following(self, value):
237 | # """Check if the number of sentences follows the Constraint.
238 |
239 | # Args:
240 | # value: A string representing the response.
241 |
242 | # Returns:
243 | # True if the response follows the Constraint.
244 |
245 | # Raise:
246 | # ValueError if the string in `Constraint_args` is not in
247 | # [`less_than`, `at_least`].
248 | # """
249 | # num_sentences = Constraints_util.count_sentences(value)
250 | # if self._comparison_relation == _COMPARISON_RELATION[0]:
251 | # return num_sentences < self._num_sentences_threshold
252 | # elif self._comparison_relation == _COMPARISON_RELATION[1]:
253 | # return num_sentences >= self._num_sentences_threshold
254 |
255 |
256 | # class PlaceholderChecker(Constraint):
257 | # """Check the placeholders in template writing."""
258 |
259 | # def build_description(self, *, num_placeholders=None):
260 | # """Build the Constraint description.
261 |
262 | # Args:
263 | # num_placeholders: An integer denoting the minimum number of
264 | # placeholders required in the response.
265 |
266 | # Returns:
267 | # A string representing the Constraint description.
268 | # """
269 | # self._num_placeholders = num_placeholders
270 | # if self._num_placeholders is None or self._num_placeholders < 0:
271 | # self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS)
272 | # self._description_pattern = (
273 | # "The response must contain at least {num_placeholders} placeholders "
274 | # + "represented by square brackets, such as [address]."
275 | # )
276 | # return self._description_pattern.format(num_placeholders=self._num_placeholders)
277 |
278 | # def get_constraint_args(self):
279 | # """Returns the keyword args of `build_description`."""
280 | # return {"num_placeholders": self._num_placeholders}
281 |
282 | # def get_constraint_args_keys(self):
283 | # """Returns the args keys of `build_description`."""
284 | # return ["num_placeholders"]
285 |
286 | # def check_following(self, value):
287 | # """Check if the number of placeholders follows the Constraint.
288 |
289 | # Args:
290 | # value: A string representing the response.
291 |
292 | # Returns:
293 | # True if the actual number of placeholders in the response is greater than
294 | # or equal to `num_placeholders`; otherwise, False.
295 | # """
296 | # placeholders = re.findall(r"\[.*?\]", value)
297 | # num_placeholders = len(placeholders)
298 | # return num_placeholders >= self._num_placeholders
299 |
300 |
301 | class BulletListChecker(Constraint):
302 | """Checks the bullet list in the prompt."""
303 |
304 | def build_description(self, *, num_bullets=None):
305 | """Build the Constraint description.
306 |
307 | Args:
308 | num_bullets: An integer specifying the exact number of bullet lists
309 | that is required to appear in the response.
310 |
311 | Returns:
312 | A string representing the Constraint description.
313 | """
314 | self._num_bullets = num_bullets
315 | if self._num_bullets is None or self._num_bullets < 0:
316 | self._num_bullets = random.randint(1, _NUM_BULLETS)
317 | self._description_pattern = (
318 | "Your answer must contain exactly {num_bullets} bullet points. "
319 | + "Use the markdown bullet points such as:\n"
320 | + "* This is point 1. \n"
321 | + "* This is point 2"
322 | )
323 | return self._description_pattern.format(num_bullets=self._num_bullets)
324 |
325 | def get_constraint_args(self):
326 | """Returns the keyword args of `build_description`."""
327 | return {"num_bullets": self._num_bullets}
328 |
329 | def get_constraint_args_keys(self):
330 | """Returns the args keys of `build_description`."""
331 | return ["num_bullets"]
332 |
333 | def check_following(self, value):
334 | r"""Check if the number of bullet lists meets the requirement.
335 |
336 | Args:
337 | value: A string representing the response. The response is expected to
338 | contain some bullet lists that start with `\*`.
339 |
340 | Returns:
341 | True if the actual number of bullet lists in the response meets the
342 | requirement.
343 | """
344 | bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE)
345 | bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE)
346 | num_bullet_lists = len(bullet_lists) + len(bullet_lists_2)
347 | return num_bullet_lists == self._num_bullets
348 |
349 |
350 | # class ConstrainedResponseChecker(Constraint):
351 | # """Checks the constrained response."""
352 |
353 | # def build_description(self):
354 | # """Build the Constraint description."""
355 | # # A sequence of string(s) representing the options of the expected response.
356 | # self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS
357 | # self._description_pattern = (
358 | # "Answer with one of the following options: {response_options}"
359 | # )
360 | # return self._description_pattern.format(
361 | # response_options=self._constrained_responses
362 | # )
363 |
364 | # def get_constraint_args(self):
365 | # """Returns the keyword args of `build_description`."""
366 | # return None
367 |
368 | # def get_constraint_args_keys(self):
369 | # """Returns the args keys of `build_description`."""
370 | # return []
371 |
372 | # def check_following(self, value):
373 | # """Checks if the response matches the constrained options.
374 |
375 | # Args:
376 | # value: A string representing the response.
377 |
378 | # Returns:
379 | # True if the actual response contains one of the options in the constrained
380 | # responses; otherwise False.
381 | # """
382 | # value = value.strip()
383 | # for constrained_response in self._constrained_responses:
384 | # if constrained_response in value:
385 | # return True
386 | # return False
387 |
388 |
389 | # class ConstrainedStartChecker(Constraint):
390 | # """Checks the response start."""
391 |
392 | # def build_description(self, *, starter=None):
393 | # """Build the Constraint description.
394 |
395 | # Args:
396 | # starter: A string representing the keyword that the response should start
397 | # with.
398 |
399 | # Returns:
400 | # A string representing the Constraint description.
401 | # """
402 | # self._starter = starter.strip() if isinstance(starter, str) else starter
403 | # if self._starter is None:
404 | # self._starter = random.choice(_STARTER_OPTIONS)
405 | # self._description_pattern = (
406 | # "During the conversation, when it is your turn, "
407 | # + "please always start with {starter}"
408 | # )
409 | # return self._description_pattern.format(starter=self._starter)
410 |
411 | # def get_constraint_args(self):
412 | # """Returns the keyword args of `build_description`."""
413 | # return {"starter": self._starter}
414 |
415 | # def get_constraint_args_keys(self):
416 | # """Returns the args keys of `build_description`."""
417 | # return ["starter"]
418 |
419 | # def check_following(self, value):
420 | # """Checks if the response starts with the constrained keyword or phrase.
421 |
422 | # Args:
423 | # value: A string representing the response.
424 |
425 | # Returns:
426 | # True if the response starts with the given phrase or keyword that is
427 | # contained in `Constraint_args`; otherwise, False.
428 | # """
429 | # response_pattern = r"^\s*" + self._starter + r".*$"
430 | # response_with_constrained_start = re.search(
431 | # response_pattern, value, flags=re.MULTILINE
432 | # )
433 | # return True if response_with_constrained_start else False
434 |
435 |
436 | class HighlightSectionChecker(Constraint):
437 | """Checks the highlighted section."""
438 |
439 | def build_description(self, *, num_highlights=None):
440 | """Build the Constraint description.
441 |
442 | Args:
443 | num_highlights: An integer specifying the minimum number of highlighted
444 | sections.
445 |
446 | Returns:
447 | A string representing the Constraint description.
448 | """
449 | self._num_highlights = num_highlights
450 | if self._num_highlights is None or self._num_highlights < 0:
451 | self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS)
452 |
453 | self._description_pattern = (
454 | "Highlight at least {num_highlights} sections in your answer with "
455 | + "markdown, i.e. *highlighted section*."
456 | )
457 |
458 | return self._description_pattern.format(num_highlights=self._num_highlights)
459 |
460 | def get_constraint_args(self):
461 | """Returns the keyword args of `build_description`."""
462 | return {"num_highlights": self._num_highlights}
463 |
464 | def get_constraint_args_keys(self):
465 | """Returns the args keys of `build_description`."""
466 | return ["num_highlights"]
467 |
468 | def check_following(self, value):
469 | """Checks if the number of highlighted sections meets the requirement.
470 |
471 | Args:
472 | value: a string representing the response. The response is expected to
473 | contain highlighted sections in the format of *highlighted*.
474 |
475 | Returns:
476 | True if the actual number of highlighted sections in the format of
477 | *highlighted sections* meets the minimum requirement; otherwise False.
478 | """
479 | num_highlights = 0
480 | highlights = re.findall(r"\*[^\n\*]*\*", value)
481 | double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value)
482 | for highlight in highlights:
483 | if highlight.strip("*").strip():
484 | num_highlights += 1
485 | for highlight in double_highlights:
486 | if highlight.removeprefix("**").removesuffix("**").strip():
487 | num_highlights += 1
488 |
489 | return num_highlights >= self._num_highlights
490 |
491 |
492 | class SectionChecker(Constraint):
493 | """Checks the sections."""
494 |
495 | def build_description(self, *, section_spliter=None, num_sections=None):
496 | """Build the Constraint description.
497 |
498 | Args:
499 | section_spliter: A string represents the section spliter keyword that
500 | marks a new section, i.e., `Section` or `SECTION`.
501 | num_sections: An integer specifying the number of sections.
502 |
503 | Returns:
504 | A string representing the Constraint description.
505 | """
506 | self._section_spliter = (
507 | section_spliter.strip()
508 | if isinstance(section_spliter, str)
509 | else section_spliter
510 | )
511 | if self._section_spliter is None:
512 | self._section_spliter = random.choice(_SECTION_SPLITER)
513 |
514 | self._num_sections = num_sections
515 | if self._num_sections is None or self._num_sections < 0:
516 | self._num_sections = random.randint(1, _NUM_SECTIONS)
517 |
518 | self._description_pattern = (
519 | "Your response must have {num_sections} sections. Mark the beginning "
520 | + "of each section with {section_spliter} X, such as:\n"
521 | + "{section_spliter} 1\n"
522 | + "[content of section 1]\n"
523 | + "{section_spliter} 2\n"
524 | + "[content of section 2]"
525 | )
526 |
527 | return self._description_pattern.format(
528 | num_sections=self._num_sections, section_spliter=self._section_spliter
529 | )
530 |
531 | def get_constraint_args(self):
532 | """Returns the keyword args of `build_description`."""
533 | return {
534 | "section_spliter": self._section_spliter,
535 | "num_sections": self._num_sections,
536 | }
537 |
538 | def get_constraint_args_keys(self):
539 | """Returns the args keys of `build_description`."""
540 | return ["section_spliter", "num_sections"]
541 |
542 | def check_following(self, value):
543 | """Checks the response contains multiple sections.
544 |
545 | Args:
546 | value: A string representing the response. The response is expected
547 | to contain multiple sections (number of sections is greater than 1).
548 | A new section starts with `Section 1`, where the number denotes the
549 | section index.
550 |
551 | Returns:
552 | True if the number of sections in the response is greater than or equal to
553 | the minimum number of sections; otherwise, False.
554 | """
555 | section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?"
556 | sections = re.split(section_splitter_patten, value)
557 | num_sections = len(sections) - 1
558 | return num_sections >= self._num_sections
559 |
560 |
561 | # class ParagraphChecker(Constraint):
562 | # """Checks the paragraphs."""
563 |
564 | # def build_description(self, *, num_paragraphs=None):
565 | # """Build the Constraint description.
566 |
567 | # Args:
568 | # num_paragraphs: An integer specifying the number of paragraphs.
569 |
570 | # Returns:
571 | # A string representing the Constraint description.
572 | # """
573 | # self._num_paragraphs = num_paragraphs
574 | # if self._num_paragraphs is None or self._num_paragraphs < 0:
575 | # self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)
576 |
577 | # self._description_pattern = (
578 | # "There should be {num_paragraphs} paragraphs. "
579 | # + "Paragraphs are separated with the markdown divider: ***"
580 | # )
581 |
582 | # return self._description_pattern.format(num_paragraphs=self._num_paragraphs)
583 |
584 | # def get_constraint_args(self):
585 | # """Returns the keyword args of `build_description`."""
586 | # return {"num_paragraphs": self._num_paragraphs}
587 |
588 | # def get_constraint_args_keys(self):
589 | # """Returns the args keys of `build_description`."""
590 | # return ["num_paragraphs"]
591 |
592 | # def check_following(self, value):
593 | # """Checks the response contains required number of paragraphs.
594 |
595 | # Args:
596 | # value: A string representing the response. The response may contain
597 | # paragraphs that are separated by the markdown divider: `***`.
598 |
599 | # Returns:
600 | # True if the actual number of paragraphs is the same as required;
601 | # otherwise, False.
602 | # """
603 | # paragraphs = re.split(r"\s?\*\*\*\s?", value)
604 | # num_paragraphs = len(paragraphs)
605 |
606 | # for index, paragraph in enumerate(paragraphs):
607 | # if not paragraph.strip():
608 | # if index == 0 or index == len(paragraphs) - 1:
609 | # num_paragraphs -= 1
610 | # else:
611 | # return False
612 |
613 | # return num_paragraphs == self._num_paragraphs
614 |
615 |
616 | # class PostscriptChecker(Constraint):
617 | # """Checks the postscript."""
618 |
619 | # def build_description(self, *, postscript_marker=None):
620 | # """Build the Constraint description.
621 |
622 | # Args:
623 | # postscript_marker: A string containing the keyword that marks the start
624 | # of the postscript section.
625 |
626 | # Returns:
627 | # A string representing the Constraint description.
628 | # """
629 | # self._postscript_marker = (
630 | # postscript_marker.strip()
631 | # if isinstance(postscript_marker, str)
632 | # else postscript_marker
633 | # )
634 | # if self._postscript_marker is None:
635 | # self._postscript_marker = random.choice(_POSTSCRIPT_MARKER)
636 |
637 | # self._description_pattern = (
638 | # "At the end of your response, please explicitly add a postscript "
639 | # + "starting with {postscript}"
640 | # )
641 |
642 | # return self._description_pattern.format(postscript=self._postscript_marker)
643 |
644 | # def get_constraint_args(self):
645 | # """Returns the keyword args of `build_description`."""
646 | # return {"postscript_marker": self._postscript_marker}
647 |
648 | # def get_constraint_args_keys(self):
649 | # """Returns the args keys of `build_description`."""
650 | # return ["postscript_marker"]
651 |
652 | # def check_following(self, value):
653 | # """Checks if the response follows the postscript format.
654 |
655 | # Args:
656 | # value: a string representing the response. The response is expected to
657 | # contain a postscript section.
658 |
659 | # Returns:
660 | # True if the response contains a postscript section starting with
661 | # the keyword containing in the `Constraint_args`; otherwise False.
662 | # """
663 | # value = value.lower()
664 | # if self._postscript_marker == "P.P.S":
665 | # postscript_pattern = r"\s*p\.\s?p\.\s?s.*$"
666 | # elif self._postscript_marker == "P.S.":
667 | # postscript_pattern = r"\s*p\.\s?s\..*$"
668 | # else:
669 | # postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$"
670 | # postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE)
671 | # return True if postscript else False
672 |
673 |
674 | # class RephraseChecker(Constraint):
675 | # """Checks the rephrase."""
676 |
677 | # def build_description(self, *, original_message):
678 | # """Build the Constraint description.
679 |
680 | # Args:
681 | # original_message: A string representing the original message. The
682 | # rephrased response should only change its words/sentences in between
683 | # its two asterisks, for example, *change me*. Both original and rephrased
684 | # messages should contain the changes in the form of *change me*.
685 |
686 | # Returns:
687 | # A string representing the Constraint description.
688 | # """
689 | # if not self.is_change(original_message):
690 | # raise ValueError(
691 | # f"Message {original_message} does not contain changes "
692 | # "in the form of *change me*."
693 | # )
694 |
695 | # self._reference_without_change = original_message
696 | # self._description = (
697 | # "Rephrasing: Your rephrased response should only"
698 | # + "change the words/sentences in between two asterisks"
699 | # + "such as *change me*."
700 | # )
701 | # return self._description
702 |
703 | # def get_constraint_args(self):
704 | # """Returns the keyword args of `build_description`."""
705 | # return {"original_message": self._reference_without_change}
706 |
707 | # def get_constraint_args_keys(self):
708 | # """Returns the args keys of `build_description`."""
709 | # return ["original_message"]
710 |
711 | # def check_following(self, value):
712 | # r"""Checks if the rephrasing follows the Constraint.
713 |
714 | # Args:
715 | # value: A string representing the response, which is expected to rephras
716 | # the string of `Constraint_args`.
717 |
718 | # Returns:
719 | # True if `value` and `Constraint_args` only differ by the words/sentences
720 | # in between two asterisks such as *change me*; otherwise, False.
721 | # """
722 |
723 | # if not self.is_change(value):
724 | # raise ValueError(
725 | # f"value {value} does not contain changes in the form of *change me*."
726 | # )
727 |
728 | # response_without_changes = self.strip_changes(value)
729 | # reference_without_changes = self.strip_changes(self._reference_without_change)
730 |
731 | # return response_without_changes == reference_without_changes
732 |
733 | # def is_change(self, response):
734 | # """Check if there is change in the response in the form of *change me*."""
735 | # return re.search(r"\*.*\*", response)
736 |
737 | # def strip_changes(self, response):
738 | # """Strips off the changes."""
739 | # return re.sub(r"\*.*\*", "", response)
740 |
741 |
742 | class KeywordChecker(Constraint):
743 | """Check the exisitence of certain keywords."""
744 |
745 | def build_description(self, *, keywords=None):
746 | """Build the Constraint description.
747 |
748 | Args:
749 | keywords: A sequence of strings representing the keywords that are
750 | expected in the response.
751 |
752 | Returns:
753 | A string representing the Constraint description.
754 | """
755 |
756 | if not keywords:
757 | self._keywords = generate_keywords(
758 | num_keywords=_NUM_KEYWORDS
759 | )
760 | else:
761 | self._keywords = keywords
762 | self._keywords = sorted(self._keywords)
763 |
764 | self._description_pattern = "Include keywords \"{keywords}\" in the response."
765 |
766 | return self._description_pattern.format(keywords=self._keywords)
767 |
768 | def get_constraint_args(self):
769 | """Returns the keyword args of `build_description`."""
770 | return {"keywords": self._keywords}
771 |
772 | def get_constraint_args_keys(self):
773 | """Returns the args keys of `build_description`."""
774 | return ["keywords"]
775 |
776 | def check_following(self, value):
777 | """Check if the response contain the expected keywords."""
778 | for keyword in self._keywords:
779 | if not re.search(keyword, value, flags=re.IGNORECASE):
780 | return False
781 | return True
782 |
783 |
784 | class KeywordFrequencyChecker(Constraint):
785 | """Check the keyword frequency."""
786 |
787 | def build_description(self, *, keyword=None, frequency=None, relation=None):
788 | """Build the Constraint description.
789 |
790 | Args:
791 | keyword: A string representing a keyword that is expected in the response.
792 | frequency: An integer specifying the number of times `keyword` is expected
793 | to appear in the response.
794 | relation: A string in (`less than`, `at least`), defining the relational
795 | operator for comparison.
796 | Two relational comparisons are supported for now:
797 | if 'less than', the actual number of occurrences < frequency;
798 | if 'at least', the actual number of occurrences >= frequency.
799 |
800 | Returns:
801 | A string representing the Constraint description.
802 | """
803 | if not keyword:
804 | self._keyword = generate_keywords(num_keywords=1)[0]
805 | else:
806 | self._keyword = keyword.strip()
807 |
808 | self._frequency = frequency
809 | if self._frequency is None or self._frequency < 0:
810 | self._frequency = random.randint(1, _KEYWORD_FREQUENCY)
811 |
812 | if relation is None:
813 | self._comparison_relation = random.choice(_COMPARISON_RELATION)
814 | elif relation not in _COMPARISON_RELATION:
815 | raise ValueError(
816 | "The supported relation for comparison must be in "
817 | f"{_COMPARISON_RELATION}, but {relation} is given."
818 | )
819 | else:
820 | self._comparison_relation = relation
821 |
822 | self._description_pattern = (
823 | "In your response, the word \"{keyword}\" should appear {relation} "
824 | + "{frequency} times."
825 | )
826 |
827 | return self._description_pattern.format(
828 | keyword=self._keyword,
829 | relation=self._comparison_relation,
830 | frequency=self._frequency,
831 | )
832 |
833 | def get_constraint_args(self):
834 | """Returns the keyword args of `build_description`."""
835 | return {
836 | "keyword": self._keyword,
837 | "frequency": self._frequency,
838 | "relation": self._comparison_relation,
839 | }
840 |
841 | def get_constraint_args_keys(self):
842 | """Returns the args keys of `build_description`."""
843 | return ["keyword", "frequency", "relation"]
844 |
845 | def check_following(self, value):
846 | """Checks if the response contain the keyword with required frequency."""
847 | actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE))
848 |
849 | if self._comparison_relation == _COMPARISON_RELATION[0]:
850 | return actual_occurrences < self._frequency
851 | elif self._comparison_relation == _COMPARISON_RELATION[1]:
852 | return actual_occurrences >= self._frequency
853 |
854 |
855 | class NumberOfWords(Constraint):
856 | """Checks the number of words."""
857 |
858 | def build_description(self, *, num_words=None, relation=None):
859 | """Build the Constraint description.
860 |
861 | Args:
862 | num_words: An integer specifying the number of words contained in the
863 | response.
864 | relation: A string in (`less than`, `at least`), defining the relational
865 | operator for comparison.
866 | Two relational comparisons are supported for now:
867 | if 'less than', the actual number of words < num_words;
868 | if 'at least', the actual number of words >= num_words.
869 |
870 | Returns:
871 | A string representing the Constraint description.
872 | """
873 |
874 | self._num_words = num_words
875 | if self._num_words is None or self._num_words < 0:
876 | self._num_words = random.randint(
877 | NUM_WORDS_LOWER_LIMIT, NUM_WORDS_UPPER_LIMIT
878 | )
879 |
880 | if relation is None:
881 | self._comparison_relation = random.choice(_COMPARISON_RELATION)
882 | elif relation not in _COMPARISON_RELATION:
883 | raise ValueError(
884 | "The supported relation for comparison must be in "
885 | f"{_COMPARISON_RELATION}, but {relation} is given."
886 | )
887 | else:
888 | self._comparison_relation = relation
889 |
890 | self._description_pattern = "Answer with {relation} {num_words} words."
891 |
892 | return self._description_pattern.format(
893 | relation=self._comparison_relation, num_words=self._num_words
894 | )
895 |
896 | def get_constraint_args(self):
897 | """Returns the keyword args of `build_description`."""
898 | return {"num_words": self._num_words, "relation": self._comparison_relation}
899 |
900 | def get_constraint_args_keys(self):
901 | """Returns the args keys of `build_description`."""
902 | return ["num_words", "relation"]
903 |
904 | def check_following(self, value):
905 | """Checks if the response contains the expected number of words."""
906 | num_words = count_words(value)
907 |
908 | if self._comparison_relation == _COMPARISON_RELATION[0]:
909 | return num_words < self._num_words
910 | elif self._comparison_relation == _COMPARISON_RELATION[1]:
911 | return num_words >= self._num_words
912 |
913 |
914 | # class JsonFormat(Constraint):
915 | # """Check the Json format."""
916 |
917 | # def build_description(self):
918 | # self._description_pattern = (
919 | # "Entire output should be wrapped in JSON format. You can use markdown"
920 | # " ticks such as ```."
921 | # )
922 | # return self._description_pattern
923 |
924 | # def get_constraint_args(self):
925 | # """Returns the keyword args of `build_description`."""
926 | # return None
927 |
928 | # def get_constraint_args_keys(self):
929 | # """Returns the args keys of `build_description`."""
930 | # return []
931 |
932 | # def check_following(self, value):
933 | # value = (
934 | # value.strip()
935 | # .removeprefix("```json")
936 | # .removeprefix("```Json")
937 | # .removeprefix("```JSON")
938 | # .removeprefix("```")
939 | # .removesuffix("```")
940 | # .strip()
941 | # )
942 | # try:
943 | # json.loads(value)
944 | # except ValueError:
945 | # return False
946 | # return True
947 |
948 |
949 | # class ParagraphFirstWordCheck(Constraint):
950 | # """Check the paragraph and the first word of the nth paragraph."""
951 |
952 | # def build_description(
953 | # self, num_paragraphs=None, nth_paragraph=None, first_word=None
954 | # ):
955 | # r"""Build the Constraint description.
956 |
957 | # Args:
958 | # num_paragraphs: An integer indicating the number of paragraphs expected
959 | # in the response. A paragraph is a subset of the string that is
960 | # expected to be separated by '\n\n'.
961 | # nth_paragraph: An integer indicating the paragraph number that we look at.
962 | # Note that n starts from 1.
963 | # first_word: A string that represent the first word of the bth paragraph.
964 |
965 | # Returns:
966 | # A string representing the Constraint description.
967 | # """
968 | # self._num_paragraphs = num_paragraphs
969 | # if self._num_paragraphs is None or self._num_paragraphs < 0:
970 | # self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)
971 |
972 | # self._nth_paragraph = nth_paragraph
973 | # if (
974 | # self._nth_paragraph is None
975 | # or self._nth_paragraph <= 0
976 | # or self._nth_paragraph > self._num_paragraphs
977 | # ):
978 | # self._nth_paragraph = random.randint(1, self._num_paragraphs + 1)
979 |
980 | # self._first_word = first_word
981 | # if self._first_word is None:
982 | # self._first_word = generate_keywords(num_keywords=1)[0]
983 | # self._first_word = self._first_word.lower()
984 |
985 | # self._description_pattern = (
986 | # "There should be {num_paragraphs} paragraphs. "
987 | # + "Paragraphs and only paragraphs are separated with each other by two "
988 | # + "new lines as if it was '\\n\\n' in python. "
989 | # + "Paragraph {nth_paragraph} must start with word {first_word}."
990 | # )
991 |
992 | # return self._description_pattern.format(
993 | # num_paragraphs=self._num_paragraphs,
994 | # nth_paragraph=self._nth_paragraph,
995 | # first_word=self._first_word,
996 | # )
997 |
998 | # def get_constraint_args(self):
999 | # """Returns the keyword args of `build_description`."""
1000 | # return {
1001 | # "num_paragraphs": self._num_paragraphs,
1002 | # "nth_paragraph": self._nth_paragraph,
1003 | # "first_word": self._first_word,
1004 | # }
1005 |
1006 | # def get_constraint_args_keys(self):
1007 | # """Returns the args keys of `build_description`."""
1008 | # return ["num_paragraphs", "nth_paragraph", "first_word"]
1009 |
1010 | # def check_following(self, value):
1011 | # """Checks for required number of paragraphs and correct first word.
1012 |
1013 | # Args:
1014 | # value: a string representing the response. The response may contain
1015 | # paragraphs that are separated by two new lines and the first word of
1016 | # the nth paragraph will have to match a specified word.
1017 |
1018 | # Returns:
1019 | # True if the number of paragraphs is the same as required and the first
1020 | # word of the specified paragraph is the same as required. Otherwise, false.
1021 | # """
1022 |
1023 | # paragraphs = re.split(r"\n\n", value)
1024 | # num_paragraphs = len(paragraphs)
1025 |
1026 | # for paragraph in paragraphs:
1027 | # if not paragraph.strip():
1028 | # num_paragraphs -= 1
1029 |
1030 | # # check that index doesn't go out of bounds
1031 | # if self._nth_paragraph <= num_paragraphs:
1032 | # paragraph = paragraphs[self._nth_paragraph - 1].strip()
1033 | # if not paragraph:
1034 | # return False
1035 | # else:
1036 | # return False
1037 |
1038 | # first_word = ""
1039 | # punctuation = {".", ",", "?", "!", "'", '"'}
1040 |
1041 | # # get first word and remove punctuation
1042 | # word = paragraph.split()[0].strip()
1043 | # # TODO(jeffrey): make more complex?
1044 | # word = word.lstrip("'")
1045 | # word = word.lstrip('"')
1046 |
1047 | # for letter in word:
1048 | # if letter in punctuation:
1049 | # break
1050 | # first_word += letter.lower()
1051 |
1052 | # return num_paragraphs == self._num_paragraphs and first_word == self._first_word
1053 |
1054 |
1055 | # # TODO(jeffrey) add relation - at least/at most?
1056 | # class KeySentenceChecker(Constraint):
1057 | # """Check the existence of certain key sentences."""
1058 |
1059 | # def build_description(self, key_sentences=None, num_sentences=None):
1060 | # """Build the Constraint description.
1061 |
1062 | # Args:
1063 | # key_sentences: A sequences of strings representing the key sentences that
1064 | # are expected in the response.
1065 | # num_sentences: The number of key sentences that are expected to be seen in
1066 | # the response.
1067 |
1068 | # Returns:
1069 | # A string representing the Constraint description.
1070 | # """
1071 |
1072 | # if not key_sentences:
1073 | # # TODO(jeffrey) make a generate sentences function? wonderwords package
1074 | # self._key_sentences = set(["For now, this is fine."])
1075 | # else:
1076 | # self._key_sentences = key_sentences
1077 |
1078 | # if not num_sentences:
1079 | # self._num_sentences = random.randint(1, len(self._key_sentences))
1080 | # else:
1081 | # self._num_sentences = num_sentences
1082 |
1083 | # self._description_pattern = (
1084 | # "Include {num_sentences} of the following sentences {key_sentences}"
1085 | # )
1086 |
1087 | # return self._description_pattern.format(
1088 | # num_sentences=self._num_sentences, key_sentences=self._key_sentences
1089 | # )
1090 |
1091 | # def get_constraint_args(self):
1092 | # """Returns the keyword args of `build_description`."""
1093 | # return {
1094 | # "num_sentences": self._num_sentences,
1095 | # "key_sentences": list(self._key_sentences),
1096 | # }
1097 |
1098 | # def get_constraint_args_keys(self):
1099 | # """Returns the args keys of `build_description`."""
1100 | # return ["num_sentences", "key_sentences"]
1101 |
1102 | # def check_following(self, value):
1103 | # """Checks if the response contains the expected key sentences."""
1104 | # count = 0
1105 | # sentences = Constraints_util.split_into_sentences(value)
1106 | # for sentence in self._key_sentences:
1107 | # if sentence in sentences:
1108 | # count += 1
1109 |
1110 | # return count == self._num_sentences
1111 |
1112 |
1113 | class ForbiddenWords(Constraint):
1114 | """Checks that specified words are not used in response."""
1115 |
1116 | def build_description(self, forbidden_words=None):
1117 | """Build the Constraint description.
1118 |
1119 | Args:
1120 | forbidden_words: A sequences of strings representing words that are not
1121 | allowed in the response.
1122 |
1123 | Returns:
1124 | A string representing the Constraint description.
1125 | """
1126 |
1127 | if not forbidden_words:
1128 | self._forbidden_words = generate_keywords(
1129 | num_keywords=_NUM_KEYWORDS
1130 | )
1131 | else:
1132 | self._forbidden_words = list(set(forbidden_words))
1133 | self._forbidden_words = sorted(self._forbidden_words)
1134 | self._description_pattern = (
1135 | "Do not include keywords \"{forbidden_words}\" in the response."
1136 | )
1137 |
1138 | return self._description_pattern.format(forbidden_words=self._forbidden_words)
1139 |
1140 | def get_constraint_args(self):
1141 | """Returns the keyword args of `build_description`."""
1142 | return {"forbidden_words": self._forbidden_words}
1143 |
1144 | def get_constraint_args_keys(self):
1145 | """Returns the args keys of `build_description`."""
1146 | return ["forbidden_words"]
1147 |
1148 | def check_following(self, value):
1149 | """Check if the response does not contain the expected keywords."""
1150 | for word in self._forbidden_words:
1151 | if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE):
1152 | return False
1153 | return True
1154 |
1155 |
1156 | # class RephraseParagraph(Constraint):
1157 | # """Checks that the paragraph is rephrased."""
1158 |
1159 | # def build_description(self, *, original_paragraph, low, high):
1160 | # """Builds the Constraint description.
1161 |
1162 | # Args:
1163 | # original_paragraph: A string presenting the original paragraph. The
1164 | # rephrases response should have betweeb low-high words in common.
1165 | # low: An integer presenting the lower bound of similar words.
1166 | # high: An integer representing the upper bound of similar words.
1167 |
1168 | # Returns:
1169 | # A string representing the Constraint description.
1170 | # """
1171 | # # TODO(jeffrey) make more encompassing
1172 | # self._original_paragraph = original_paragraph
1173 | # self._low = low
1174 | # self._high = high
1175 |
1176 | # self._description = (
1177 | # "Rephrase the following paragraph: "
1178 | # + "{original_paragraph}\nYour response should have "
1179 | # + "between {low} and {high} of the same words. "
1180 | # + "Words are the same if and only if all of the "
1181 | # + "letters, ignoring cases, are the same. For "
1182 | # + "example, 'run' is the same as 'Run' but different "
1183 | # + "to 'ran'."
1184 | # )
1185 |
1186 | # return self._description.format(
1187 | # original_paragraph=original_paragraph, low=self._low, high=self._high
1188 | # )
1189 |
1190 | # def get_constraint_args(self):
1191 | # """Returns the keyword args of `build_description`."""
1192 | # return {
1193 | # "original_paragraph": self._original_paragraph,
1194 | # "low": self._low,
1195 | # "high": self._high,
1196 | # }
1197 |
1198 | # def get_constraint_args_keys(self):
1199 | # """Returns the args keys of `build_description`."""
1200 | # return ["original_paragraph", "low", "high"]
1201 |
1202 | # def check_following(self, value):
1203 | # val_words = re.findall(r"\w+", value.lower())
1204 | # original_words = re.findall(r"\w+", self._original_paragraph.lower())
1205 | # similar_words = 0
1206 |
1207 | # dict_val = collections.Counter(val_words)
1208 | # dict_original = collections.Counter(original_words)
1209 |
1210 | # for word in dict_original:
1211 | # similar_words += min(dict_original[word], dict_val[word])
1212 |
1213 | # return similar_words >= self._low and similar_words <= self._high
1214 |
1215 |
1216 | # class TwoResponsesChecker(Constraint):
1217 | # """Check that two responses were given."""
1218 |
1219 | # def build_description(self):
1220 | # """Build the Constraint description."""
1221 | # self._description_pattern = (
1222 | # "Give two different responses. Responses and only responses should"
1223 | # " be separated by 6 asterisk symbols: ******."
1224 | # )
1225 | # return self._description_pattern
1226 |
1227 | # def get_constraint_args(self):
1228 | # """Returns the keyword args of `build_description`."""
1229 | # return None
1230 |
1231 | # def get_constraint_args_keys(self):
1232 | # """Returns the args keys of `build_description`."""
1233 | # return []
1234 |
1235 | # def check_following(self, value):
1236 | # """Checks if the response has two different answers.
1237 |
1238 | # Args:
1239 | # value: A string representing the response.
1240 |
1241 | # Returns:
1242 | # True if two responses are detected and false otherwise.
1243 | # """
1244 | # valid_responses = list()
1245 | # responses = value.split("******")
1246 | # for index, response in enumerate(responses):
1247 | # if not response.strip():
1248 | # if index != 0 and index != len(responses) - 1:
1249 | # return False
1250 | # else:
1251 | # valid_responses.append(response)
1252 | # return (
1253 | # len(valid_responses) == 2
1254 | # and valid_responses[0].strip() != valid_responses[1].strip()
1255 | # )
1256 |
1257 |
1258 | class RepeatPromptThenAnswer(Constraint):
1259 | """Checks that Prompt is first repeated then answered."""
1260 |
1261 | def build_description(self, *, prompt_to_repeat=None):
1262 | """Build the Constraint description.
1263 |
1264 | Args:
1265 | prompt_to_repeat: The prompt that is meant to be repeated.
1266 |
1267 | Returns:
1268 | A string representing the Constraint description.
1269 | """
1270 | if not prompt_to_repeat:
1271 | raise ValueError("prompt_to_repeat must be set.")
1272 | else:
1273 | self._prompt_to_repeat = prompt_to_repeat
1274 | self._description_pattern = (
1275 | "First repeat the request word for word without change,"
1276 | " then give your answer (1. do not say any words or characters"
1277 | " before repeating the request; 2. the request you need to repeat"
1278 | " does not include this sentence)"
1279 | )
1280 | return self._description_pattern
1281 |
1282 | def get_constraint_args(self):
1283 | return {"prompt_to_repeat": self._prompt_to_repeat}
1284 |
1285 | def get_constraint_args_keys(self):
1286 | """Returns the args keys of `build_description`."""
1287 | return ["prompt_to_repeat"]
1288 |
1289 | def check_following(self, value):
1290 | if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()):
1291 | return True
1292 | return False
1293 |
1294 |
1295 | class EndChecker(Constraint):
1296 | """Checks that the prompt ends with a given phrase."""
1297 |
1298 | def build_description(self, *, end_phrase=None):
1299 | """Build the Constraint description.
1300 |
1301 | Args:
1302 | end_phrase: A string representing the phrase the response should end with.
1303 |
1304 | Returns:
1305 | A string representing the Constraint description.
1306 | """
1307 | self._end_phrase = (
1308 | end_phrase.strip() if isinstance(end_phrase, str) else end_phrase
1309 | )
1310 | if self._end_phrase is None:
1311 | self._end_phrase = random.choice(_ENDING_OPTIONS)
1312 | self._description_pattern = (
1313 | "Finish your response with this exact phrase \"{ender}\". "
1314 | "No other words should follow this phrase."
1315 | )
1316 | return self._description_pattern.format(ender=self._end_phrase)
1317 |
1318 | def get_constraint_args(self):
1319 | return {"end_phrase": self._end_phrase}
1320 |
1321 | def get_constraint_args_keys(self):
1322 | """Returns the args keys of `build_description`."""
1323 | return ["end_phrase"]
1324 |
1325 | def check_following(self, value):
1326 | """Checks if the response ends with the expected phrase."""
1327 | value = value.strip().strip('"').lower()
1328 | self._end_phrase = self._end_phrase.strip().lower()
1329 | return value.endswith(self._end_phrase)
1330 |
1331 |
1332 | # class TitleChecker(Constraint):
1333 | # """Checks the response for a title."""
1334 |
1335 | # def build_description(self):
1336 | # """Build the Constraint description."""
1337 | # self._description_pattern = (
1338 | # "Your answer must contain a title, wrapped in double angular brackets,"
1339 | # " such as <>."
1340 | # )
1341 | # return self._description_pattern
1342 |
1343 | # def get_constraint_args(self):
1344 | # return None
1345 |
1346 | # def get_constraint_args_keys(self):
1347 | # """Returns the args keys of `build_description`."""
1348 | # return []
1349 |
1350 | # def check_following(self, value):
1351 | # """Checks if the response contains a title."""
1352 | # pattern = r"<<[^\n]+>>"
1353 | # re_pattern = re.compile(pattern)
1354 | # titles = re.findall(re_pattern, value)
1355 |
1356 | # for title in titles:
1357 | # if title.lstrip("<").rstrip(">").strip():
1358 | # return True
1359 | # return False
1360 |
1361 |
1362 | # class LetterFrequencyChecker(Constraint):
1363 | # """Checks letter frequency."""
1364 |
1365 | # def build_description(self, *, letter=None, let_frequency=None, let_relation=None):
1366 | # """Build the Constraint description.
1367 |
1368 | # Args:
1369 | # letter: A string representing a letter that is expected in the response.
1370 | # let_frequency: An integer specifying the number of times `keyword` is
1371 | # expected to appear in the response.
1372 | # let_relation: A string in (`less than`, `at least`), defining the
1373 | # relational operator for comparison. Two relational comparisons are
1374 | # supported for now; if 'less than', the actual number of
1375 | # occurrences < frequency; if 'at least', the actual number of
1376 | # occurrences >= frequency.
1377 |
1378 | # Returns:
1379 | # A string representing the Constraint description.
1380 | # """
1381 | # if (
1382 | # not letter
1383 | # or len(letter) > 1
1384 | # or ord(letter.lower()) < 97
1385 | # or ord(letter.lower()) > 122
1386 | # ):
1387 | # self._letter = random.choice(list(string.ascii_letters))
1388 | # else:
1389 | # self._letter = letter.strip()
1390 | # self._letter = self._letter.lower()
1391 |
1392 | # self._frequency = let_frequency
1393 | # if self._frequency is None or self._frequency < 0:
1394 | # self._frequency = random.randint(1, _LETTER_FREQUENCY)
1395 |
1396 | # if let_relation is None:
1397 | # self._comparison_relation = random.choice(_COMPARISON_RELATION)
1398 | # elif let_relation not in _COMPARISON_RELATION:
1399 | # raise ValueError(
1400 | # "The supported relation for comparison must be in "
1401 | # f"{_COMPARISON_RELATION}, but {let_relation} is given."
1402 | # )
1403 | # else:
1404 | # self._comparison_relation = let_relation
1405 |
1406 | # self._description_pattern = (
1407 | # "In your response, the letter {letter} should appear {let_relation}"
1408 | # " {let_frequency} times."
1409 | # )
1410 |
1411 | # return self._description_pattern.format(
1412 | # letter=self._letter,
1413 | # let_frequency=self._frequency,
1414 | # let_relation=self._comparison_relation,
1415 | # )
1416 |
1417 | # def get_constraint_args(self):
1418 | # """Returns the keyword args of build description."""
1419 | # return {
1420 | # "letter": self._letter,
1421 | # "let_frequency": self._frequency,
1422 | # "let_relation": self._comparison_relation,
1423 | # }
1424 |
1425 | # def get_constraint_args_keys(self):
1426 | # """Returns the args keys of `build_description`."""
1427 | # return ["letter", "let_frequency", "let_relation"]
1428 |
1429 | # def check_following(self, value):
1430 | # """Checks that the response contains the letter at the right frequency."""
1431 | # value = value.lower()
1432 | # letters = collections.Counter(value)
1433 |
1434 | # if self._comparison_relation == _COMPARISON_RELATION[0]:
1435 | # return letters[self._letter] < self._frequency
1436 | # else:
1437 | # return letters[self._letter] >= self._frequency
1438 |
1439 |
1440 | class CapitalLettersEnglishChecker(Constraint):
1441 | """Checks that the response is in english and is in all capital letters."""
1442 |
1443 | def build_description(self):
1444 | """Build the Constraint description."""
1445 | self._description_pattern = (
1446 | "Your entire response should be in English, and in all capital letters."
1447 | )
1448 | return self._description_pattern
1449 |
1450 | def get_constraint_args(self):
1451 | return None
1452 |
1453 | def get_constraint_args_keys(self):
1454 | """Returns the args keys of `build_description`."""
1455 | return []
1456 |
1457 | def check_following(self, value):
1458 | """Checks that the response is in English and in all capital letters."""
1459 | assert isinstance(value, str)
1460 |
1461 | try:
1462 | return value.isupper() and langdetect.detect(value) == "en"
1463 | except langdetect.LangDetectException as e:
1464 | # Count as Constraint is followed.
1465 | logging.error(
1466 | "Unable to detect language for text %s due to %s", value, e
1467 | ) # refex: disable=pytotw.037
1468 | return True
1469 |
1470 |
1471 | class LowercaseLettersEnglishChecker(Constraint):
1472 | """Checks that the response is in english and is in all lowercase letters."""
1473 |
1474 | def build_description(self):
1475 | """Build the Constraint description."""
1476 | self._description_pattern = (
1477 | "Your entire response should be in English, and in all lowercase"
1478 | " letters. No capital letters are allowed."
1479 | )
1480 | return self._description_pattern
1481 |
1482 | def get_constraint_args(self):
1483 | return None
1484 |
1485 | def get_constraint_args_keys(self):
1486 | """Returns the args keys of `build_description`."""
1487 | return []
1488 |
1489 | def check_following(self, value):
1490 | """Checks that the response is in English and in all lowercase letters."""
1491 | assert isinstance(value, str)
1492 |
1493 | try:
1494 | return value.islower() and langdetect.detect(value) == "en"
1495 | except langdetect.LangDetectException as e:
1496 | # Count as Constraint is followed.
1497 | logging.error(
1498 | "Unable to detect language for text %s due to %s", value, e
1499 | ) # refex: disable=pytotw.037
1500 | return True
1501 |
1502 |
1503 | class CommaChecker(Constraint):
1504 | """Checks the response for no commas."""
1505 |
1506 | def build_description(self):
1507 | """Build the Constraint description."""
1508 | self._description_pattern = (
1509 | "In your entire response, refrain from the use of any commas."
1510 | )
1511 | return self._description_pattern
1512 |
1513 | def get_constraint_args(self):
1514 | return None
1515 |
1516 | def get_constraint_args_keys(self):
1517 | """Returns the args keys of `build_description`."""
1518 | return []
1519 |
1520 | def check_following(self, value):
1521 | """Checks that the response does not contain commas."""
1522 | return not re.search(r"\,", value)
1523 |
1524 |
1525 | class CapitalWordFrequencyChecker(Constraint):
1526 | """Checks frequency of words with all capital letters."""
1527 |
1528 | def build_description(
1529 | self,
1530 | capital_frequency=None,
1531 | capital_relation=None,
1532 | ):
1533 | """Build the Constraint description.
1534 |
1535 | Args:
1536 | capital_frequency: An integer that represents the number of words that
1537 | should be in all capital letters.
1538 | capital_relation: A string that is 'at least' or 'at most' that refers to
1539 | the frequency.
1540 |
1541 | Returns:
1542 | A string representing the Constraint description.
1543 | """
1544 | self._frequency = capital_frequency
1545 | if self._frequency is None:
1546 | self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY)
1547 |
1548 | self._comparison_relation = capital_relation
1549 | if capital_relation is None:
1550 | self._comparison_relation = random.choice(_COMPARISON_RELATION)
1551 | elif capital_relation not in _COMPARISON_RELATION:
1552 | raise ValueError(
1553 | "The supported relation for comparison must be in "
1554 | f"{_COMPARISON_RELATION}, but {capital_relation} is given."
1555 | )
1556 |
1557 | self._description_pattern = (
1558 | "In your response, words with all capital letters should appear"
1559 | " {relation} {frequency} times."
1560 | )
1561 |
1562 | return self._description_pattern.format(
1563 | frequency=self._frequency, relation=self._comparison_relation
1564 | )
1565 |
1566 | def get_constraint_args(self):
1567 | """Returns the keyword args of build description."""
1568 | return {
1569 | "capital_frequency": self._frequency,
1570 | "capital_relation": self._comparison_relation,
1571 | }
1572 |
1573 | def get_constraint_args_keys(self):
1574 | """Returns the args keys of `build_description`."""
1575 | return ["capital_frequency", "capital_relation"]
1576 |
1577 | def check_following(self, value):
1578 | """Checks the frequency of words with all capital letters."""
1579 | # Hyphenated words will count as one word
1580 | words = nltk.word_tokenize(value)
1581 | capital_words = [word for word in words if word.isupper()]
1582 |
1583 | capital_words = len(capital_words)
1584 |
1585 | if self._comparison_relation == _COMPARISON_RELATION[0]:
1586 | return capital_words < self._frequency
1587 | else:
1588 | return capital_words >= self._frequency
1589 |
1590 |
1591 | class QuotationChecker(Constraint):
1592 | """Checks response is wrapped with double quotation marks."""
1593 |
1594 | def build_description(self):
1595 | """Build the Constraint description."""
1596 | self._description_pattern = (
1597 | "Wrap your entire response with double quotation marks. "
1598 | )
1599 | return self._description_pattern
1600 |
1601 | def get_constraint_args(self):
1602 | """Returns the keyword args of build description."""
1603 | return None
1604 |
1605 | def get_constraint_args_keys(self):
1606 | """Returns the args keys of `build_description`."""
1607 | return []
1608 |
1609 | def check_following(self, value):
1610 | """Checks if the response is wrapped with double quotation marks."""
1611 | value = value.strip()
1612 | return len(value) > 1 and value[0] == '"' and value[-1] == '"'
1613 |
--------------------------------------------------------------------------------
/code/constraint_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Registry of all constraint_checker."""
16 |
17 | import constraint_checker
18 |
19 |
20 | _KEYWORD = "keywords:"
21 |
22 | _LANGUAGE = "language:"
23 |
24 | _LENGTH = "length_constraint_checkers:"
25 |
26 | _CONTENT = "detectable_content:"
27 |
28 | _FORMAT = "detectable_format:"
29 |
30 | _MULTITURN = "multi-turn:"
31 |
32 | _COMBINATION = "combination:"
33 |
34 | _STARTEND = "startend:"
35 |
36 | _CHANGE_CASES = "change_case:"
37 |
38 | _PUNCTUATION = "punctuation:"
39 |
40 | INSTRUCTION_DICT = {
41 | _KEYWORD + "existence": constraint_checker.KeywordChecker,
42 | _KEYWORD + "frequency": constraint_checker.KeywordFrequencyChecker,
43 | # TODO(jeffreyzhou): make a proper set of sentences to choose from
44 | # _KEYWORD + "key_sentences": constraint_checker.KeySentenceChecker,
45 | _KEYWORD + "forbidden_words": constraint_checker.ForbiddenWords,
46 | # _KEYWORD + "letter_frequency": constraint_checker.LetterFrequencyChecker,
47 | _LANGUAGE + "response_language": constraint_checker.ResponseLanguageChecker,
48 | # _LENGTH + "number_sentences": constraint_checker.NumberOfSentences,
49 | # _LENGTH + "number_paragraphs": constraint_checker.ParagraphChecker,
50 | _LENGTH + "number_words": constraint_checker.NumberOfWords,
51 | # _LENGTH + "nth_paragraph_first_word": constraint_checker.ParagraphFirstWordCheck,
52 | # _CONTENT + "number_placeholders": constraint_checker.PlaceholderChecker,
53 | # _CONTENT + "postscript": constraint_checker.PostscriptChecker,
54 | _FORMAT + "number_bullet_lists": constraint_checker.BulletListChecker,
55 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace
56 | # _CONTENT + "rephrase_paragraph": constraint_checker.RephraseParagraph,
57 | # _FORMAT + "constrained_response": constraint_checker.ConstrainedResponseChecker,
58 | _FORMAT + "number_highlighted_sections": (constraint_checker.HighlightSectionChecker),
59 | _FORMAT + "multiple_sections": constraint_checker.SectionChecker,
60 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
61 | # _FORMAT + "rephrase": constraint_checker.RephraseChecker,
62 | # _FORMAT + "json_format": constraint_checker.JsonFormat,
63 | # _FORMAT + "title": constraint_checker.TitleChecker,
64 | # TODO(tianjianlu): Re-enable with specific prompts.
65 | # _MULTITURN + "constrained_start": constraint_checker.ConstrainedStartChecker,
66 | # _COMBINATION + "two_responses": constraint_checker.TwoResponsesChecker,
67 | _COMBINATION + "repeat_prompt": constraint_checker.RepeatPromptThenAnswer,
68 | _STARTEND + "end_checker": constraint_checker.EndChecker,
69 | _STARTEND + "quotation": constraint_checker.QuotationChecker,
70 | _CHANGE_CASES + "capital_word_frequency": constraint_checker.CapitalWordFrequencyChecker,
71 | _CHANGE_CASES + "english_capital": constraint_checker.CapitalLettersEnglishChecker,
72 | _CHANGE_CASES + "english_lowercase": constraint_checker.LowercaseLettersEnglishChecker,
73 | _PUNCTUATION + "no_comma": constraint_checker.CommaChecker,
74 | }
75 |
76 | INSTRUCTION_CONFLICTS = {
77 | _KEYWORD + "existence": {_KEYWORD + "existence"},
78 | _KEYWORD + "frequency": {_KEYWORD + "frequency"},
79 | # TODO(jeffreyzhou): make a proper set of sentences to choose from
80 | # _KEYWORD + "key_sentences": constraint_checker.KeySentenceChecker,
81 | _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"},
82 | #_KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"},
83 | _LANGUAGE + "response_language": {
84 | _LANGUAGE + "response_language",
85 | _FORMAT + "multiple_sections",
86 | _KEYWORD + "existence",
87 | _KEYWORD + "frequency",
88 | _KEYWORD + "forbidden_words",
89 | _STARTEND + "end_checker",
90 | _CHANGE_CASES + "english_capital",
91 | _CHANGE_CASES + "english_lowercase",
92 | },
93 | # _LENGTH + "number_sentences": {_LENGTH + "number_sentences"},
94 | # _LENGTH + "number_paragraphs": {
95 | # _LENGTH + "number_paragraphs",
96 | # _LENGTH + "nth_paragraph_first_word",
97 | # _LENGTH + "number_sentences",
98 | # _LENGTH + "nth_paragraph_first_word",
99 | # },
100 | _LENGTH + "number_words": {_LENGTH + "number_words"},
101 | # _LENGTH + "nth_paragraph_first_word": {
102 | # _LENGTH + "nth_paragraph_first_word",
103 | # _LENGTH + "number_paragraphs",
104 | # },
105 | # _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"},
106 | # _CONTENT + "postscript": {_CONTENT + "postscript"},
107 | _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"},
108 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace
109 | # _CONTENT + "rephrase_paragraph": constraint_checker.RephraseParagraph,
110 | # _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()),
111 | _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"},
112 | _FORMAT + "multiple_sections": {
113 | _FORMAT + "multiple_sections",
114 | _LANGUAGE + "response_language",
115 | _FORMAT + "number_highlighted_sections",
116 | },
117 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
118 | # _FORMAT + "rephrase": constraint_checker.RephraseChecker,
119 | # _FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference(
120 | # {_KEYWORD + "forbidden_words", _KEYWORD + "existence"}
121 | # ),
122 | # _FORMAT + "title": {_FORMAT + "title"},
123 | # TODO(tianjianlu): Re-enable with specific prompts.
124 | # _MULTITURN + "constrained_start": constraint_checker.ConstrainedStartChecker,
125 | # _COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference(
126 | # {
127 | # _KEYWORD + "forbidden_words",
128 | # _KEYWORD + "existence",
129 | # _LANGUAGE + "response_language",
130 | # _FORMAT + "title",
131 | # _PUNCTUATION + "no_comma",
132 | # }
133 | # ),
134 | _COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
135 | {
136 | _KEYWORD + "existence",
137 | # _FORMAT + "title",
138 | _PUNCTUATION + "no_comma"}
139 | ),
140 | _STARTEND + "end_checker": {_STARTEND + "end_checker"},
141 | _CHANGE_CASES + "capital_word_frequency": {
142 | _CHANGE_CASES + "capital_word_frequency",
143 | _CHANGE_CASES + "english_lowercase",
144 | _CHANGE_CASES + "english_capital",
145 | },
146 | _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"},
147 | _CHANGE_CASES + "english_lowercase": {
148 | _CHANGE_CASES + "english_lowercase",
149 | _CHANGE_CASES + "english_capital",
150 | },
151 | _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"},
152 | _STARTEND + "quotation": {
153 | _STARTEND + "quotation",
154 | # _FORMAT + "title"
155 | },
156 | }
157 |
158 |
159 | def conflict_make(conflicts):
160 | """Makes sure if A conflicts with B, B will conflict with A.
161 |
162 | Args:
163 | conflicts: Dictionary of potential conflicts where key is instruction id
164 | and value is set of instruction ids that it conflicts with.
165 |
166 | Returns:
167 | Revised version of the dictionary. All constraint_checker conflict with
168 | themselves. If A conflicts with B, B will conflict with A.
169 | """
170 | for key in conflicts:
171 | for k in conflicts[key]:
172 | conflicts[k].add(key)
173 | conflicts[key].add(key)
174 | return conflicts
175 |
176 | DOUBLE_CONSTRAINT = [
177 | ('combination:repeat_prompt', 'keywords:forbidden_words'),
178 | ('change_case:english_lowercase', 'keywords:frequency'),
179 | ('change_case:english_lowercase', 'keywords:existence'),
180 | # ('detectable_format:multiple_sections', 'detectable_format:number_bullet_lists'),
181 | ('language:response_language', 'startend:quotation'),
182 | ('keywords:existence', 'punctuation:no_comma'),
183 | ('change_case:english_lowercase', 'detectable_format:number_highlighted_sections'),
184 | ('keywords:forbidden_words', 'startend:quotation'),
185 | # ('change_case:capital_word_frequency', 'change_case:capital_word_frequency'),
186 | ('change_case:english_capital', 'keywords:forbidden_words'),
187 | ('language:response_language', 'punctuation:no_comma'),
188 | ('change_case:capital_word_frequency', 'punctuation:no_comma'),
189 | ('keywords:forbidden_words', 'punctuation:no_comma'),
190 | ('keywords:forbidden_words', 'keywords:frequency'),
191 | ('detectable_format:number_highlighted_sections', 'startend:quotation'),
192 | ('change_case:english_lowercase', 'detectable_format:number_bullet_lists'),
193 | ('detectable_format:number_bullet_lists', 'keywords:forbidden_words'),
194 | ('detectable_format:multiple_sections', 'startend:quotation'),
195 | ('keywords:frequency', 'keywords:frequency'),
196 | ('combination:repeat_prompt', 'keywords:existence'),
197 | ('change_case:english_lowercase', 'keywords:forbidden_words'),
198 | ('keywords:frequency', 'punctuation:no_comma'),
199 | ('combination:repeat_prompt', 'keywords:frequency'),
200 | ('combination:repeat_prompt', 'punctuation:no_comma'),
201 | ('keywords:existence', 'keywords:forbidden_words'),
202 | ('detectable_format:number_highlighted_sections', 'keywords:frequency'),
203 | ('detectable_format:number_highlighted_sections', 'punctuation:no_comma'),
204 | ('detectable_format:number_highlighted_sections', 'keywords:existence'),
205 | ('keywords:frequency', 'startend:end_checker'),
206 | ('punctuation:no_comma', 'startend:end_checker'),
207 | ('detectable_format:number_bullet_lists', 'startend:quotation'),
208 | ('change_case:english_capital', 'punctuation:no_comma'),
209 | # ('change_case:english_capital', 'detectable_format:multiple_sections'),
210 | ('change_case:english_capital', 'keywords:existence'),
211 | ('detectable_format:multiple_sections', 'startend:end_checker'),
212 | ('keywords:existence', 'startend:quotation'),
213 | ('change_case:english_lowercase', 'startend:quotation'),
214 | ('change_case:english_capital', 'detectable_format:number_highlighted_sections'),
215 | ('change_case:english_capital', 'startend:end_checker'),
216 | ('detectable_format:number_bullet_lists', 'punctuation:no_comma')
217 | ]
218 |
219 | TRIPLE_CONSTRAINT = [
220 | ('change_case:capital_word_frequency', 'length_constraint_checkers:number_words', 'keywords:frequency'),
221 | ('combination:repeat_prompt', 'keywords:existence', 'keywords:forbidden_words'),
222 | ('detectable_format:number_highlighted_sections', 'punctuation:no_comma', 'startend:end_checker'),
223 | ('detectable_format:number_highlighted_sections', 'keywords:existence', 'punctuation:no_comma'),
224 | ('change_case:capital_word_frequency', 'keywords:frequency', 'punctuation:no_comma'),
225 | ('keywords:existence', 'keywords:frequency', 'keywords:frequency'),
226 | ('detectable_format:number_highlighted_sections', 'keywords:existence', 'keywords:frequency'),
227 | ('detectable_format:number_bullet_lists', 'keywords:forbidden_words', 'keywords:frequency'),
228 | ('detectable_format:multiple_sections', 'keywords:existence', 'keywords:frequency'),
229 | ('detectable_format:number_bullet_lists', 'keywords:existence', 'length_constraint_checkers:number_words'),
230 | ('combination:repeat_prompt', 'length_constraint_checkers:number_words', 'keywords:forbidden_words'),
231 | ('change_case:english_lowercase', 'keywords:existence', 'keywords:frequency'),
232 | ('keywords:existence', 'length_constraint_checkers:number_words', 'startend:quotation'),
233 | ('combination:repeat_prompt', 'detectable_format:number_highlighted_sections', 'length_constraint_checkers:number_words'),
234 | ('change_case:english_capital', 'punctuation:no_comma', 'startend:quotation')
235 | ]
236 |
--------------------------------------------------------------------------------
/code/constraint_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility library of instructions."""
16 |
17 | import functools
18 | import os
19 | import random
20 | import re
21 | from importlib.metadata import version
22 |
23 | import immutabledict
24 | import nltk
25 | from packaging.version import parse as parse_version
26 |
27 |
28 | # Downloading 'punkt' with nltk<3.9 has a remote code vuln.
29 | # see https://github.com/EleutherAI/lm-evaluation-harness/issues/2210
30 | # and https://github.com/nltk/nltk/issues/3266
31 | # for more information.
32 | NLTK_MIN_VERSION = "3.9.1"
33 | RANK = os.environ.get("LOCAL_RANK", "0")
34 |
35 |
36 | def download_nltk_resources():
37 | """Download 'punkt' if not already installed"""
38 | assert (nltk_version := parse_version(version("nltk"))) >= parse_version(
39 | NLTK_MIN_VERSION
40 | ), (
41 | f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
42 | )
43 |
44 | try:
45 | nltk.data.find("tokenizers/punkt_tab")
46 | except LookupError:
47 | if RANK == "0":
48 | nltk.download("punkt_tab")
49 | print("Downloaded punkt_tab on rank 0")
50 |
51 |
52 | download_nltk_resources()
53 |
54 | #['center', 'Plugging', 'xy', 'list', '92', 'stars', 'bar', 'items', '2da', '1000a', '100b', 'Using', 'tangent', 'blues', 'simply', 'To', 'mid', 'integer', '17k', 'eta', 'similar', 'way', 'lottery', 'circles', 'IL', '60', '10c', 'rows', 'Taking', 'BCD', '81', 'H', 'using', 'blue', 'red', 'pairs', 'more', 'monotonic', 'intersecting', 'positive', 'consider', '63', 'S', '104', 'congruent', 'parallel', 'lambda', 'would', 'triangles', 'black', 'candy', 'hearts', 'OI', 'occupy', '144', '96', 'segment', 'cong', 'but', 'casework', 'no', 'every', 'four', 'work', 'b_3', 'sum_', 'out', 'Equation', 'area', 'perpendicular', 'inradius', 'coordinates', 'make', 'calculate', 'multiply', 'strategy', 'winning', 'prize', 'DE', 'bi', '108b', '468', '34', 'x_1', 'y_1', 'who', 'Pi_', 'Substituting', '180', 'denote', 'lengths', 'Hence', 'Consider', 'count', 'neq', 'intersection', 'notice', 'through', 'function', 'Theorem', 'hline', 'form', 'least', 'tetrahedron', 'volume', 'note', 'perp', 'rm', 'terms', 'vertical', 'maximize', 'radius', '657', '4r', 'mathcal', '324', 'alpha', '30', 'digit', 'white', 'columns', '18', '45', 'M', 'lines', '33', 'cycle', 'satisfy', 'does', 'prove', 'b_1', 'them', 'q', 'g', 'they', 'remainder', 'pm', 'tfrac', 'OC', 'sphere', 'those', '46', 'lw', 'just', 'final', '404', 'together', 'like', 'different', 'drawn', '210', 'T', 'ordered', '3a', '2b', '4e', 'residents', '234', '2R', 'P_A', 'P_', 'rectangle', 'median', 'abcd', '1000', 'chip', 'Point', 'midpoint', 'AM', 'E', 'EF', 'Also', 'Furthermore', 'law', '39', 'Notice', 'question', 'go', 'import', 'void', '256', 'Technodoggo', 'geq', 'WLOG', 'b_2', 'symmetry', 'b_4', 'configurations', 'unique', 'Adding', 'giving', 'call', 'region', 'compute', 'gcd', 'divided', 'Solution', 'becomes', 'plane', 'ABCD', 'comhttps', 'AR', '2A', 'incenter', 'box', '2a', 'remaining', 'G', 'wins', 'under', 'last', 'integers', 'follows', 'CE', 'Draw', 'horizontal', '117i', '432', 'ge', '1190', '4046', '192', '900', '437', 'bag', 'CL', 'BL', 'ab', 'sqrt3', 'axis', 'x_C', 'db', 'occupied', 'cell', 'x_m', 'y_n', 'hours', '240', 'divide', 'Finally', '113', 'altitude', 'DAC', '4x', 'OD', 'reds', 'vertex', 'Longrightarrow', 'configuration', 'here', 'drawing', 'functions', 'slope', 'whose', 'array', 'bmod', '51', 'pmatrix', 'height', 'due', 'Pythagorean', 'obtain', '189', 'User', 'V', '405', 'coordinate', 'its', 'base', 'polynomial', 'yz', 'First', 'distinct', 'variable', '2r', 'simplifies', 'coin', 'A_i', 'player', 'move', 'their', 'start', 'sets', 'And', 'identical', 'even', 'OH', 'cap', 'path', 'row', '75a', '117b', '75b', '4a', 'phi', 'write', '480', 'greater', '3c', 'label']
55 |
56 | WORD_LIST = [
57 | "align",
58 | "number",
59 | "find",
60 | "therefore",
61 | "equation",
62 | "answer",
63 | "must",
64 | "now",
65 | "same",
66 | "imply",
67 | "because",
68 | "solution",
69 | "since",
70 | "where",
71 | "choose",
72 | "between",
73 | "length",
74 | "side",
75 | "follow",
76 | "case",
77 | "when",
78 | "value",
79 | "point",
80 | "because",
81 | "total",
82 | "denote",
83 | "see",
84 | "equal",
85 | "possible",
86 | "problem",
87 | "draw",
88 | "formula",
89 | "expression",
90 | "given",
91 | "adjacent",
92 | "note",
93 | "function",
94 | "above",
95 | "win",
96 | "than",
97 | "maximum",
98 | "root",
99 | "bar",
100 | "yield",
101 | "condition",
102 | "theorem",
103 | "respectively",
104 | "valid",
105 | "simply",
106 | "similar",
107 | "strategy",
108 | "function",
109 | "furthermore",
110 | "question",
111 | "configuration",
112 | "identical"
113 | ] # pylint: disable=line-too-long
114 |
115 | # ISO 639-1 codes to language names.
116 | LANGUAGE_CODES = immutabledict.immutabledict(
117 | {
118 | "en": "English",
119 | "es": "Spanish",
120 | "pt": "Portuguese",
121 | "ar": "Arabic",
122 | "hi": "Hindi",
123 | "fr": "French",
124 | "ru": "Russian",
125 | "de": "German",
126 | "ja": "Japanese",
127 | "it": "Italian",
128 | "bn": "Bengali",
129 | "uk": "Ukrainian",
130 | "th": "Thai",
131 | "ur": "Urdu",
132 | "ta": "Tamil",
133 | "te": "Telugu",
134 | "bg": "Bulgarian",
135 | "ko": "Korean",
136 | "pl": "Polish",
137 | "he": "Hebrew",
138 | "fa": "Persian",
139 | "vi": "Vietnamese",
140 | "ne": "Nepali",
141 | "sw": "Swahili",
142 | "kn": "Kannada",
143 | "mr": "Marathi",
144 | "gu": "Gujarati",
145 | "pa": "Punjabi",
146 | "ml": "Malayalam",
147 | "fi": "Finnish",
148 | }
149 | )
150 |
151 | _ALPHABETS = "([A-Za-z])"
152 | _PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]"
153 | _SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)"
154 | _STARTERS = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
155 | _ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
156 | _WEBSITES = "[.](com|net|org|io|gov|edu|me)"
157 | _DIGITS = "([0-9])"
158 | _MULTIPLE_DOTS = r"\.{2,}"
159 |
160 |
161 | def split_into_sentences(text):
162 | """Split the text into sentences.
163 |
164 | Args:
165 | text: A string that consists of more than or equal to one sentences.
166 |
167 | Returns:
168 | A list of strings where each string is a sentence.
169 | """
170 | text = " " + text + " "
171 | text = text.replace("\n", " ")
172 | text = re.sub(_PREFIXES, "\\1", text)
173 | text = re.sub(_WEBSITES, "\\1", text)
174 | text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text)
175 | text = re.sub(
176 | _MULTIPLE_DOTS,
177 | lambda match: "" * len(match.group(0)) + "",
178 | text,
179 | )
180 | if "Ph.D" in text:
181 | text = text.replace("Ph.D.", "PhD")
182 | text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text)
183 | text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text)
184 | text = re.sub(
185 | _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]",
186 | "\\1\\2\\3",
187 | text,
188 | )
189 | text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text)
190 | text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text)
191 | text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text)
192 | text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text)
193 | if "”" in text:
194 | text = text.replace(".”", "”.")
195 | if '"' in text:
196 | text = text.replace('."', '".')
197 | if "!" in text:
198 | text = text.replace('!"', '"!')
199 | if "?" in text:
200 | text = text.replace('?"', '"?')
201 | text = text.replace(".", ".")
202 | text = text.replace("?", "?")
203 | text = text.replace("!", "!")
204 | text = text.replace("", ".")
205 | sentences = text.split("")
206 | sentences = [s.strip() for s in sentences]
207 | if sentences and not sentences[-1]:
208 | sentences = sentences[:-1]
209 | return sentences
210 |
211 |
212 | def count_words(text):
213 | """Counts the number of words."""
214 | tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+")
215 | tokens = tokenizer.tokenize(text)
216 | num_words = len(tokens)
217 | return num_words
218 |
219 |
220 | @functools.lru_cache(maxsize=None)
221 | def _get_sentence_tokenizer():
222 | return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
223 |
224 |
225 | def count_sentences(text):
226 | """Count the number of sentences."""
227 | tokenizer = _get_sentence_tokenizer()
228 | tokenized_sentences = tokenizer.tokenize(text)
229 | return len(tokenized_sentences)
230 |
231 |
232 | def generate_keywords(num_keywords):
233 | """Randomly generates a few keywords."""
234 | return random.sample(WORD_LIST, k=num_keywords)
235 |
--------------------------------------------------------------------------------
/code/eval_if.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import json
5 |
6 | from constraint_registry import INSTRUCTION_DICT
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--hypothesis_path', type=str)
10 | parser.add_argument('--data_path', type=str)
11 | parser.add_argument("--delimiter", type=str, default = "")
12 | args = parser.parse_args()
13 |
14 |
15 | def test_instruction_following_strict(
16 | instruction_id_list,
17 | response,
18 | parameters,
19 | prompt,
20 | ):
21 | """Tests response to see if instructions are followed."""
22 |
23 | is_following_list = []
24 | for index, instruction_id in enumerate(instruction_id_list):
25 | try:
26 | instruction_cls = INSTRUCTION_DICT[instruction_id]
27 | except:
28 | import pdb
29 | pdb.set_trace()
30 | instruction = instruction_cls(instruction_id)
31 |
32 | # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
33 | if parameters[index]:
34 | kwargs = {n: p for n, p in parameters[index].items() if p}
35 | else:
36 | kwargs = {}
37 | instruction.build_description(**kwargs)
38 | args = instruction.get_constraint_args()
39 | if args and "prompt" in args:
40 | instruction.build_description(prompt=prompt)
41 | try:
42 | if response.strip() and instruction.check_following(response):
43 | is_following_list.append(True)
44 | else:
45 | is_following_list.append(False)
46 | except:
47 | import pdb
48 | pdb.set_trace()
49 |
50 | return is_following_list
51 |
52 |
53 | strict = []
54 | loose = []
55 |
56 |
57 |
58 | for line1,line2 in zip(open(args.hypothesis_path).readlines(), open(args.data_path).readlines()):
59 | try:
60 | hypothesis = json.loads(line1)["output"]
61 | except:
62 | hypothesis = json.loads(line1)["response"]
63 | if isinstance(hypothesis,list):
64 | hypothesis = hypothesis[0]
65 | has_end_think = '' in hypothesis
66 | has_start_think = '' in hypothesis
67 |
68 | think = hypothesis
69 | if has_end_think:
70 | think = think.split("")[0]
71 | if '' in think:
72 | think = think.split('')[1]
73 |
74 | if '' in hypothesis:
75 | hypothesis = hypothesis.split('')[1]
76 | if '' in hypothesis:
77 | hypothesis = hypothesis.split('')[1]
78 | if '' in hypothesis:
79 | hypothesis = hypothesis.split('')[1]
80 | if '' in hypothesis:
81 | hypothesis = hypothesis.split('')[0]
82 |
83 |
84 | data = json.loads(line2)
85 | if not ("noconstraint" in args.hypothesis_path):
86 | is_follow_list = test_instruction_following_strict(
87 | data["constraint_name"],
88 | hypothesis,
89 | data["constraint_args"],
90 | data["question"],
91 | )
92 | strict.append(all(is_follow_list))
93 | loose.append(sum(is_follow_list)/len(is_follow_list))
94 | else:
95 | # only place holder
96 | strict.append(1)
97 | loose.append(1)
98 |
99 |
100 |
101 | print(sum(strict)/len(strict))
102 | print(sum(loose)/len(loose))
--------------------------------------------------------------------------------
/code/requirement.txt:
--------------------------------------------------------------------------------
1 | vllm==0.7.3
2 | torchdata==0.11.0
3 | tensordict==0.6.0
4 | hydra-core
5 | wheel
6 | flash-attn==2.7.4.post1
7 | accelerate==1.6.0
8 | pylatexenc
9 | codetiming
10 | nltk
11 | langdetect
12 | immutabledict
13 | datasets
--------------------------------------------------------------------------------
/code/script/download.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ./data
2 | for source in gsm8k math500 minerva olympiad aime
3 | do
4 | for constraint in single double triple
5 | do
6 | wget https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/${source}_${constraint}.jsonl -o data/${source}_${constraint}.jsonl
7 | done
8 | done
--------------------------------------------------------------------------------
/code/script/eval_if.sh:
--------------------------------------------------------------------------------
1 | model=deepseek-ai_DeepSeek-R1-Distill-Qwen-1.5B
2 | for dataset in gsm8k math500 minerva olympiad aime
3 | do
4 | for constraint in single double triple
5 | do
6 | echo ${model}_${dataset}_${constraint}
7 | python3 -u code/eval_if.py \
8 | --data_path data/${dataset}_${constraint}.jsonl \
9 | --hypothesis_path output/${model}_${dataset}_${constraint}_t1.0p0.95max16384seedNone.jsonl
10 |
11 | echo ${model}_${dataset}_${constraint}_noconstraint
12 | python3 -u code/eval_if.py \
13 | --data_path data/${dataset}_${constraint}.jsonl \
14 | --hypothesis_path output/${model}_${dataset}_${constraint}_t1.0p0.95max16384seedNone_noconstraint.jsonl
15 | done
16 | done
17 |
18 |
19 |
--------------------------------------------------------------------------------
/code/script/vllm_if.sh:
--------------------------------------------------------------------------------
1 | for model in deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
2 | do
3 | for dataset in gsm8k math500 minerva olympiad aime
4 | do
5 | for constraint in single double triple
6 | do
7 | python3 -u code/vllm_core.py \
8 | --test_file data/${dataset}_${constraint}.jsonl \
9 | --model_name_or_path ${model} \
10 | --top_p 0.95 \
11 | --temperature 1.0 \
12 | --max_token 16384 \
13 |
14 | python3 -u code/vllm_core.py \
15 | --test_file data/${dataset}_${constraint}.jsonl \
16 | --model_name_or_path ${model} \
17 | --top_p 0.95 \
18 | --temperature 1.0 \
19 | --max_token 16384 \
20 | --no_constraint
21 |
22 | done
23 | done
24 | done
25 |
26 |
--------------------------------------------------------------------------------
/code/vllm_core.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_dataset
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | # from peft import PeftModel
5 | import os
6 | import torch
7 | from vllm import LLM, SamplingParams
8 | import json
9 | import sys
10 | from pathlib import Path
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_name_or_path', type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
15 | parser.add_argument('--tokenizer_name_or_path', type=str, default=None)
16 | parser.add_argument("--test_file",type=str,default=None)
17 | parser.add_argument("--temperature",type=float,default=1.0)
18 | parser.add_argument("--top_p",type=float,default=0.95)
19 | parser.add_argument("--max_token",type=int,default=16384)
20 | parser.add_argument("--n_sample",type=int,default=1)
21 | parser.add_argument("--seed",type=int,default=None)
22 | parser.add_argument("--no_constraint",action="store_true")
23 | args = parser.parse_args()
24 |
25 |
26 |
27 |
28 |
29 | def vllm_inference(prompt):
30 | num_gpus = torch.cuda.device_count()
31 | llm = LLM(model = args.model_name_or_path,
32 | tokenizer = args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
33 | dtype='bfloat16',
34 | tensor_parallel_size = 4,#num_gpus,
35 | trust_remote_code=True,
36 | )
37 | print('>>>>>> model loaded')
38 |
39 | sampling_params = SamplingParams(temperature = args.temperature, top_p = args.top_p, max_tokens = args.max_token, seed = args.seed, n = args.n_sample)
40 | outputs = llm.generate(prompt, sampling_params)
41 | sorted_outputs = sorted(outputs, key=lambda output: int(output.request_id))
42 | print('>>>>>> generation done')
43 |
44 | return sorted_outputs
45 |
46 |
47 |
48 | ds = load_dataset('json' if 'json' in args.test_file else 'parquet', data_files=args.test_file, split='train')
49 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path)
50 | prompt = []
51 | cot = "Let's think step by step and output the final answer within \\boxed{}."
52 | for i in range(len(ds)):
53 | if 'question' in ds.column_names:
54 | prompt.append(tokenizer.apply_chat_template([{
55 | "role": "user",
56 | "content": ds[i]['question']+ cot + (" ".join(ds[i]["constraint_desc"]) if not args.no_constraint else ""),
57 | }],tokenize=False, add_generation_prompt=True))
58 | elif "prompt" in ds.column_names:
59 | prompt.append(tokenizer.apply_chat_template(ds[i]["prompt"], tokenize=False, add_generation_prompt=True))
60 |
61 | print(">>>>>>>>>>>>>>>>>>>>>>>>")
62 | print(prompt[0])
63 | print(">>>>>>>>>>>>>>>>>>>>>>>>")
64 |
65 |
66 |
67 | output_path = "output/{}_{}_t{}p{}max{}seed{}.jsonl".format(args.model_name_or_path.replace("/","_"), args.test_file.split('/')[-1].split('.')[0], args.temperature, args.top_p, args.max_token, args.seed)
68 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
69 | output_path = output_path.replace(".jsonl", "_noconstraint.jsonl") if args.no_constraint else output_path
70 |
71 |
72 | output = vllm_inference(prompt)
73 | fout = open(output_path,'w', encoding='utf8')
74 | for i in range(len(prompt)):
75 | fout.write(json.dumps({"output": [ output[i].outputs[j].text for j in range(args.n_sample)]}, ensure_ascii=False)+'\n')
76 | fout.close()
77 |
78 |
--------------------------------------------------------------------------------
/data/aime_double.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_double.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:17630 (17K) [text/plain]
6 | 正在保存至: “aime_double.jsonl”
7 |
8 | 0K .......... ....... 100% 118M=0s
9 |
10 | 2025-05-20 16:10:46 (118 MB/s) - 已保存 “aime_double.jsonl” [17630/17630])
11 |
12 |
--------------------------------------------------------------------------------
/data/aime_single.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_single.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:12535 (12K) [text/plain]
6 | 正在保存至: “aime_single.jsonl”
7 |
8 | 0K .......... .. 100% 11.7G=0s
9 |
10 | 2025-05-20 16:10:46 (11.7 GB/s) - 已保存 “aime_single.jsonl” [12535/12535])
11 |
12 |
--------------------------------------------------------------------------------
/data/aime_triple.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:46-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/aime_triple.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:19088 (19K) [text/plain]
6 | 正在保存至: “aime_triple.jsonl”
7 |
8 | 0K .......... ........ 100% 150M=0s
9 |
10 | 2025-05-20 16:10:47 (150 MB/s) - 已保存 “aime_triple.jsonl” [19088/19088])
11 |
12 |
--------------------------------------------------------------------------------
/data/gsm8k_double.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:39-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_double.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.103, 54.230.71.28, 54.230.71.56, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.103|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:25342 (25K) [text/plain]
6 | 正在保存至: “gsm8k_double.jsonl”
7 |
8 | 0K .......... .......... .... 100% 47.9M=0.001s
9 |
10 | 2025-05-20 16:10:40 (47.9 MB/s) - 已保存 “gsm8k_double.jsonl” [25342/25342])
11 |
12 |
--------------------------------------------------------------------------------
/data/gsm8k_single.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:37-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_single.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:18960 (19K) [text/plain]
6 | 正在保存至: “gsm8k_single.jsonl”
7 |
8 | 0K .......... ........ 100% 700K=0.03s
9 |
10 | 2025-05-20 16:10:39 (700 KB/s) - 已保存 “gsm8k_single.jsonl” [18960/18960])
11 |
12 |
--------------------------------------------------------------------------------
/data/gsm8k_triple.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:40-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/gsm8k_triple.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.28, 54.230.71.56, 54.230.71.2, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.28|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:29140 (28K) [text/plain]
6 | 正在保存至: “gsm8k_triple.jsonl”
7 |
8 | 0K .......... .......... ........ 100% 90.5M=0s
9 |
10 | 2025-05-20 16:10:41 (90.5 MB/s) - 已保存 “gsm8k_triple.jsonl” [29140/29140])
11 |
12 |
--------------------------------------------------------------------------------
/data/math500_double.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:41-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_double.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:20441 (20K) [text/plain]
6 | 正在保存至: “math500_double.jsonl”
7 |
8 | 0K .......... ......... 100% 105M=0s
9 |
10 | 2025-05-20 16:10:42 (105 MB/s) - 已保存 “math500_double.jsonl” [20441/20441])
11 |
12 |
--------------------------------------------------------------------------------
/data/math500_single.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:41-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_single.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.56, 54.230.71.2, 54.230.71.103, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.56|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:15679 (15K) [text/plain]
6 | 正在保存至: “math500_single.jsonl”
7 |
8 | 0K .......... ..... 100% 67.7M=0s
9 |
10 | 2025-05-20 16:10:41 (67.7 MB/s) - 已保存 “math500_single.jsonl” [15679/15679])
11 |
12 |
--------------------------------------------------------------------------------
/data/math500_triple.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:42-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/math500_triple.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:26727 (26K) [text/plain]
6 | 正在保存至: “math500_triple.jsonl”
7 |
8 | 0K .......... .......... ...... 100% 13.6M=0.002s
9 |
10 | 2025-05-20 16:10:42 (13.6 MB/s) - 已保存 “math500_triple.jsonl” [26727/26727])
11 |
12 |
--------------------------------------------------------------------------------
/data/minerva_double.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:43-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_double.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:27121 (26K) [text/plain]
6 | 正在保存至: “minerva_double.jsonl”
7 |
8 | 0K .......... .......... ...... 100% 26.6M=0.001s
9 |
10 | 2025-05-20 16:10:44 (26.6 MB/s) - 已保存 “minerva_double.jsonl” [27121/27121])
11 |
12 |
--------------------------------------------------------------------------------
/data/minerva_single.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:42-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_single.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:23996 (23K) [text/plain]
6 | 正在保存至: “minerva_single.jsonl”
7 |
8 | 0K .......... .......... ... 100% 48.5M=0s
9 |
10 | 2025-05-20 16:10:43 (48.5 MB/s) - 已保存 “minerva_single.jsonl” [23996/23996])
11 |
12 |
--------------------------------------------------------------------------------
/data/minerva_triple.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:44-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/minerva_triple.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:34264 (33K) [text/plain]
6 | 正在保存至: “minerva_triple.jsonl”
7 |
8 | 0K .......... .......... .......... ... 100% 1.93M=0.02s
9 |
10 | 2025-05-20 16:10:44 (1.93 MB/s) - 已保存 “minerva_triple.jsonl” [34264/34264])
11 |
12 |
--------------------------------------------------------------------------------
/data/olympiad_double.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:45-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_double.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:27628 (27K) [text/plain]
6 | 正在保存至: “olympiad_double.jsonl”
7 |
8 | 0K .......... .......... ...... 100% 2.63M=0.01s
9 |
10 | 2025-05-20 16:10:45 (2.63 MB/s) - 已保存 “olympiad_double.jsonl” [27628/27628])
11 |
12 |
--------------------------------------------------------------------------------
/data/olympiad_single.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:44-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_single.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:19297 (19K) [text/plain]
6 | 正在保存至: “olympiad_single.jsonl”
7 |
8 | 0K .......... ........ 100% 52.3M=0s
9 |
10 | 2025-05-20 16:10:45 (52.3 MB/s) - 已保存 “olympiad_single.jsonl” [19297/19297])
11 |
12 |
--------------------------------------------------------------------------------
/data/olympiad_triple.jsonl:
--------------------------------------------------------------------------------
1 | --2025-05-20 16:10:45-- https://huggingface.co/datasets/TingchenFu/MathIF/resolve/main/olympiad_triple.jsonl
2 | 正在解析主机 huggingface.co (huggingface.co)... 54.230.71.2, 54.230.71.103, 54.230.71.28, ...
3 | 正在连接 huggingface.co (huggingface.co)|54.230.71.2|:443... 已连接。
4 | 已发出 HTTP 请求,正在等待回应... 200 OK
5 | 长度:27840 (27K) [text/plain]
6 | 正在保存至: “olympiad_triple.jsonl”
7 |
8 | 0K .......... .......... ....... 100% 1.16M=0.02s
9 |
10 | 2025-05-20 16:10:46 (1.16 MB/s) - 已保存 “olympiad_triple.jsonl” [27840/27840])
11 |
12 |
--------------------------------------------------------------------------------
/data/toy.jsonl:
--------------------------------------------------------------------------------
1 | {"source": "HuggingFaceH4/MATH-500", "id": "level2-single-0", "question": "Simplify $\\tan 100^\\circ + 4 \\sin 100^\\circ.$", "answer": "-\\sqrt{3}", "constraint_desc": ["Include keywords \"['find', 'must']\" in the response."], "constraint_name": ["keywords:existence"], "constraint_args": [{"keywords": ["find", "must"]}]}
2 | {"source": "HuggingFaceH4/MATH-500", "id": "level2-single-2", "question": "In right triangle $ABC$ with $\\angle B = 90^\\circ$, we have $\\sin A = 2\\cos A$. What is $\\tan A$?", "answer": "2", "constraint_desc": ["Do not include keywords \"['respectively', 'strategy']\" in the response."], "constraint_name": ["keywords:forbidden_words"], "constraint_args": [{"forbidden_words": ["respectively", "strategy"]}]}
--------------------------------------------------------------------------------