├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── api_client.py ├── ifeval.py ├── metrics.py ├── multi_turn_instruct_following_eval_api.py ├── multi_turn_instruct_following_eval_vllm.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | Meta-Llama-3.1-8B-Instruct/ 2 | Meta-Llama-3.1-70B-Instruct/ 3 | *.csv 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Multi-IF 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## Coding Style 31 | * 2 spaces for indentation rather than tabs 32 | * 80 character line length 33 | 34 | ## License 35 | By contributing to Multi-IF, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Turn Evaluation Framework 2 | 3 | This repository contains a Python implementation of a multi-turn evaluation benchmark for large language model (LLM). The benchmark is designed to evaluate the performance of LLM models' capabilities in multi-turn instruction following within a multilingual environment. 4 | 5 | 6 | 7 | ## Files 8 | 9 | The repository contains the following files: 10 | 11 | * `api_client.py`: This file contains the implementation of the interface for LLMs interactions via API calls. 12 | * `ifeval.py`: This file contains the implementation of the Inference Evaluation (IFEVAL) metric, which is used to evaluate the capability of LLM following natural language instructions 13 | * `metrics.py`: This file contains the implementation of various metrics that can be used to calculate ifeval, data preprocess and enrichment for multi turn instructions. 14 | * `utils.py`: This file contains utility functions that are used throughout the framework, e.g., GenerationSetting, get_inference_batch(),preprocess_data. 15 | * `multi_turn_instruct_following_eval_api.py`: This file contains the main function that executes the multi-turn evaluation benchmark via API calls. 16 | * `multi_turn_instruct_following_eval_vllm.py`: This file contains the main function that executes the multi-turn evaluation benchmark via running on local GPUs. 17 | 18 | 19 | ## Usage 20 | 21 | To use the multi-turn evaluation benchmar, follow these steps: 22 | 23 | 1. Clone the repository: 24 | ```bash 25 | git clone https://github.com/your-username/multi-turn-evaluation.git 26 | ``` 27 | 2. Install the required dependencies: 28 | ```bash 29 | cd multi_turn_eval 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | 3. Download the data from huggingface: 34 | ``` 35 | git clone https://huggingface.co/datasets/facebook/Multi-IF data/Multi-IF 36 | ``` 37 | 38 | 4. Run the main function in `multi_turn_instruct_following_eval_vllm.py`: 39 | ```bash 40 | python multi_turn_instruct_following_eval_vllm.py \ 41 | --model_path \ 42 | --tokenizer_path \ 43 | --input_data_csv \ 44 | --batch_size \ 45 | --tensor_parallel_size 46 | ``` 47 | This will execute the multi-turn evaluation benchmar and output the results to the console and intermediate generation results saved in csv files. 48 | 49 | For example, for Meta-Llama-3.1-70B-Instruct, 50 | ```bash 51 | python multi_turn_instruct_following_eval_vllm.py \ 52 | --model_path meta-llama/Llama-3.1-70B-Instruct \ 53 | --tokenizer_path meta-llama/Llama-3.1-70B-Instruct \ 54 | --input_data_csv data/Multi-IF/multiIF_20241018.csv \ 55 | --batch_size 4 \ 56 | --tensor_parallel_size 8 57 | ``` 58 | 59 | Or for running evaluation via API please use 60 | 61 | 4. Run the main function in `multi_turn_instruct_following_eval_api.py` with `claude-3.5-sonnet-20240620`: 62 | ```bash 63 | python multi_turn_instruct_following_eval_api.py \ 64 | --max_workers 5 \ 65 | --api_model_name claude-3.5-sonnet-20240620 \ 66 | --input_data_csv data/Multi-IF/multiIF_20241018.csv \ 67 | --max_new_tokens 1024 \ 68 | --steps 1 2 3 69 | ``` 70 | 71 | ## Bibtex 72 | If you use the code or benchmark, please consider citing the following paper: 73 | ``` 74 | @article{he2024multi, 75 | title={Multi-IF: Benchmarking LLMs on Multi-Turn and Multilingual Instructions Following}, 76 | author={He, Yun and Jin, Di and Wang, Chaoqi and Bi, Chloe and Mandyam, Karishma and Zhang, Hejia and Zhu, Chen and Li, Ning and Xu, Tengyu and Lv, Hongjiang and others}, 77 | journal={arXiv preprint arXiv:2410.15553}, 78 | year={2024} 79 | } 80 | ``` 81 | 82 | ## Contributing 83 | 84 | We welcome contributions to this repository! If you have any suggestions or improvements, please open an issue or submit a pull request. 85 | 86 | ## License 87 | 88 | This project is licensed under the Apache License, Version 2.0 (the "License"). See the LICENSE file for details. 89 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /api_client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | import anthropic 17 | import google.generativeai as genai 18 | from google.generativeai.types import HarmCategory, HarmBlockThreshold 19 | from mistralai import Mistral 20 | from openai import OpenAI 21 | 22 | from utils import GenerationSetting 23 | 24 | 25 | def get_api_bot(model_name, generation_setting): 26 | if OpenAIBot.check_name(model_name): 27 | return OpenAIBot(model_name, generation_config=generation_setting) 28 | elif AnthropicBot.check_name(model_name): 29 | return AnthropicBot(model_name, generation_setting) 30 | elif GeminiBot.check_name(model_name): 31 | return GeminiBot(model_name, generation_setting) 32 | elif MistralBot.check_name(model_name): 33 | return MistralBot(model_name, generation_setting) 34 | else: 35 | raise NotImplementedError(f"The model {model_name} is not supported yet.") 36 | 37 | class APIBot: 38 | 39 | def __init__(self, model, generation_config): 40 | self.model_name = model 41 | self.generation_config = generation_config 42 | 43 | def generate(self, messages): 44 | ... 45 | 46 | def check_name(self, name): 47 | ... 48 | 49 | 50 | class OpenAIBot(APIBot): 51 | def __init__(self, model, generation_config): 52 | super().__init__(model, generation_config) 53 | self.client = OpenAI() 54 | 55 | def generate(self, messages) -> str: 56 | response = self.client.chat.completions.create( 57 | model=self.model_name, 58 | messages=messages, 59 | max_completion_tokens=self.generation_config.max_new_tokens, 60 | seed=self.generation_config.seed, 61 | top_p=self.generation_config.top_p, 62 | temperature=self.generation_config.temperature 63 | ) 64 | return response.choices[0].message.content 65 | 66 | @staticmethod 67 | def check_name(name): 68 | if name in ['o1-preview', 69 | 'o1-mini', 70 | 'o1-preview-2024-09-12', 71 | 'o1-mini-2024-09-12', 72 | 'gpt-4-turbo', 73 | 'gpt-4-turbo-2024-04-09', 74 | 'gpt-4-turbo-preview', 75 | 'gpt-4-0125-preview', 76 | 'gpt-4-1106-preview', 77 | 'gpt-4', 78 | 'gpt-4-0613', 79 | 'gpt-4o-2024-08-06' 80 | ]: 81 | return True 82 | return False 83 | 84 | 85 | class AnthropicBot(APIBot): 86 | def __init__(self, model, generation_config): 87 | super().__init__(model, generation_config) 88 | # make sure ti set the API key for anthropic, 89 | # which will be accessed via os.environ.get("ANTHROPIC_API_KEY") 90 | self.client = anthropic.Anthropic() 91 | 92 | def generate(self, messages): 93 | # Anthropic models don't support manual seed. 94 | response = self.client.messages.create( 95 | model=self.model_name, 96 | max_tokens=self.generation_config.max_new_tokens, 97 | messages=messages, 98 | temperature=self.generation_config.temperature, 99 | top_p=self.generation_config.top_p 100 | ) 101 | return response.content[0].text 102 | 103 | @staticmethod 104 | def check_name(name): 105 | if name in ['claude-3-haiku-20240307', 'claude-3-sonnet-20240229', 'claude-3-5-sonnet-20240620']: 106 | return True 107 | return False 108 | 109 | 110 | class GeminiBot(APIBot): 111 | def __init__(self, model, generation_config): 112 | super().__init__(model, generation_config) 113 | # Configure the genai client with the API key 114 | 115 | genai.configure(api_key=os.environ['GEMINI_API_KEY']) 116 | self.model = genai.GenerativeModel(model) 117 | self.generation_config = generation_config 118 | 119 | def generate(self, messages): 120 | role_map = { 121 | 'user': 'user', 122 | 'assistant': 'model' 123 | } 124 | history = [] 125 | for m in messages: 126 | history.append({'role': role_map[m['role']], 'parts': m['content']}) 127 | chat = self.model.start_chat(history=history) 128 | response = chat.send_message( 129 | messages[-1]['content'], 130 | generation_config={ 131 | "temperature": self.generation_config.temperature, 132 | "top_p": self.generation_config.top_p, 133 | "max_output_tokens": self.generation_config.max_new_tokens 134 | }, 135 | safety_settings={ 136 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 137 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 138 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 139 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 140 | } 141 | ) 142 | return response.text 143 | 144 | @staticmethod 145 | def check_name(name): 146 | if name in [ 147 | 'gemini-1.5-flash', 148 | 'gemini-1.5-pro', 149 | ]: 150 | return True 151 | return False 152 | 153 | 154 | class MistralBot(APIBot): 155 | def __init__(self, model, generation_config): 156 | super().__init__(model, generation_config) 157 | api_key = os.environ["MISTRAL_API_KEY"] 158 | self.client = Mistral(api_key=api_key) 159 | self.model = model 160 | 161 | def generate(self, messages): 162 | response = self.client.chat.complete( 163 | model=self.model_name, 164 | messages=messages, 165 | temperature=self.generation_config.temperature, 166 | random_seed=self.generation_config.seed, 167 | max_tokens=self.generation_config.max_new_tokens, 168 | top_p=self.generation_config.top_p 169 | ) 170 | return response.choices[0].message.content 171 | 172 | @staticmethod 173 | def check_name(name): 174 | if name in [ 175 | 'mistral-large-latest', 176 | 'mistral-small-latest', 177 | # Add any other supported model names here 178 | ]: 179 | return True 180 | return False 181 | 182 | 183 | 184 | if __name__ == '__main__': 185 | generation_setting = GenerationSetting(max_new_tokens=1024, temperature=0.6, top_p=0.9) 186 | bot = get_api_bot('mistral-small-latest',generation_setting) 187 | history = [ 188 | {'role': 'user', 'content': 'create an equation.'}, 189 | {'role': 'assistant', 'content': 'x^2-4x+4=0'}, 190 | {'role': 'user', 'content': 'solve the equation.'} 191 | ] 192 | print(bot.generate(history)) 193 | -------------------------------------------------------------------------------- /ifeval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | import functools 17 | import json 18 | import logging 19 | import random 20 | import re 21 | import string 22 | from types import MappingProxyType 23 | from typing import Dict, Iterable, Optional, Sequence, Union 24 | 25 | from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai 26 | from pythainlp.tokenize import word_tokenize as word_tokenize_thai 27 | 28 | try: 29 | import langdetect 30 | except ImportError: 31 | langdetect = None 32 | try: 33 | import emoji 34 | except ImportError: 35 | emoji = None 36 | import nltk 37 | 38 | logger = logging.getLogger() 39 | 40 | WORD_LIST = [ 41 | "western", 42 | "sentence", 43 | "signal", 44 | "dump", 45 | "spot", 46 | "opposite", 47 | "bottom", 48 | "potato", 49 | "administration", 50 | "working", 51 | "welcome", 52 | "morning", 53 | "good", 54 | "agency", 55 | "primary", 56 | "wish", 57 | "responsibility", 58 | "press", 59 | "problem", 60 | "president", 61 | "steal", 62 | "brush", 63 | "read", 64 | "type", 65 | "beat", 66 | "trainer", 67 | "growth", 68 | "lock", 69 | "bone", 70 | "case", 71 | "equal", 72 | "comfortable", 73 | "region", 74 | "replacement", 75 | "performance", 76 | "mate", 77 | "walk", 78 | "medicine", 79 | "film", 80 | "thing", 81 | "rock", 82 | "tap", 83 | "total", 84 | "competition", 85 | "ease", 86 | "south", 87 | "establishment", 88 | "gather", 89 | "parking", 90 | "world", 91 | "plenty", 92 | "breath", 93 | "claim", 94 | "alcohol", 95 | "trade", 96 | "dear", 97 | "highlight", 98 | "street", 99 | "matter", 100 | "decision", 101 | "mess", 102 | "agreement", 103 | "studio", 104 | "coach", 105 | "assist", 106 | "brain", 107 | "wing", 108 | "style", 109 | "private", 110 | "top", 111 | "brown", 112 | "leg", 113 | "buy", 114 | "procedure", 115 | "method", 116 | "speed", 117 | "high", 118 | "company", 119 | "valuable", 120 | "pie", 121 | "analyst", 122 | "session", 123 | "pattern", 124 | "district", 125 | "pleasure", 126 | "dinner", 127 | "swimming", 128 | "joke", 129 | "order", 130 | "plate", 131 | "department", 132 | "motor", 133 | "cell", 134 | "spend", 135 | "cabinet", 136 | "difference", 137 | "power", 138 | "examination", 139 | "engine", 140 | "horse", 141 | "dimension", 142 | "pay", 143 | "toe", 144 | "curve", 145 | "literature", 146 | "bother", 147 | "fire", 148 | "possibility", 149 | "debate", 150 | "activity", 151 | "passage", 152 | "hello", 153 | "cycle", 154 | "background", 155 | "quiet", 156 | "author", 157 | "effect", 158 | "actor", 159 | "page", 160 | "bicycle", 161 | "error", 162 | "throat", 163 | "attack", 164 | "character", 165 | "phone", 166 | "tea", 167 | "increase", 168 | "outcome", 169 | "file", 170 | "specific", 171 | "inspector", 172 | "internal", 173 | "potential", 174 | "staff", 175 | "building", 176 | "employer", 177 | "shoe", 178 | "hand", 179 | "direction", 180 | "garden", 181 | "purchase", 182 | "interview", 183 | "study", 184 | "recognition", 185 | "member", 186 | "spiritual", 187 | "oven", 188 | "sandwich", 189 | "weird", 190 | "passenger", 191 | "particular", 192 | "response", 193 | "reaction", 194 | "size", 195 | "variation", 196 | "a", 197 | "cancel", 198 | "candy", 199 | "exit", 200 | "guest", 201 | "condition", 202 | "fly", 203 | "price", 204 | "weakness", 205 | "convert", 206 | "hotel", 207 | "great", 208 | "mouth", 209 | "mind", 210 | "song", 211 | "sugar", 212 | "suspect", 213 | "telephone", 214 | "ear", 215 | "roof", 216 | "paint", 217 | "refrigerator", 218 | "organization", 219 | "jury", 220 | "reward", 221 | "engineering", 222 | "day", 223 | "possession", 224 | "crew", 225 | "bar", 226 | "road", 227 | "description", 228 | "celebration", 229 | "score", 230 | "mark", 231 | "letter", 232 | "shower", 233 | "suggestion", 234 | "sir", 235 | "luck", 236 | "national", 237 | "progress", 238 | "hall", 239 | "stroke", 240 | "theory", 241 | "offer", 242 | "story", 243 | "tax", 244 | "definition", 245 | "history", 246 | "ride", 247 | "medium", 248 | "opening", 249 | "glass", 250 | "elevator", 251 | "stomach", 252 | "question", 253 | "ability", 254 | "leading", 255 | "village", 256 | "computer", 257 | "city", 258 | "grand", 259 | "confidence", 260 | "candle", 261 | "priest", 262 | "recommendation", 263 | "point", 264 | "necessary", 265 | "body", 266 | "desk", 267 | "secret", 268 | "horror", 269 | "noise", 270 | "culture", 271 | "warning", 272 | "water", 273 | "round", 274 | "diet", 275 | "flower", 276 | "bus", 277 | "tough", 278 | "permission", 279 | "week", 280 | "prompt", 281 | "connection", 282 | "abuse", 283 | "height", 284 | "save", 285 | "corner", 286 | "border", 287 | "stress", 288 | "drive", 289 | "stop", 290 | "rip", 291 | "meal", 292 | "listen", 293 | "confusion", 294 | "girlfriend", 295 | "living", 296 | "relation", 297 | "significance", 298 | "plan", 299 | "creative", 300 | "atmosphere", 301 | "blame", 302 | "invite", 303 | "housing", 304 | "paper", 305 | "drink", 306 | "roll", 307 | "silver", 308 | "drunk", 309 | "age", 310 | "damage", 311 | "smoke", 312 | "environment", 313 | "pack", 314 | "savings", 315 | "influence", 316 | "tourist", 317 | "rain", 318 | "post", 319 | "sign", 320 | "grandmother", 321 | "run", 322 | "profit", 323 | "push", 324 | "clerk", 325 | "final", 326 | "wine", 327 | "swim", 328 | "pause", 329 | "stuff", 330 | "singer", 331 | "funeral", 332 | "average", 333 | "source", 334 | "scene", 335 | "tradition", 336 | "personal", 337 | "snow", 338 | "nobody", 339 | "distance", 340 | "sort", 341 | "sensitive", 342 | "animal", 343 | "major", 344 | "negotiation", 345 | "click", 346 | "mood", 347 | "period", 348 | "arrival", 349 | "expression", 350 | "holiday", 351 | "repeat", 352 | "dust", 353 | "closet", 354 | "gold", 355 | "bad", 356 | "sail", 357 | "combination", 358 | "clothes", 359 | "emphasis", 360 | "duty", 361 | "black", 362 | "step", 363 | "school", 364 | "jump", 365 | "document", 366 | "professional", 367 | "lip", 368 | "chemical", 369 | "front", 370 | "wake", 371 | "while", 372 | "inside", 373 | "watch", 374 | "row", 375 | "subject", 376 | "penalty", 377 | "balance", 378 | "possible", 379 | "adult", 380 | "aside", 381 | "sample", 382 | "appeal", 383 | "wedding", 384 | "depth", 385 | "king", 386 | "award", 387 | "wife", 388 | "blow", 389 | "site", 390 | "camp", 391 | "music", 392 | "safe", 393 | "gift", 394 | "fault", 395 | "guess", 396 | "act", 397 | "shame", 398 | "drama", 399 | "capital", 400 | "exam", 401 | "stupid", 402 | "record", 403 | "sound", 404 | "swing", 405 | "novel", 406 | "minimum", 407 | "ratio", 408 | "machine", 409 | "shape", 410 | "lead", 411 | "operation", 412 | "salary", 413 | "cloud", 414 | "affair", 415 | "hit", 416 | "chapter", 417 | "stage", 418 | "quantity", 419 | "access", 420 | "army", 421 | "chain", 422 | "traffic", 423 | "kick", 424 | "analysis", 425 | "airport", 426 | "time", 427 | "vacation", 428 | "philosophy", 429 | "ball", 430 | "chest", 431 | "thanks", 432 | "place", 433 | "mountain", 434 | "advertising", 435 | "red", 436 | "past", 437 | "rent", 438 | "return", 439 | "tour", 440 | "house", 441 | "construction", 442 | "net", 443 | "native", 444 | "war", 445 | "figure", 446 | "fee", 447 | "spray", 448 | "user", 449 | "dirt", 450 | "shot", 451 | "task", 452 | "stick", 453 | "friend", 454 | "software", 455 | "promotion", 456 | "interaction", 457 | "surround", 458 | "block", 459 | "purpose", 460 | "practice", 461 | "conflict", 462 | "routine", 463 | "requirement", 464 | "bonus", 465 | "hole", 466 | "state", 467 | "junior", 468 | "sweet", 469 | "catch", 470 | "tear", 471 | "fold", 472 | "wall", 473 | "editor", 474 | "life", 475 | "position", 476 | "pound", 477 | "respect", 478 | "bathroom", 479 | "coat", 480 | "script", 481 | "job", 482 | "teach", 483 | "birth", 484 | "view", 485 | "resolve", 486 | "theme", 487 | "employee", 488 | "doubt", 489 | "market", 490 | "education", 491 | "serve", 492 | "recover", 493 | "tone", 494 | "harm", 495 | "miss", 496 | "union", 497 | "understanding", 498 | "cow", 499 | "river", 500 | "association", 501 | "concept", 502 | "training", 503 | "recipe", 504 | "relationship", 505 | "reserve", 506 | "depression", 507 | "proof", 508 | "hair", 509 | "revenue", 510 | "independent", 511 | "lift", 512 | "assignment", 513 | "temporary", 514 | "amount", 515 | "loss", 516 | "edge", 517 | "track", 518 | "check", 519 | "rope", 520 | "estimate", 521 | "pollution", 522 | "stable", 523 | "message", 524 | "delivery", 525 | "perspective", 526 | "mirror", 527 | "assistant", 528 | "representative", 529 | "witness", 530 | "nature", 531 | "judge", 532 | "fruit", 533 | "tip", 534 | "devil", 535 | "town", 536 | "emergency", 537 | "upper", 538 | "drop", 539 | "stay", 540 | "human", 541 | "neck", 542 | "speaker", 543 | "network", 544 | "sing", 545 | "resist", 546 | "league", 547 | "trip", 548 | "signature", 549 | "lawyer", 550 | "importance", 551 | "gas", 552 | "choice", 553 | "engineer", 554 | "success", 555 | "part", 556 | "external", 557 | "worker", 558 | "simple", 559 | "quarter", 560 | "student", 561 | "heart", 562 | "pass", 563 | "spite", 564 | "shift", 565 | "rough", 566 | "lady", 567 | "grass", 568 | "community", 569 | "garage", 570 | "youth", 571 | "standard", 572 | "skirt", 573 | "promise", 574 | "blind", 575 | "television", 576 | "disease", 577 | "commission", 578 | "positive", 579 | "energy", 580 | "calm", 581 | "presence", 582 | "tune", 583 | "basis", 584 | "preference", 585 | "head", 586 | "common", 587 | "cut", 588 | "somewhere", 589 | "presentation", 590 | "current", 591 | "thought", 592 | "revolution", 593 | "effort", 594 | "master", 595 | "implement", 596 | "republic", 597 | "floor", 598 | "principle", 599 | "stranger", 600 | "shoulder", 601 | "grade", 602 | "button", 603 | "tennis", 604 | "police", 605 | "collection", 606 | "account", 607 | "register", 608 | "glove", 609 | "divide", 610 | "professor", 611 | "chair", 612 | "priority", 613 | "combine", 614 | "peace", 615 | "extension", 616 | "maybe", 617 | "evening", 618 | "frame", 619 | "sister", 620 | "wave", 621 | "code", 622 | "application", 623 | "mouse", 624 | "match", 625 | "counter", 626 | "bottle", 627 | "half", 628 | "cheek", 629 | "resolution", 630 | "back", 631 | "knowledge", 632 | "make", 633 | "discussion", 634 | "screw", 635 | "length", 636 | "accident", 637 | "battle", 638 | "dress", 639 | "knee", 640 | "log", 641 | "package", 642 | "it", 643 | "turn", 644 | "hearing", 645 | "newspaper", 646 | "layer", 647 | "wealth", 648 | "profile", 649 | "imagination", 650 | "answer", 651 | "weekend", 652 | "teacher", 653 | "appearance", 654 | "meet", 655 | "bike", 656 | "rise", 657 | "belt", 658 | "crash", 659 | "bowl", 660 | "equivalent", 661 | "support", 662 | "image", 663 | "poem", 664 | "risk", 665 | "excitement", 666 | "remote", 667 | "secretary", 668 | "public", 669 | "produce", 670 | "plane", 671 | "display", 672 | "money", 673 | "sand", 674 | "situation", 675 | "punch", 676 | "customer", 677 | "title", 678 | "shake", 679 | "mortgage", 680 | "option", 681 | "number", 682 | "pop", 683 | "window", 684 | "extent", 685 | "nothing", 686 | "experience", 687 | "opinion", 688 | "departure", 689 | "dance", 690 | "indication", 691 | "boy", 692 | "material", 693 | "band", 694 | "leader", 695 | "sun", 696 | "beautiful", 697 | "muscle", 698 | "farmer", 699 | "variety", 700 | "fat", 701 | "handle", 702 | "director", 703 | "opportunity", 704 | "calendar", 705 | "outside", 706 | "pace", 707 | "bath", 708 | "fish", 709 | "consequence", 710 | "put", 711 | "owner", 712 | "go", 713 | "doctor", 714 | "information", 715 | "share", 716 | "hurt", 717 | "protection", 718 | "career", 719 | "finance", 720 | "force", 721 | "golf", 722 | "garbage", 723 | "aspect", 724 | "kid", 725 | "food", 726 | "boot", 727 | "milk", 728 | "respond", 729 | "objective", 730 | "reality", 731 | "raw", 732 | "ring", 733 | "mall", 734 | "one", 735 | "impact", 736 | "area", 737 | "news", 738 | "international", 739 | "series", 740 | "impress", 741 | "mother", 742 | "shelter", 743 | "strike", 744 | "loan", 745 | "month", 746 | "seat", 747 | "anything", 748 | "entertainment", 749 | "familiar", 750 | "clue", 751 | "year", 752 | "glad", 753 | "supermarket", 754 | "natural", 755 | "god", 756 | "cost", 757 | "conversation", 758 | "tie", 759 | "ruin", 760 | "comfort", 761 | "earth", 762 | "storm", 763 | "percentage", 764 | "assistance", 765 | "budget", 766 | "strength", 767 | "beginning", 768 | "sleep", 769 | "other", 770 | "young", 771 | "unit", 772 | "fill", 773 | "store", 774 | "desire", 775 | "hide", 776 | "value", 777 | "cup", 778 | "maintenance", 779 | "nurse", 780 | "function", 781 | "tower", 782 | "role", 783 | "class", 784 | "camera", 785 | "database", 786 | "panic", 787 | "nation", 788 | "basket", 789 | "ice", 790 | "art", 791 | "spirit", 792 | "chart", 793 | "exchange", 794 | "feedback", 795 | "statement", 796 | "reputation", 797 | "search", 798 | "hunt", 799 | "exercise", 800 | "nasty", 801 | "notice", 802 | "male", 803 | "yard", 804 | "annual", 805 | "collar", 806 | "date", 807 | "platform", 808 | "plant", 809 | "fortune", 810 | "passion", 811 | "friendship", 812 | "spread", 813 | "cancer", 814 | "ticket", 815 | "attitude", 816 | "island", 817 | "active", 818 | "object", 819 | "service", 820 | "buyer", 821 | "bite", 822 | "card", 823 | "face", 824 | "steak", 825 | "proposal", 826 | "patient", 827 | "heat", 828 | "rule", 829 | "resident", 830 | "broad", 831 | "politics", 832 | "west", 833 | "knife", 834 | "expert", 835 | "girl", 836 | "design", 837 | "salt", 838 | "baseball", 839 | "grab", 840 | "inspection", 841 | "cousin", 842 | "couple", 843 | "magazine", 844 | "cook", 845 | "dependent", 846 | "security", 847 | "chicken", 848 | "version", 849 | "currency", 850 | "ladder", 851 | "scheme", 852 | "kitchen", 853 | "employment", 854 | "local", 855 | "attention", 856 | "manager", 857 | "fact", 858 | "cover", 859 | "sad", 860 | "guard", 861 | "relative", 862 | "county", 863 | "rate", 864 | "lunch", 865 | "program", 866 | "initiative", 867 | "gear", 868 | "bridge", 869 | "breast", 870 | "talk", 871 | "dish", 872 | "guarantee", 873 | "beer", 874 | "vehicle", 875 | "reception", 876 | "woman", 877 | "substance", 878 | "copy", 879 | "lecture", 880 | "advantage", 881 | "park", 882 | "cold", 883 | "death", 884 | "mix", 885 | "hold", 886 | "scale", 887 | "tomorrow", 888 | "blood", 889 | "request", 890 | "green", 891 | "cookie", 892 | "church", 893 | "strip", 894 | "forever", 895 | "beyond", 896 | "debt", 897 | "tackle", 898 | "wash", 899 | "following", 900 | "feel", 901 | "maximum", 902 | "sector", 903 | "sea", 904 | "property", 905 | "economics", 906 | "menu", 907 | "bench", 908 | "try", 909 | "language", 910 | "start", 911 | "call", 912 | "solid", 913 | "address", 914 | "income", 915 | "foot", 916 | "senior", 917 | "honey", 918 | "few", 919 | "mixture", 920 | "cash", 921 | "grocery", 922 | "link", 923 | "map", 924 | "form", 925 | "factor", 926 | "pot", 927 | "model", 928 | "writer", 929 | "farm", 930 | "winter", 931 | "skill", 932 | "anywhere", 933 | "birthday", 934 | "policy", 935 | "release", 936 | "husband", 937 | "lab", 938 | "hurry", 939 | "mail", 940 | "equipment", 941 | "sink", 942 | "pair", 943 | "driver", 944 | "consideration", 945 | "leather", 946 | "skin", 947 | "blue", 948 | "boat", 949 | "sale", 950 | "brick", 951 | "two", 952 | "feed", 953 | "square", 954 | "dot", 955 | "rush", 956 | "dream", 957 | "location", 958 | "afternoon", 959 | "manufacturer", 960 | "control", 961 | "occasion", 962 | "trouble", 963 | "introduction", 964 | "advice", 965 | "bet", 966 | "eat", 967 | "kill", 968 | "category", 969 | "manner", 970 | "office", 971 | "estate", 972 | "pride", 973 | "awareness", 974 | "slip", 975 | "crack", 976 | "client", 977 | "nail", 978 | "shoot", 979 | "membership", 980 | "soft", 981 | "anybody", 982 | "web", 983 | "official", 984 | "individual", 985 | "pizza", 986 | "interest", 987 | "bag", 988 | "spell", 989 | "profession", 990 | "queen", 991 | "deal", 992 | "resource", 993 | "ship", 994 | "guy", 995 | "chocolate", 996 | "joint", 997 | "formal", 998 | "upstairs", 999 | "car", 1000 | "resort", 1001 | "abroad", 1002 | "dealer", 1003 | "associate", 1004 | "finger", 1005 | "surgery", 1006 | "comment", 1007 | "team", 1008 | "detail", 1009 | "crazy", 1010 | "path", 1011 | "tale", 1012 | "initial", 1013 | "arm", 1014 | "radio", 1015 | "demand", 1016 | "single", 1017 | "draw", 1018 | "yellow", 1019 | "contest", 1020 | "piece", 1021 | "quote", 1022 | "pull", 1023 | "commercial", 1024 | "shirt", 1025 | "contribution", 1026 | "cream", 1027 | "channel", 1028 | "suit", 1029 | "discipline", 1030 | "instruction", 1031 | "concert", 1032 | "speech", 1033 | "low", 1034 | "effective", 1035 | "hang", 1036 | "scratch", 1037 | "industry", 1038 | "breakfast", 1039 | "lay", 1040 | "join", 1041 | "metal", 1042 | "bedroom", 1043 | "minute", 1044 | "product", 1045 | "rest", 1046 | "temperature", 1047 | "many", 1048 | "give", 1049 | "argument", 1050 | "print", 1051 | "purple", 1052 | "laugh", 1053 | "health", 1054 | "credit", 1055 | "investment", 1056 | "sell", 1057 | "setting", 1058 | "lesson", 1059 | "egg", 1060 | "middle", 1061 | "marriage", 1062 | "level", 1063 | "evidence", 1064 | "phrase", 1065 | "love", 1066 | "self", 1067 | "benefit", 1068 | "guidance", 1069 | "affect", 1070 | "you", 1071 | "dad", 1072 | "anxiety", 1073 | "special", 1074 | "boyfriend", 1075 | "test", 1076 | "blank", 1077 | "payment", 1078 | "soup", 1079 | "obligation", 1080 | "reply", 1081 | "smile", 1082 | "deep", 1083 | "complaint", 1084 | "addition", 1085 | "review", 1086 | "box", 1087 | "towel", 1088 | "minor", 1089 | "fun", 1090 | "soil", 1091 | "issue", 1092 | "cigarette", 1093 | "internet", 1094 | "gain", 1095 | "tell", 1096 | "entry", 1097 | "spare", 1098 | "incident", 1099 | "family", 1100 | "refuse", 1101 | "branch", 1102 | "can", 1103 | "pen", 1104 | "grandfather", 1105 | "constant", 1106 | "tank", 1107 | "uncle", 1108 | "climate", 1109 | "ground", 1110 | "volume", 1111 | "communication", 1112 | "kind", 1113 | "poet", 1114 | "child", 1115 | "screen", 1116 | "mine", 1117 | "quit", 1118 | "gene", 1119 | "lack", 1120 | "charity", 1121 | "memory", 1122 | "tooth", 1123 | "fear", 1124 | "mention", 1125 | "marketing", 1126 | "reveal", 1127 | "reason", 1128 | "court", 1129 | "season", 1130 | "freedom", 1131 | "land", 1132 | "sport", 1133 | "audience", 1134 | "classroom", 1135 | "law", 1136 | "hook", 1137 | "win", 1138 | "carry", 1139 | "eye", 1140 | "smell", 1141 | "distribution", 1142 | "research", 1143 | "country", 1144 | "dare", 1145 | "hope", 1146 | "whereas", 1147 | "stretch", 1148 | "library", 1149 | "if", 1150 | "delay", 1151 | "college", 1152 | "plastic", 1153 | "book", 1154 | "present", 1155 | "use", 1156 | "worry", 1157 | "champion", 1158 | "goal", 1159 | "economy", 1160 | "march", 1161 | "election", 1162 | "reflection", 1163 | "midnight", 1164 | "slide", 1165 | "inflation", 1166 | "action", 1167 | "challenge", 1168 | "guitar", 1169 | "coast", 1170 | "apple", 1171 | "campaign", 1172 | "field", 1173 | "jacket", 1174 | "sense", 1175 | "way", 1176 | "visual", 1177 | "remove", 1178 | "weather", 1179 | "trash", 1180 | "cable", 1181 | "regret", 1182 | "buddy", 1183 | "beach", 1184 | "historian", 1185 | "courage", 1186 | "sympathy", 1187 | "truck", 1188 | "tension", 1189 | "permit", 1190 | "nose", 1191 | "bed", 1192 | "son", 1193 | "person", 1194 | "base", 1195 | "meat", 1196 | "usual", 1197 | "air", 1198 | "meeting", 1199 | "worth", 1200 | "game", 1201 | "independence", 1202 | "physical", 1203 | "brief", 1204 | "play", 1205 | "raise", 1206 | "board", 1207 | "she", 1208 | "key", 1209 | "writing", 1210 | "pick", 1211 | "command", 1212 | "party", 1213 | "yesterday", 1214 | "spring", 1215 | "candidate", 1216 | "physics", 1217 | "university", 1218 | "concern", 1219 | "development", 1220 | "change", 1221 | "string", 1222 | "target", 1223 | "instance", 1224 | "room", 1225 | "bitter", 1226 | "bird", 1227 | "football", 1228 | "normal", 1229 | "split", 1230 | "impression", 1231 | "wood", 1232 | "long", 1233 | "meaning", 1234 | "stock", 1235 | "cap", 1236 | "leadership", 1237 | "media", 1238 | "ambition", 1239 | "fishing", 1240 | "essay", 1241 | "salad", 1242 | "repair", 1243 | "today", 1244 | "designer", 1245 | "night", 1246 | "bank", 1247 | "drawing", 1248 | "inevitable", 1249 | "phase", 1250 | "vast", 1251 | "chip", 1252 | "anger", 1253 | "switch", 1254 | "cry", 1255 | "twist", 1256 | "personality", 1257 | "attempt", 1258 | "storage", 1259 | "being", 1260 | "preparation", 1261 | "bat", 1262 | "selection", 1263 | "white", 1264 | "technology", 1265 | "contract", 1266 | "side", 1267 | "section", 1268 | "station", 1269 | "till", 1270 | "structure", 1271 | "tongue", 1272 | "taste", 1273 | "truth", 1274 | "difficulty", 1275 | "group", 1276 | "limit", 1277 | "main", 1278 | "move", 1279 | "feeling", 1280 | "light", 1281 | "example", 1282 | "mission", 1283 | "might", 1284 | "wait", 1285 | "wheel", 1286 | "shop", 1287 | "host", 1288 | "classic", 1289 | "alternative", 1290 | "cause", 1291 | "agent", 1292 | "consist", 1293 | "table", 1294 | "airline", 1295 | "text", 1296 | "pool", 1297 | "craft", 1298 | "range", 1299 | "fuel", 1300 | "tool", 1301 | "partner", 1302 | "load", 1303 | "entrance", 1304 | "deposit", 1305 | "hate", 1306 | "article", 1307 | "video", 1308 | "summer", 1309 | "feature", 1310 | "extreme", 1311 | "mobile", 1312 | "hospital", 1313 | "flight", 1314 | "fall", 1315 | "pension", 1316 | "piano", 1317 | "fail", 1318 | "result", 1319 | "rub", 1320 | "gap", 1321 | "system", 1322 | "report", 1323 | "suck", 1324 | "ordinary", 1325 | "wind", 1326 | "nerve", 1327 | "ask", 1328 | "shine", 1329 | "note", 1330 | "line", 1331 | "mom", 1332 | "perception", 1333 | "brother", 1334 | "reference", 1335 | "bend", 1336 | "charge", 1337 | "treat", 1338 | "trick", 1339 | "term", 1340 | "homework", 1341 | "bake", 1342 | "bid", 1343 | "status", 1344 | "project", 1345 | "strategy", 1346 | "orange", 1347 | "let", 1348 | "enthusiasm", 1349 | "parent", 1350 | "concentrate", 1351 | "device", 1352 | "travel", 1353 | "poetry", 1354 | "business", 1355 | "society", 1356 | "kiss", 1357 | "end", 1358 | "vegetable", 1359 | "employ", 1360 | "schedule", 1361 | "hour", 1362 | "brave", 1363 | "focus", 1364 | "process", 1365 | "movie", 1366 | "illegal", 1367 | "general", 1368 | "coffee", 1369 | "ad", 1370 | "highway", 1371 | "chemistry", 1372 | "psychology", 1373 | "hire", 1374 | "bell", 1375 | "conference", 1376 | "relief", 1377 | "show", 1378 | "neat", 1379 | "funny", 1380 | "weight", 1381 | "quality", 1382 | "club", 1383 | "daughter", 1384 | "zone", 1385 | "touch", 1386 | "tonight", 1387 | "shock", 1388 | "burn", 1389 | "excuse", 1390 | "name", 1391 | "survey", 1392 | "landscape", 1393 | "advance", 1394 | "satisfaction", 1395 | "bread", 1396 | "disaster", 1397 | "item", 1398 | "hat", 1399 | "prior", 1400 | "shopping", 1401 | "visit", 1402 | "east", 1403 | "photo", 1404 | "home", 1405 | "idea", 1406 | "father", 1407 | "comparison", 1408 | "cat", 1409 | "pipe", 1410 | "winner", 1411 | "count", 1412 | "lake", 1413 | "fight", 1414 | "prize", 1415 | "foundation", 1416 | "dog", 1417 | "keep", 1418 | "ideal", 1419 | "fan", 1420 | "struggle", 1421 | "peak", 1422 | "safety", 1423 | "solution", 1424 | "hell", 1425 | "conclusion", 1426 | "population", 1427 | "strain", 1428 | "alarm", 1429 | "measurement", 1430 | "second", 1431 | "train", 1432 | "race", 1433 | "due", 1434 | "insurance", 1435 | "boss", 1436 | "tree", 1437 | "monitor", 1438 | "sick", 1439 | "course", 1440 | "drag", 1441 | "appointment", 1442 | "slice", 1443 | "still", 1444 | "care", 1445 | "patience", 1446 | "rich", 1447 | "escape", 1448 | "emotion", 1449 | "royal", 1450 | "female", 1451 | "childhood", 1452 | "government", 1453 | "picture", 1454 | "will", 1455 | "sock", 1456 | "big", 1457 | "gate", 1458 | "oil", 1459 | "cross", 1460 | "pin", 1461 | "improvement", 1462 | "championship", 1463 | "silly", 1464 | "help", 1465 | "sky", 1466 | "pitch", 1467 | "man", 1468 | "diamond", 1469 | "most", 1470 | "transition", 1471 | "work", 1472 | "science", 1473 | "committee", 1474 | "moment", 1475 | "fix", 1476 | "teaching", 1477 | "dig", 1478 | "specialist", 1479 | "complex", 1480 | "guide", 1481 | "people", 1482 | "dead", 1483 | "voice", 1484 | "original", 1485 | "break", 1486 | "topic", 1487 | "data", 1488 | "degree", 1489 | "reading", 1490 | "recording", 1491 | "bunch", 1492 | "reach", 1493 | "judgment", 1494 | "lie", 1495 | "regular", 1496 | "set", 1497 | "painting", 1498 | "mode", 1499 | "list", 1500 | "player", 1501 | "bear", 1502 | "north", 1503 | "wonder", 1504 | "carpet", 1505 | "heavy", 1506 | "officer", 1507 | "negative", 1508 | "clock", 1509 | "unique", 1510 | "baby", 1511 | "pain", 1512 | "assumption", 1513 | "disk", 1514 | "iron", 1515 | "bill", 1516 | "drawer", 1517 | "look", 1518 | "double", 1519 | "mistake", 1520 | "finish", 1521 | "future", 1522 | "brilliant", 1523 | "contact", 1524 | "math", 1525 | "rice", 1526 | "leave", 1527 | "restaurant", 1528 | "discount", 1529 | "sex", 1530 | "virus", 1531 | "bit", 1532 | "trust", 1533 | "event", 1534 | "wear", 1535 | "juice", 1536 | "failure", 1537 | "bug", 1538 | "context", 1539 | "mud", 1540 | "whole", 1541 | "wrap", 1542 | "intention", 1543 | "draft", 1544 | "pressure", 1545 | "cake", 1546 | "dark", 1547 | "explanation", 1548 | "space", 1549 | "angle", 1550 | "word", 1551 | "efficiency", 1552 | "management", 1553 | "habit", 1554 | "star", 1555 | "chance", 1556 | "finding", 1557 | "transportation", 1558 | "stand", 1559 | "criticism", 1560 | "flow", 1561 | "door", 1562 | "injury", 1563 | "insect", 1564 | "surprise", 1565 | "apartment", 1566 | ] # pylint: disable=line-too-long 1567 | 1568 | # ISO 639-1 codes to language names. 1569 | LANGUAGE_CODES = MappingProxyType( 1570 | { 1571 | "en": "English", 1572 | "es": "Spanish", 1573 | "pt": "Portuguese", 1574 | "ar": "Arabic", 1575 | "hi": "Hindi", 1576 | "fr": "French", 1577 | "ru": "Russian", 1578 | "de": "German", 1579 | "ja": "Japanese", 1580 | "it": "Italian", 1581 | "bn": "Bengali", 1582 | "uk": "Ukrainian", 1583 | "th": "Thai", 1584 | "ur": "Urdu", 1585 | "ta": "Tamil", 1586 | "te": "Telugu", 1587 | "bg": "Bulgarian", 1588 | "ko": "Korean", 1589 | "pl": "Polish", 1590 | "he": "Hebrew", 1591 | "fa": "Persian", 1592 | "vi": "Vietnamese", 1593 | "ne": "Nepali", 1594 | "sw": "Swahili", 1595 | "kn": "Kannada", 1596 | "mr": "Marathi", 1597 | "gu": "Gujarati", 1598 | "pa": "Punjabi", 1599 | "ml": "Malayalam", 1600 | "fi": "Finnish", 1601 | } 1602 | ) 1603 | 1604 | # Chinese characters 1605 | _CHINESE_CHARS_PATTERN = r"[\u4E00-\u9FFF\u3400-\u4DBF]" 1606 | # Japanese Hiragana & Katakana 1607 | _JAPANESE_CHARS_PATTERN = r"[\u3040-\u309f\u30a0-\u30ff]" 1608 | # Korean (Hangul Syllables) 1609 | _KOREAN_CHARS_PATTERN = r"[\uAC00-\uD7AF]" 1610 | _ALPHABETS = "([A-Za-z])" 1611 | _PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" 1612 | _SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" 1613 | _STARTERS = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" 1614 | _ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" 1615 | _WEBSITES = "[.](com|net|org|io|gov|edu|me)" 1616 | _DIGITS = "([0-9])" 1617 | _MULTIPLE_DOTS = r"\.{2,}" 1618 | 1619 | 1620 | # Util functions 1621 | def split_into_sentences(text): 1622 | """Split the text into sentences. 1623 | 1624 | Args: 1625 | text: A string that consists of more than or equal to one sentences. 1626 | 1627 | Returns: 1628 | A list of strings where each string is a sentence. 1629 | """ 1630 | text = " " + text + " " 1631 | text = text.replace("\n", " ") 1632 | text = re.sub(_PREFIXES, "\\1", text) 1633 | text = re.sub(_WEBSITES, "\\1", text) 1634 | text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) 1635 | text = re.sub( 1636 | _MULTIPLE_DOTS, 1637 | lambda match: "" * len(match.group(0)) + "", 1638 | text, 1639 | ) 1640 | if "Ph.D" in text: 1641 | text = text.replace("Ph.D.", "PhD") 1642 | text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) 1643 | text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) 1644 | text = re.sub( 1645 | _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", 1646 | "\\1\\2\\3", 1647 | text, 1648 | ) 1649 | text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) 1650 | text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) 1651 | text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) 1652 | text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) 1653 | if "”" in text: 1654 | text = text.replace(".”", "”.") 1655 | if '"' in text: 1656 | text = text.replace('."', '".') 1657 | if "!" in text: 1658 | text = text.replace('!"', '"!') 1659 | if "?" in text: 1660 | text = text.replace('?"', '"?') 1661 | text = text.replace(".", ".") 1662 | text = text.replace("?", "?") 1663 | text = text.replace("!", "!") 1664 | text = text.replace("", ".") 1665 | sentences = text.split("") 1666 | sentences = [s.strip() for s in sentences] 1667 | if sentences and not sentences[-1]: 1668 | sentences = sentences[:-1] 1669 | return sentences 1670 | 1671 | 1672 | def count_words(text): 1673 | """Counts the number of words.""" 1674 | try: 1675 | tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") 1676 | tokens = tokenizer.tokenize(text) 1677 | num_words = len(tokens) 1678 | except: 1679 | print('Failed to count for', text) 1680 | return 0 1681 | return num_words 1682 | 1683 | 1684 | def split_chinese_japanese(lines: str) -> Iterable[str]: 1685 | """ 1686 | Split Chinese and Japanese text into sentences. 1687 | From https://stackoverflow.com/questions/27441191/splitting-chinese-document-into-sentences 1688 | Special question/exclamation marks were added upon inspection of our raw data 1689 | Also supports multiple lines. 1690 | """ 1691 | for line in lines.splitlines(): 1692 | for sent in re.findall( 1693 | r"[^!?。\.\!\?\!\?\.\n]+[!?。\.\!\?\!\?\.\n]?", line.strip(), flags=re.U 1694 | ): 1695 | yield sent 1696 | 1697 | 1698 | def count_words_chinese_japanese(text: str) -> int: 1699 | """Counts the number of words for Chinese and Japanese and Korean. 1700 | Can be extended to additional languages. 1701 | Source: https://stackoverflow.com/questions/49164507/how-to-count-the-number-of-chinese-korean-and-english-words withadditional modifications 1702 | Example: 1703 | >In: count_words_chinese_japanese('こんにちは、ジェイソンさん、Jason? Nice to meet you☺ ❤') 1704 | >Out: 19 1705 | """ 1706 | # Non alpha numeric patterns in latin and asian languages. 1707 | non_alphanumeric_patterns = ( 1708 | r"[\\.\!\?\.\/_,\{\}<>:;$%^&*(+\"\'+——!,。?、`~@#¥……():;《)《》“”()\[\]»〔〕\-「」]+" 1709 | ) 1710 | text = re.sub(non_alphanumeric_patterns, "", text) 1711 | if emoji: 1712 | emoji_cnt = emoji.emoji_count(text) # count emojis 1713 | text = emoji.replace_emoji(text, "") # remove emojis 1714 | else: 1715 | emoji_cnt = 0 1716 | foreign_chars_patterns = "|".join( 1717 | [_CHINESE_CHARS_PATTERN, _JAPANESE_CHARS_PATTERN, _KOREAN_CHARS_PATTERN] 1718 | ) 1719 | asian_chars = re.findall(foreign_chars_patterns, text) 1720 | asian_chars_cnt = len(asian_chars) 1721 | non_asian_chars = re.sub(foreign_chars_patterns, " ", text) 1722 | non_asian_words_cnt = len(non_asian_chars.split()) 1723 | return non_asian_words_cnt + asian_chars_cnt + emoji_cnt 1724 | 1725 | 1726 | @functools.lru_cache(maxsize=None) 1727 | def _get_sentence_tokenizer(): 1728 | return nltk.data.load("nltk:tokenizers/punkt/english.pickle") 1729 | 1730 | 1731 | def count_sentences(text): 1732 | """Count the number of sentences.""" 1733 | tokenizer = _get_sentence_tokenizer() 1734 | tokenized_sentences = tokenizer.tokenize(text) 1735 | return len(tokenized_sentences) 1736 | 1737 | def count_hindi_num_sentences(text): 1738 | sentences = re.split(r'(?<=[।!?])\s*', text) 1739 | return len([s for s in sentences if s.strip()]) 1740 | 1741 | def generate_keywords(num_keywords): 1742 | """Randomly generates a few keywords.""" 1743 | return random.sample(WORD_LIST, k=num_keywords) 1744 | 1745 | 1746 | """Library of instructions""" 1747 | _InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] 1748 | 1749 | _LANGUAGES = LANGUAGE_CODES 1750 | 1751 | # The relational operation for comparison. 1752 | _COMPARISON_RELATION = ("less than", "at least") 1753 | 1754 | # The maximum number of sentences. 1755 | _MAX_NUM_SENTENCES = 20 1756 | 1757 | # The number of placeholders. 1758 | _NUM_PLACEHOLDERS = 4 1759 | 1760 | # The number of bullet lists. 1761 | _NUM_BULLETS = 5 1762 | 1763 | # The options of constrained response. 1764 | _CONSTRAINED_RESPONSE_OPTIONS = ( 1765 | "My answer is yes.", 1766 | "My answer is no.", 1767 | "My answer is maybe.", 1768 | ) 1769 | 1770 | # The options of starter keywords. 1771 | _STARTER_OPTIONS = ( 1772 | "I would say", 1773 | "My answer is", 1774 | "I believe", 1775 | "In my opinion", 1776 | "I think", 1777 | "I reckon", 1778 | "I feel", 1779 | "From my perspective", 1780 | "As I see it", 1781 | "According to me", 1782 | "As far as I'm concerned", 1783 | "To my understanding", 1784 | "In my view", 1785 | "My take on it is", 1786 | "As per my perception", 1787 | ) 1788 | 1789 | # The options of ending keywords. 1790 | # TODO(jeffreyzhou) add more ending options 1791 | _ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") 1792 | 1793 | # The number of highlighted sections. 1794 | _NUM_HIGHLIGHTED_SECTIONS = 4 1795 | 1796 | # The section spliter. 1797 | _SECTION_SPLITER = ("Section", "SECTION") 1798 | 1799 | # The number of sections. 1800 | _NUM_SECTIONS = 5 1801 | 1802 | # The number of paragraphs. 1803 | _NUM_PARAGRAPHS = 5 1804 | 1805 | # The postscript marker. 1806 | _POSTSCRIPT_MARKER = ("P.S.", "P.P.S") 1807 | 1808 | # The number of keywords. 1809 | _NUM_KEYWORDS = 2 1810 | 1811 | # The occurrences of a single keyword. 1812 | _KEYWORD_FREQUENCY = 3 1813 | 1814 | # The occurrences of a single letter. 1815 | _LETTER_FREQUENCY = 10 1816 | 1817 | # The occurrences of words with all capital letters. 1818 | _ALL_CAPITAL_WORD_FREQUENCY = 20 1819 | 1820 | # The number of words in the response. 1821 | _NUM_WORDS_LOWER_LIMIT = 100 1822 | _NUM_WORDS_UPPER_LIMIT = 500 1823 | 1824 | 1825 | class Instruction: 1826 | """An instruction template.""" 1827 | 1828 | def __init__(self, instruction_id): 1829 | self.id = instruction_id 1830 | 1831 | def build_description(self, **kwargs): 1832 | raise NotImplementedError("`build_description` not implemented.") 1833 | 1834 | def get_instruction_args(self): 1835 | raise NotImplementedError("`get_instruction_args` not implemented.") 1836 | 1837 | def get_instruction_args_keys(self): 1838 | raise NotImplementedError("`get_instruction_args_keys` not implemented.") 1839 | 1840 | def check_following(self, value): 1841 | raise NotImplementedError("`check_following` not implemented.") 1842 | 1843 | 1844 | class ResponseLanguageChecker(Instruction): 1845 | """Check the language of the entire response.""" 1846 | 1847 | def build_description(self, *, language=None): 1848 | """Build the instruction description. 1849 | 1850 | Args: 1851 | language: A string representing the expected language of the response. The 1852 | language has to comply to the 97 types defined in 1853 | `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows 1854 | ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); 1855 | for example, `en` for English, `zh` for Chinese, `fr` for French. 1856 | 1857 | Returns: 1858 | A string representing the instruction description. 1859 | """ 1860 | self._language = language 1861 | if self._language is None: 1862 | self._language = random.choice(list(_LANGUAGES.keys())) 1863 | # TODO(tianjianlu): opens the description generation to more choices. 1864 | self._description_pattern = ( 1865 | "Your ENTIRE response should be in {language} language, no other " 1866 | + "language is allowed." 1867 | ) 1868 | return self._description_pattern.format(language=_LANGUAGES[self._language]) 1869 | 1870 | def get_instruction_args(self): 1871 | """Returns the keyward args of `build_description`.""" 1872 | return {"language": self._language} 1873 | 1874 | def get_instruction_args_keys(self): 1875 | """Returns the args keys of `build_description`.""" 1876 | return ["language"] 1877 | 1878 | def check_following(self, value): 1879 | """Check if the language of the entire response follows the instruction. 1880 | 1881 | Args: 1882 | value: A string representing the response. 1883 | 1884 | Returns: 1885 | True if the language of `value` follows instruction; otherwise False. 1886 | """ 1887 | try: 1888 | assert isinstance(value, str) 1889 | except: 1890 | print('Failed for assertion, got non str type input,', value) 1891 | return False 1892 | 1893 | try: 1894 | return langdetect.detect(value) == self._language 1895 | except langdetect.LangDetectException as e: 1896 | # Count as instruction is followed. 1897 | logger.info( 1898 | "Unable to detect language for text %s due to %s", value, e 1899 | ) # refex: disable=pytotw.037 1900 | return True 1901 | 1902 | 1903 | class NumberOfSentences(Instruction): 1904 | """Check the number of sentences.""" 1905 | 1906 | def build_description(self, *, num_sentences=None, relation=None): 1907 | """Build the instruction description. 1908 | 1909 | Args: 1910 | num_sentences: An integer specifying the number of sentences as a 1911 | threshold. 1912 | relation: A string in (`less than`, `at least`), defining the relational 1913 | operator for comparison. 1914 | Two relational comparisons are supported for now: 1915 | if 'less than', the actual number of sentences < the threshold; 1916 | if 'at least', the actual number of sentences >= the threshold. 1917 | 1918 | Returns: 1919 | A string representing the instruction description. 1920 | """ 1921 | # The number of sentences as a threshold for comparison. 1922 | self._num_sentences_threshold = num_sentences 1923 | if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: 1924 | self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) 1925 | 1926 | if relation is None: 1927 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 1928 | elif relation not in _COMPARISON_RELATION: 1929 | raise ValueError( 1930 | "The supported relation for comparison must be in " 1931 | f"{_COMPARISON_RELATION}, but {relation} is given." 1932 | ) 1933 | else: 1934 | self._comparison_relation = relation 1935 | 1936 | self._description_pattern = ( 1937 | "Your response should contain {relation} {num_sentences} sentences." 1938 | ) 1939 | return self._description_pattern.format( 1940 | relation=self._comparison_relation, 1941 | num_sentences=self._num_sentences_threshold, 1942 | ) 1943 | 1944 | def get_instruction_args(self): 1945 | """Returns the keyward args of `build_description`.""" 1946 | return { 1947 | "num_sentences": self._num_sentences_threshold, 1948 | "relation": self._comparison_relation, 1949 | } 1950 | 1951 | def get_instruction_args_keys(self): 1952 | """Returns the args keys of `build_description`.""" 1953 | return ["num_sentences", "relation"] 1954 | 1955 | def check_following(self, value): 1956 | """Check if the number of sentences follows the instruction. 1957 | 1958 | Args: 1959 | value: A string representing the response. 1960 | 1961 | Returns: 1962 | True if the response follows the instruction. 1963 | 1964 | Raise: 1965 | ValueError if the string in `instruction_args` is not in 1966 | [`less_than`, `at_least`]. 1967 | """ 1968 | try: 1969 | lang = langdetect.detect(value) 1970 | except: 1971 | print("Failed to detect language, got value:", value) 1972 | lang = 'en' 1973 | if lang == "th": 1974 | # print(f"shervin1. lang is {lang}") 1975 | # print(value) 1976 | # Newline also counts as a new sentence: 1977 | num_sentences = sum( 1978 | [len(sent_tokenize_thai(line)) for line in value.splitlines()] 1979 | ) 1980 | # print(f"num sentences: {num_sentences}") 1981 | elif lang == 'hi': 1982 | num_sentences = count_hindi_num_sentences(value) 1983 | elif lang in ["zh", "zh-cn", "zh-tw", "ja"]: 1984 | # print(f"shervin2. lang is {lang}") 1985 | # print(value) 1986 | num_sentences = len(list(split_chinese_japanese(value))) 1987 | # print(f"num sentences: {num_sentences}") 1988 | else: 1989 | # print(f"shervin3: lang is {lang}") 1990 | num_sentences = count_sentences(value) 1991 | # print(f"num sentences: {num_sentences}") 1992 | if self._comparison_relation == _COMPARISON_RELATION[0]: 1993 | return num_sentences < self._num_sentences_threshold 1994 | elif self._comparison_relation == _COMPARISON_RELATION[1]: 1995 | return num_sentences >= self._num_sentences_threshold 1996 | 1997 | 1998 | class PlaceholderChecker(Instruction): 1999 | """Check the placeholders in template writing.""" 2000 | 2001 | def build_description(self, *, num_placeholders=None): 2002 | """Build the instruction description. 2003 | 2004 | Args: 2005 | num_placeholders: An integer denoting the minimum number of 2006 | placeholders required in the response. 2007 | 2008 | Returns: 2009 | A string representing the instruction description. 2010 | """ 2011 | self._num_placeholders = num_placeholders 2012 | if self._num_placeholders is None or self._num_placeholders < 0: 2013 | self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) 2014 | self._description_pattern = ( 2015 | "The response must contain at least {num_placeholders} placeholders " 2016 | + "represented by square brackets, such as [address]." 2017 | ) 2018 | return self._description_pattern.format(num_placeholders=self._num_placeholders) 2019 | 2020 | def get_instruction_args(self): 2021 | """Returns the keyward args of `build_description`.""" 2022 | return {"num_placeholders": self._num_placeholders} 2023 | 2024 | def get_instruction_args_keys(self): 2025 | """Returns the args keys of `build_description`.""" 2026 | return ["num_placeholders"] 2027 | 2028 | def check_following(self, value): 2029 | """Check if the number of placeholders follows the instruction. 2030 | 2031 | Args: 2032 | value: A string representing the response. 2033 | 2034 | Returns: 2035 | True if the actual number of placeholders in the response is greater than 2036 | or equal to `num_placeholders`; otherwise, False. 2037 | """ 2038 | placeholders = re.findall(r"\[.*?\]", value) 2039 | num_placeholders = len(placeholders) 2040 | return num_placeholders >= self._num_placeholders 2041 | 2042 | 2043 | class BulletListChecker(Instruction): 2044 | """Checks the bullet list in the prompt.""" 2045 | 2046 | def build_description(self, *, num_bullets=None): 2047 | """Build the instruction description. 2048 | 2049 | Args: 2050 | num_bullets: An integer specifying the exact number of bullet lists 2051 | that is required to appear in the response. 2052 | 2053 | Returns: 2054 | A string representing the instruction description. 2055 | """ 2056 | self._num_bullets = num_bullets 2057 | if self._num_bullets is None or self._num_bullets < 0: 2058 | self._num_bullets = random.randint(1, _NUM_BULLETS) 2059 | self._description_pattern = ( 2060 | "Your answer must contain exactly {num_bullets} bullet points. " 2061 | + "Use the markdown bullet points such as:\n" 2062 | + "* This is point 1. \n" 2063 | + "* This is point 2" 2064 | ) 2065 | return self._description_pattern.format(num_bullets=self._num_bullets) 2066 | 2067 | def get_instruction_args(self): 2068 | """Returns the keyward args of `build_description`.""" 2069 | return {"num_bullets": self._num_bullets} 2070 | 2071 | def get_instruction_args_keys(self): 2072 | """Returns the args keys of `build_description`.""" 2073 | return ["num_bullets"] 2074 | 2075 | def check_following(self, value): 2076 | r"""Check if the number of bullet lists meets the requirement. 2077 | 2078 | Args: 2079 | value: A string representing the response. The response is expected to 2080 | contain some bullet lists that start with `\*`. 2081 | 2082 | Returns: 2083 | True if the actual number of bullet lists in the response meets the 2084 | requirement. 2085 | """ 2086 | bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) 2087 | bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) 2088 | num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) 2089 | return num_bullet_lists == self._num_bullets 2090 | 2091 | 2092 | class ConstrainedResponseChecker(Instruction): 2093 | """Checks the constrained response.""" 2094 | 2095 | def build_description(self): 2096 | """Build the instruction description.""" 2097 | # A sequence of string(s) representing the options of the expected response. 2098 | self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS 2099 | self._description_pattern = ( 2100 | "Answer with one of the following options: {response_options}" 2101 | ) 2102 | return self._description_pattern.format( 2103 | response_options=self._constrained_responses 2104 | ) 2105 | 2106 | def get_instruction_args(self): 2107 | """Returns the keyward args of `build_description`.""" 2108 | return None 2109 | 2110 | def get_instruction_args_keys(self): 2111 | """Returns the args keys of `build_description`.""" 2112 | return [] 2113 | 2114 | def check_following(self, value): 2115 | """Checks if the response matches the constrained options. 2116 | 2117 | Args: 2118 | value: A string representing the response. 2119 | 2120 | Returns: 2121 | True if the actual response contains one of the options in the constrained 2122 | responses; otherwise False. 2123 | """ 2124 | value = value.strip() 2125 | for constrained_response in self._constrained_responses: 2126 | if constrained_response in value: 2127 | return True 2128 | return False 2129 | 2130 | 2131 | class ConstrainedStartChecker(Instruction): 2132 | """Checks the response start.""" 2133 | 2134 | def build_description(self, *, starter=None): 2135 | """Build the instruction description. 2136 | 2137 | Args: 2138 | starter: A string representing the keyward that the response should start 2139 | with. 2140 | 2141 | Returns: 2142 | A string representing the instruction description. 2143 | """ 2144 | self._starter = starter.strip() if isinstance(starter, str) else starter 2145 | if self._starter is None: 2146 | self._starter = random.choice(_STARTER_OPTIONS) 2147 | self._description_pattern = ( 2148 | "During the conversation, when it is your turn, " 2149 | + "please always start with {starter}" 2150 | ) 2151 | return self._description_pattern.format(starter=self._starter) 2152 | 2153 | def get_instruction_args(self): 2154 | """Returns the keyward args of `build_description`.""" 2155 | return {"starter": self._starter} 2156 | 2157 | def get_instruction_args_keys(self): 2158 | """Returns the args keys of `build_description`.""" 2159 | return ["starter"] 2160 | 2161 | def check_following(self, value): 2162 | """Checks if the response starts with the constrained keyword or phrase. 2163 | 2164 | Args: 2165 | value: A string representing the response. 2166 | 2167 | Returns: 2168 | True if the response starts with the given phrase or keyword that is 2169 | contained in `instruction_args`; otherwise, False. 2170 | """ 2171 | response_pattern = r"^\s*" + self._starter + r".*$" 2172 | response_with_constrained_start = re.search( 2173 | response_pattern, value, flags=re.MULTILINE 2174 | ) 2175 | return True if response_with_constrained_start else False 2176 | 2177 | 2178 | class HighlightSectionChecker(Instruction): 2179 | """Checks the highlighted section.""" 2180 | 2181 | def build_description(self, *, num_highlights=None): 2182 | """Build the instruction description. 2183 | 2184 | Args: 2185 | num_highlights: An integer specifying the minimum number of highlighted 2186 | sections. 2187 | 2188 | Returns: 2189 | A string representing the instruction description. 2190 | """ 2191 | self._num_highlights = num_highlights 2192 | if self._num_highlights is None or self._num_highlights < 0: 2193 | self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) 2194 | 2195 | self._description_pattern = ( 2196 | "Highlight at least {num_highlights} sections in your answer with " 2197 | + "markdown, i.e. *highlighted section*." 2198 | ) 2199 | 2200 | return self._description_pattern.format(num_highlights=self._num_highlights) 2201 | 2202 | def get_instruction_args(self): 2203 | """Returns the keyward args of `build_description`.""" 2204 | return {"num_highlights": self._num_highlights} 2205 | 2206 | def get_instruction_args_keys(self): 2207 | """Returns the args keys of `build_description`.""" 2208 | return ["num_highlights"] 2209 | 2210 | def check_following(self, value): 2211 | """Checks if the number of highlighted sections meets the requirement. 2212 | 2213 | Args: 2214 | value: a string repesenting the response. The response is expected to 2215 | contain highlighted sections in the format of *highlighted*. 2216 | 2217 | Returns: 2218 | True if the actual number of highlighted sections in the format of 2219 | *highlighed sections* meets the minimum requirement; otherwise False. 2220 | """ 2221 | num_highlights = 0 2222 | try: 2223 | highlights = re.findall(r"\*[^\n\*]*\*", value) 2224 | except: 2225 | print('Failed for highlights, got value: ', value) 2226 | return False 2227 | double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) 2228 | for highlight in highlights: 2229 | if highlight.strip("*").strip(): 2230 | num_highlights += 1 2231 | for highlight in double_highlights: 2232 | if highlight.removeprefix("**").removesuffix("**").strip(): 2233 | num_highlights += 1 2234 | 2235 | return num_highlights >= self._num_highlights 2236 | 2237 | 2238 | class SectionChecker(Instruction): 2239 | """Checks the sections.""" 2240 | 2241 | def build_description(self, *, section_spliter=None, num_sections=None): 2242 | """Build the instruction description. 2243 | 2244 | Args: 2245 | section_spliter: A string represents the section spliter keyword that 2246 | marks a new section, i.e., `Section` or `SECTION`. 2247 | num_sections: An integer specifying the number of sections. 2248 | 2249 | Returns: 2250 | A string representing the instruction description. 2251 | """ 2252 | self._section_spliter = ( 2253 | section_spliter.strip() 2254 | if isinstance(section_spliter, str) 2255 | else section_spliter 2256 | ) 2257 | if self._section_spliter is None: 2258 | self._section_spliter = random.choice(_SECTION_SPLITER) 2259 | 2260 | self._num_sections = num_sections 2261 | if self._num_sections is None or self._num_sections < 0: 2262 | self._num_sections = random.randint(1, _NUM_SECTIONS) 2263 | 2264 | self._description_pattern = ( 2265 | "Your response must have {num_sections} sections. Mark the beginning " 2266 | + "of each section with {section_spliter} X, such as:\n" 2267 | + "{section_spliter} 1\n" 2268 | + "[content of section 1]\n" 2269 | + "{section_spliter} 2\n" 2270 | + "[content of section 2]" 2271 | ) 2272 | 2273 | return self._description_pattern.format( 2274 | num_sections=self._num_sections, section_spliter=self._section_spliter 2275 | ) 2276 | 2277 | def get_instruction_args(self): 2278 | """Returns the keyward args of `build_description`.""" 2279 | return { 2280 | "section_spliter": self._section_spliter, 2281 | "num_sections": self._num_sections, 2282 | } 2283 | 2284 | def get_instruction_args_keys(self): 2285 | """Returns the args keys of `build_description`.""" 2286 | return ["section_spliter", "num_sections"] 2287 | 2288 | def check_following(self, value): 2289 | """Checks the response contains multiple sections. 2290 | 2291 | Args: 2292 | value: A string representing the response. The response is expected 2293 | to contain multiple sections (number of sections is greater than 1). 2294 | A new section starts with `Section 1`, where the number denotes the 2295 | section index. 2296 | 2297 | Returns: 2298 | True if the number of sections in the response is greater than or equal to 2299 | the minimum number of sections; otherwise, False. 2300 | """ 2301 | section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" 2302 | sections = re.split(section_splitter_patten, value) 2303 | num_sections = len(sections) - 1 2304 | return num_sections >= self._num_sections 2305 | 2306 | 2307 | class ParagraphChecker(Instruction): 2308 | """Checks the paragraphs.""" 2309 | 2310 | def build_description(self, *, num_paragraphs=None): 2311 | """Build the instruction description. 2312 | 2313 | Args: 2314 | num_paragraphs: An integer specifying the number of paragraphs. 2315 | 2316 | Returns: 2317 | A string representing the instruction description. 2318 | """ 2319 | self._num_paragraphs = num_paragraphs 2320 | if self._num_paragraphs is None or self._num_paragraphs < 0: 2321 | self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) 2322 | 2323 | self._description_pattern = ( 2324 | "There should be {num_paragraphs} paragraphs. " 2325 | + "Paragraphs are separated with the markdown divider: ***" 2326 | ) 2327 | 2328 | return self._description_pattern.format(num_paragraphs=self._num_paragraphs) 2329 | 2330 | def get_instruction_args(self): 2331 | """Returns the keyward args of `build_description`.""" 2332 | return {"num_paragraphs": self._num_paragraphs} 2333 | 2334 | def get_instruction_args_keys(self): 2335 | """Returns the args keys of `build_description`.""" 2336 | return ["num_paragraphs"] 2337 | 2338 | def check_following(self, value): 2339 | """Checks the response contains required number of paragraphs. 2340 | 2341 | Args: 2342 | value: A string representing the response. The response may contain 2343 | paragraphs that are separated by the markdown divider: `***`. 2344 | 2345 | Returns: 2346 | True if the actual number of paragraphs is the same as required; 2347 | otherwise, False. 2348 | """ 2349 | paragraphs = re.split(r"\s?\*\*\*\s?", value) 2350 | num_paragraphs = len(paragraphs) 2351 | 2352 | for index, paragraph in enumerate(paragraphs): 2353 | if not paragraph.strip(): 2354 | if index == 0 or index == len(paragraphs) - 1: 2355 | num_paragraphs -= 1 2356 | else: 2357 | return False 2358 | 2359 | return num_paragraphs == self._num_paragraphs 2360 | 2361 | 2362 | class PostscriptChecker(Instruction): 2363 | """Checks the postscript.""" 2364 | 2365 | def build_description(self, *, postscript_marker=None): 2366 | """Build the instruction description. 2367 | 2368 | Args: 2369 | postscript_marker: A string containing the keyword that marks the start 2370 | of the postscript section. 2371 | 2372 | Returns: 2373 | A string representing the instruction description. 2374 | """ 2375 | self._postscript_marker = ( 2376 | postscript_marker.strip() 2377 | if isinstance(postscript_marker, str) 2378 | else postscript_marker 2379 | ) 2380 | if self._postscript_marker is None: 2381 | self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) 2382 | 2383 | self._description_pattern = ( 2384 | "At the end of your response, please explicitly add a postscript " 2385 | + "starting with {postscript}" 2386 | ) 2387 | 2388 | return self._description_pattern.format(postscript=self._postscript_marker) 2389 | 2390 | def get_instruction_args(self): 2391 | """Returns the keyward args of `build_description`.""" 2392 | return {"postscript_marker": self._postscript_marker} 2393 | 2394 | def get_instruction_args_keys(self): 2395 | """Returns the args keys of `build_description`.""" 2396 | return ["postscript_marker"] 2397 | 2398 | def check_following(self, value): 2399 | """Checks if the response follows the postscript format. 2400 | 2401 | Args: 2402 | value: a string representing the response. The response is expected to 2403 | contain a postscript section. 2404 | 2405 | Returns: 2406 | True if the response contains a postscript section starting with 2407 | the keyword containing in the `instruction_args`; otherwise False. 2408 | """ 2409 | value = value.lower() 2410 | if self._postscript_marker == "P.P.S": 2411 | postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" 2412 | elif self._postscript_marker == "P.S.": 2413 | postscript_pattern = r"\s*p\.\s?s\..*$" 2414 | else: 2415 | postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" 2416 | postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) 2417 | return True if postscript else False 2418 | 2419 | 2420 | class RephraseChecker(Instruction): 2421 | """Checks the repharse.""" 2422 | 2423 | def build_description(self, *, original_message): 2424 | """Build the instruction description. 2425 | 2426 | Args: 2427 | original_message: A string representing the original message. The 2428 | rephrased response should only change its words/sentences in between 2429 | its two asterisks, for example, *change me*. Both original and rephrased 2430 | messages should contain the changes in the form of *change me*. 2431 | 2432 | Returns: 2433 | A string representing the instruction description. 2434 | """ 2435 | if not self.is_change(original_message): 2436 | raise ValueError( 2437 | f"Message {original_message} does not contain changes " 2438 | "in the form of *change me*." 2439 | ) 2440 | 2441 | self._reference_without_change = original_message 2442 | self._description = ( 2443 | "Rephrasing: Your rephrased response should only" 2444 | + "change the words/sentences in between two asterisks" 2445 | + "such as *change me*." 2446 | ) 2447 | return self._description 2448 | 2449 | def get_instruction_args(self): 2450 | """Returns the keyward args of `build_description`.""" 2451 | return {"original_message": self._reference_without_change} 2452 | 2453 | def get_instruction_args_keys(self): 2454 | """Returns the args keys of `build_description`.""" 2455 | return ["original_message"] 2456 | 2457 | def check_following(self, value): 2458 | r"""Checks if the rephrasing follows the instruction. 2459 | 2460 | Args: 2461 | value: A string representing the response, which is expected to rephras 2462 | the string of `instruction_args`. 2463 | 2464 | Returns: 2465 | True if `value` and `instruction_args` only differ by the words/sentences 2466 | in between two asterisks such as *change me*; otherwise, False. 2467 | """ 2468 | 2469 | if not self.is_change(value): 2470 | raise ValueError( 2471 | f"value {value} does not contain " "changes in the form of *change me*." 2472 | ) 2473 | 2474 | response_without_changes = self.strip_changes(value) 2475 | reference_without_changes = self.strip_changes(self._reference_without_change) 2476 | 2477 | return response_without_changes == reference_without_changes 2478 | 2479 | def is_change(self, response): 2480 | """Check if there is change in the response in the form of *change me*.""" 2481 | return re.search(r"\*.*\*", response) 2482 | 2483 | def strip_changes(self, response): 2484 | """Strips off the changes.""" 2485 | return re.sub(r"\*.*\*", "", response) 2486 | 2487 | 2488 | class KeywordChecker(Instruction): 2489 | """Check the exisitence of certain keywords.""" 2490 | 2491 | def build_description(self, *, keywords=None): 2492 | """Build the instruction description. 2493 | 2494 | Args: 2495 | keywords: A sequence of strings representing the keywords that are 2496 | expected in the response. 2497 | 2498 | Returns: 2499 | A string representing the instruction description. 2500 | """ 2501 | 2502 | if not keywords: 2503 | self._keywords = generate_keywords(num_keywords=_NUM_KEYWORDS) 2504 | else: 2505 | self._keywords = keywords 2506 | self._keywords = sorted(self._keywords) 2507 | 2508 | self._description_pattern = "Include keywords {keywords} in the response." 2509 | 2510 | return self._description_pattern.format(keywords=self._keywords) 2511 | 2512 | def get_instruction_args(self): 2513 | """Returns the keyward args of `build_description`.""" 2514 | return {"keywords": self._keywords} 2515 | 2516 | def get_instruction_args_keys(self): 2517 | """Returns the args keys of `build_description`.""" 2518 | return ["keywords"] 2519 | 2520 | def check_following(self, value): 2521 | """Check if the response contain the expected keywords.""" 2522 | for keyword in self._keywords: 2523 | if not re.search(keyword, value, flags=re.IGNORECASE): 2524 | return False 2525 | return True 2526 | 2527 | 2528 | class KeywordFrequencyChecker(Instruction): 2529 | """Check the keyword frequency.""" 2530 | 2531 | def build_description(self, *, keyword=None, frequency=None, relation=None): 2532 | """Build the instruction description. 2533 | 2534 | Args: 2535 | keyword: A string representing a keyword that is expected in the response. 2536 | frequency: An integer specifying the number of times `keyword` is expected 2537 | to appear in the response. 2538 | relation: A string in (`less than`, `at least`), defining the relational 2539 | operator for comparison. 2540 | Two relational comparisons are supported for now: 2541 | if 'less than', the actual number of occurrences < frequency; 2542 | if 'at least', the actual number of occurrences >= frequency. 2543 | 2544 | Returns: 2545 | A string representing the instruction description. 2546 | """ 2547 | if not keyword: 2548 | self._keyword = generate_keywords(num_keywords=1)[0] 2549 | else: 2550 | self._keyword = keyword.strip() 2551 | 2552 | self._frequency = frequency 2553 | if self._frequency is None or self._frequency < 0: 2554 | self._frequency = random.randint(1, _KEYWORD_FREQUENCY) 2555 | 2556 | if relation is None: 2557 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 2558 | elif relation not in _COMPARISON_RELATION: 2559 | raise ValueError( 2560 | "The supported relation for comparison must be in " 2561 | f"{_COMPARISON_RELATION}, but {relation} is given." 2562 | ) 2563 | else: 2564 | self._comparison_relation = relation 2565 | 2566 | self._description_pattern = ( 2567 | "In your response, the word {keyword} should appear {relation} " 2568 | + "{frequency} times." 2569 | ) 2570 | 2571 | return self._description_pattern.format( 2572 | keyword=self._keyword, 2573 | relation=self._comparison_relation, 2574 | frequency=self._frequency, 2575 | ) 2576 | 2577 | def get_instruction_args(self): 2578 | """Returns the keyward args of `build_description`.""" 2579 | return { 2580 | "keyword": self._keyword, 2581 | "frequency": self._frequency, 2582 | "relation": self._comparison_relation, 2583 | } 2584 | 2585 | def get_instruction_args_keys(self): 2586 | """Returns the args keys of `build_description`.""" 2587 | return ["keyword", "frequency", "relation"] 2588 | 2589 | def check_following(self, value): 2590 | """Checks if the response contain the keyword with required frequency.""" 2591 | try: 2592 | actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) 2593 | except: 2594 | print('Failed to parse for', value) 2595 | return False 2596 | 2597 | if self._comparison_relation == _COMPARISON_RELATION[0]: 2598 | return actual_occurrences < self._frequency 2599 | elif self._comparison_relation == _COMPARISON_RELATION[1]: 2600 | return actual_occurrences >= self._frequency 2601 | 2602 | 2603 | class NumberOfWords(Instruction): 2604 | """Checks the number of words.""" 2605 | 2606 | def build_description(self, *, num_words=None, relation=None): 2607 | """Build the instruction description. 2608 | 2609 | Args: 2610 | num_words: An integer specifying the number of words contained in the 2611 | response. 2612 | relation: A string in (`less than`, `at least`), defining the relational 2613 | operator for comparison. 2614 | Two relational comparisons are supported for now: 2615 | if 'less than', the actual number of words < num_words; 2616 | if 'at least', the actual number of words >= num_words. 2617 | 2618 | Returns: 2619 | A string representing the instruction description. 2620 | """ 2621 | 2622 | self._num_words = num_words 2623 | if self._num_words is None or self._num_words < 0: 2624 | self._num_words = random.randint( 2625 | _NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT 2626 | ) 2627 | 2628 | if relation is None: 2629 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 2630 | elif relation not in _COMPARISON_RELATION: 2631 | raise ValueError( 2632 | "The supported relation for comparison must be in " 2633 | f"{_COMPARISON_RELATION}, but {relation} is given." 2634 | ) 2635 | else: 2636 | self._comparison_relation = relation 2637 | 2638 | self._description_pattern = "Answer with {relation} {num_words} words." 2639 | 2640 | return self._description_pattern.format( 2641 | relation=self._comparison_relation, num_words=self._num_words 2642 | ) 2643 | 2644 | def get_instruction_args(self): 2645 | """Returns the keyward args of `build_description`.""" 2646 | return {"num_words": self._num_words, "relation": self._comparison_relation} 2647 | 2648 | def get_instruction_args_keys(self): 2649 | """Returns the args keys of `build_description`.""" 2650 | return ["num_words", "relation"] 2651 | 2652 | def check_following(self, value): 2653 | """Checks if the response contains the expected number of words.""" 2654 | try: 2655 | lang = langdetect.detect(value) 2656 | except: 2657 | print("Failed to detect language, got value:", value) 2658 | lang = 'en' 2659 | if lang == "th": 2660 | # print(f"shervin4. lang is {lang}") 2661 | # print(value) 2662 | num_words = len(word_tokenize_thai(value)) 2663 | # print(f"num words: {num_words}") 2664 | elif lang in ["zh", "zh-cn", "zh-tw", "ja"]: 2665 | # print(f"shervin5. lang is {lang}") 2666 | # print(value) 2667 | num_words = count_words_chinese_japanese(value) 2668 | # print(f"num words: {num_words}") 2669 | else: 2670 | # print(f"shervin6. lang is {lang}") 2671 | # print(value) 2672 | num_words = count_words(value) 2673 | # print(f"num words: {num_words}") 2674 | 2675 | if self._comparison_relation == _COMPARISON_RELATION[0]: 2676 | return num_words < self._num_words 2677 | elif self._comparison_relation == _COMPARISON_RELATION[1]: 2678 | return num_words >= self._num_words 2679 | 2680 | 2681 | class JsonFormat(Instruction): 2682 | """Check the Json format.""" 2683 | 2684 | def build_description(self): 2685 | self._description_pattern = ( 2686 | "Entire output should be wrapped in JSON format. You can use markdown" 2687 | " ticks such as ```." 2688 | ) 2689 | return self._description_pattern 2690 | 2691 | def get_instruction_args(self): 2692 | """Returns the keyward args of `build_description`.""" 2693 | return None 2694 | 2695 | def get_instruction_args_keys(self): 2696 | """Returns the args keys of `build_description`.""" 2697 | return [] 2698 | 2699 | def check_following(self, value): 2700 | value = ( 2701 | value.strip() 2702 | .removeprefix("```json") 2703 | .removeprefix("```Json") 2704 | .removeprefix("```JSON") 2705 | .removeprefix("```") 2706 | .removesuffix("```") 2707 | .strip() 2708 | ) 2709 | try: 2710 | json.loads(value) 2711 | except ValueError as _: 2712 | return False 2713 | return True 2714 | 2715 | 2716 | class ParagraphFirstWordCheck(Instruction): 2717 | """Check the paragraph and the first word of the nth paragraph.""" 2718 | 2719 | def build_description( 2720 | self, num_paragraphs=None, nth_paragraph=None, first_word=None 2721 | ): 2722 | r"""Build the instruction description. 2723 | 2724 | Args: 2725 | num_paragraphs: An integer indicating the number of paragraphs expected 2726 | in the response. A paragraph is a subset of the string that is 2727 | expected to be separated by '\n\n'. 2728 | nth_paragraph: An integer indicating the paragraph number that we look at. 2729 | Note that n starts from 1. 2730 | first_word: A string that represent the first word of the bth paragraph. 2731 | 2732 | Returns: 2733 | A string representing the instruction description. 2734 | """ 2735 | self._num_paragraphs = num_paragraphs 2736 | if self._num_paragraphs is None or self._num_paragraphs < 0: 2737 | self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) 2738 | 2739 | self._nth_paragraph = nth_paragraph 2740 | if ( 2741 | self._nth_paragraph is None 2742 | or self._nth_paragraph <= 0 2743 | or self._nth_paragraph > self._num_paragraphs 2744 | ): 2745 | self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) 2746 | 2747 | self._first_word = first_word 2748 | if self._first_word is None: 2749 | self._first_word = generate_keywords(num_keywords=1)[0] 2750 | self._first_word = self._first_word.lower() 2751 | 2752 | self._description_pattern = ( 2753 | "There should be {num_paragraphs} paragraphs. " 2754 | + "Paragraphs and only paragraphs are separated with each other by two " 2755 | + "new lines as if it was '\\n\\n' in python. " 2756 | + "Paragraph {nth_paragraph} must start with word {first_word}." 2757 | ) 2758 | 2759 | return self._description_pattern.format( 2760 | num_paragraphs=self._num_paragraphs, 2761 | nth_paragraph=self._nth_paragraph, 2762 | first_word=self._first_word, 2763 | ) 2764 | 2765 | def get_instruction_args(self): 2766 | """Returns the keyward args of `build_description`.""" 2767 | return { 2768 | "num_paragraphs": self._num_paragraphs, 2769 | "nth_paragraph": self._nth_paragraph, 2770 | "first_word": self._first_word, 2771 | } 2772 | 2773 | def get_instruction_args_keys(self): 2774 | """Returns the args keys of `build_description`.""" 2775 | return ["num_paragraphs", "nth_paragraph", "first_word"] 2776 | 2777 | def check_following(self, value): 2778 | """Checks for required number of paragraphs and correct first word. 2779 | 2780 | Args: 2781 | value: a string representing the response. The response may contain 2782 | paragraphs that are separated by two new lines and the first word of 2783 | the nth paragraph will have to match a specified word. 2784 | 2785 | Returns: 2786 | True if the number of paragraphs is the same as required and the first 2787 | word of the specified paragraph is the same as required. Otherwise, false. 2788 | """ 2789 | 2790 | paragraphs = re.split(r"\n\n", value) 2791 | num_paragraphs = len(paragraphs) 2792 | 2793 | for paragraph in paragraphs: 2794 | if not paragraph.strip(): 2795 | num_paragraphs -= 1 2796 | 2797 | # check that index doesn't go out of bounds 2798 | if self._nth_paragraph <= num_paragraphs: 2799 | paragraph = paragraphs[self._nth_paragraph - 1].strip() 2800 | if not paragraph: 2801 | return False 2802 | else: 2803 | return False 2804 | 2805 | first_word = "" 2806 | punctuation = {".", ",", "?", "!", "'", '"'} 2807 | 2808 | # get first word and remove punctuation 2809 | word = paragraph.split()[0].strip() 2810 | # TODO(jeffrey): make more complex? 2811 | word = word.lstrip("'") 2812 | word = word.lstrip('"') 2813 | 2814 | for letter in word: 2815 | if letter in punctuation: 2816 | break 2817 | first_word += letter.lower() 2818 | 2819 | return num_paragraphs == self._num_paragraphs and first_word == self._first_word 2820 | 2821 | 2822 | # TODO(jeffrey) add relation - at least/at most? 2823 | class KeySentenceChecker(Instruction): 2824 | """Check the existence of certain key sentences.""" 2825 | 2826 | def build_description(self, key_sentences=None, num_sentences=None): 2827 | """Build the instruction description. 2828 | 2829 | Args: 2830 | key_sentences: A sequences of strings representing the key sentences that 2831 | are expected in the response. 2832 | num_sentences: The number of key sentences that are expected to be seen in 2833 | the response. 2834 | 2835 | Returns: 2836 | A string representing the instruction description. 2837 | """ 2838 | 2839 | if not key_sentences: 2840 | # TODO(jeffrey) make a generate sentences function? wonderwords package 2841 | self._key_sentences = set(["For now, this is fine."]) 2842 | else: 2843 | self._key_sentences = key_sentences 2844 | 2845 | if not num_sentences: 2846 | self._num_sentences = random.randint(1, len(self._key_sentences)) 2847 | else: 2848 | self._num_sentences = num_sentences 2849 | 2850 | self._description_pattern = ( 2851 | "Include {num_sentences} of the following sentences {key_sentences}" 2852 | ) 2853 | 2854 | return self._description_pattern.format( 2855 | num_sentences=self._num_sentences, key_sentences=self._key_sentences 2856 | ) 2857 | 2858 | def get_instruction_args(self): 2859 | """Returns the keyward args of `build_description`.""" 2860 | return { 2861 | "num_sentences": self._num_sentences, 2862 | "key_sentences": list(self._key_sentences), 2863 | } 2864 | 2865 | def get_instruction_args_keys(self): 2866 | """Returns the args keys of `build_description`.""" 2867 | return ["num_sentences", "key_sentences"] 2868 | 2869 | def check_following(self, value): 2870 | """Checks if the response contains the expected key sentences.""" 2871 | count = 0 2872 | sentences = split_into_sentences(value) 2873 | for sentence in self._key_sentences: 2874 | if sentence in sentences: 2875 | count += 1 2876 | 2877 | return count == self._num_sentences 2878 | 2879 | 2880 | class ForbiddenWords(Instruction): 2881 | """Checks that specified words are not used in response.""" 2882 | 2883 | def build_description(self, forbidden_words=None): 2884 | """Build the instruction description. 2885 | 2886 | Args: 2887 | forbidden_words: A sequences of strings respresenting words that are not 2888 | allowed in the response. 2889 | 2890 | Returns: 2891 | A string representing the instruction description. 2892 | """ 2893 | 2894 | if not forbidden_words: 2895 | self._forbidden_words = generate_keywords(num_keywords=_NUM_KEYWORDS) 2896 | else: 2897 | self._forbidden_words = list(set(forbidden_words)) 2898 | self._forbidden_words = sorted(self._forbidden_words) 2899 | self._description_pattern = ( 2900 | "Do not include keywords {forbidden_words} in the response." 2901 | ) 2902 | 2903 | return self._description_pattern.format(forbidden_words=self._forbidden_words) 2904 | 2905 | def get_instruction_args(self): 2906 | """Returns the keyward args of `build_description`.""" 2907 | return {"forbidden_words": self._forbidden_words} 2908 | 2909 | def get_instruction_args_keys(self): 2910 | """Returns the args keys of `build_description`.""" 2911 | return ["forbidden_words"] 2912 | 2913 | def check_following(self, value): 2914 | """Check if the response does not contain the expected keywords.""" 2915 | for word in self._forbidden_words: 2916 | if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): 2917 | return False 2918 | return True 2919 | 2920 | 2921 | class RephraseParagraph(Instruction): 2922 | """Checks that the paragraph is rephrased.""" 2923 | 2924 | def build_description(self, *, original_paragraph, low, high): 2925 | """Builds the instruction description. 2926 | 2927 | Args: 2928 | original_paragraph: A string presenting the original paragraph. The 2929 | rephrases response should have betweeb low-high words in common. 2930 | low: An integer presenting the lower bound of similar words. 2931 | high: An integer representing the upper bound of similar words. 2932 | 2933 | Returns: 2934 | A string representing the instruction description. 2935 | """ 2936 | # TODO(jeffrey) make more encompassing 2937 | self._original_paragraph = original_paragraph 2938 | self._low = low 2939 | self._high = high 2940 | 2941 | self._description = ( 2942 | "Rephrase the following paragraph: " 2943 | + "{original_paragraph}\nYour response should have " 2944 | + "between {low} and {high} of the same words. " 2945 | + "Words are the same if and only if all of the " 2946 | + "letters, ignoring cases, are the same. For " 2947 | + "example, 'run' is the same as 'Run' but different " 2948 | + "to 'ran'." 2949 | ) 2950 | 2951 | return self._description.format( 2952 | original_paragraph=original_paragraph, low=self._low, high=self._high 2953 | ) 2954 | 2955 | def get_instruction_args(self): 2956 | """Returns the keyward args of `build_description`.""" 2957 | return { 2958 | "original_paragraph": self._original_paragraph, 2959 | "low": self._low, 2960 | "high": self._high, 2961 | } 2962 | 2963 | def get_instruction_args_keys(self): 2964 | """Returns the args keys of `build_description`.""" 2965 | return ["original_paragraph", "low", "high"] 2966 | 2967 | def check_following(self, value): 2968 | val_words = re.findall(r"\w+", value.lower()) 2969 | original_words = re.findall(r"\w+", self._original_paragraph.lower()) 2970 | similar_words = 0 2971 | 2972 | dict_val = collections.Counter(val_words) 2973 | dict_original = collections.Counter(original_words) 2974 | 2975 | for word in dict_original: 2976 | similar_words += min(dict_original[word], dict_val[word]) 2977 | 2978 | return similar_words >= self._low and similar_words <= self._high 2979 | 2980 | 2981 | class TwoResponsesChecker(Instruction): 2982 | """Check that two responses were given.""" 2983 | 2984 | def build_description(self): 2985 | """Build the instruction description.""" 2986 | self._description_pattern = ( 2987 | "Give two different responses. Responses and only responses should" 2988 | " be separated by 6 asterisk symbols: ******." 2989 | ) 2990 | return self._description_pattern 2991 | 2992 | def get_instruction_args(self): 2993 | """Returns the keyward args of `build_description`.""" 2994 | return None 2995 | 2996 | def get_instruction_args_keys(self): 2997 | """Returns the args keys of `build_description`.""" 2998 | return [] 2999 | 3000 | def check_following(self, value): 3001 | """Checks if the response has two different answers. 3002 | 3003 | Args: 3004 | value: A string representing the response. 3005 | 3006 | Returns: 3007 | True if two responses are detected and false otherwise. 3008 | """ 3009 | valid_responses = list() 3010 | responses = value.split("******") 3011 | for index, response in enumerate(responses): 3012 | if not response.strip(): 3013 | if index != 0 and index != len(responses) - 1: 3014 | return False 3015 | else: 3016 | valid_responses.append(response) 3017 | return ( 3018 | len(valid_responses) == 2 3019 | and valid_responses[0].strip() != valid_responses[1].strip() 3020 | ) 3021 | 3022 | 3023 | class RepeatPromptThenAnswer(Instruction): 3024 | """Checks that Prompt is first repeated then answered.""" 3025 | 3026 | def build_description(self, *, prompt_to_repeat=None): 3027 | """Build the instruction description. 3028 | 3029 | Args: 3030 | prompt_to_repeat: The prompt that is meant to be repeated. 3031 | 3032 | Returns: 3033 | A string representing the instruction description. 3034 | """ 3035 | if not prompt_to_repeat: 3036 | raise ValueError("prompt_to_repeat must be set.") 3037 | else: 3038 | self._prompt_to_repeat = prompt_to_repeat 3039 | self._description_pattern = ( 3040 | "First repeat the request word for word without change," 3041 | " then give your answer (1. do not say any words or characters" 3042 | " before repeating the request; 2. the request you need to repeat" 3043 | " does not include this sentence)" 3044 | ) 3045 | return self._description_pattern 3046 | 3047 | def get_instruction_args(self): 3048 | return {"prompt_to_repeat": self._prompt_to_repeat} 3049 | 3050 | def get_instruction_args_keys(self): 3051 | """Returns the args keys of `build_description`.""" 3052 | return ["prompt_to_repeat"] 3053 | 3054 | def check_following(self, value): 3055 | if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): 3056 | return True 3057 | return False 3058 | 3059 | 3060 | class EndChecker(Instruction): 3061 | """Checks that the prompt ends with a given phrase.""" 3062 | 3063 | def build_description(self, *, end_phrase=None): 3064 | """Build the instruction description. 3065 | 3066 | Args: 3067 | end_phrase: A string representing the phrase the response should end with. 3068 | 3069 | Returns: 3070 | A string representing the instruction description. 3071 | """ 3072 | self._end_phrase = ( 3073 | end_phrase.strip() if isinstance(end_phrase, str) else end_phrase 3074 | ) 3075 | if self._end_phrase is None: 3076 | self._end_phrase = random.choice(_ENDING_OPTIONS) 3077 | self._description_pattern = ( 3078 | "Finish your response with this exact phrase {ender}. " 3079 | "No other words should follow this phrase." 3080 | ) 3081 | return self._description_pattern.format(ender=self._end_phrase) 3082 | 3083 | def get_instruction_args(self): 3084 | return {"end_phrase": self._end_phrase} 3085 | 3086 | def get_instruction_args_keys(self): 3087 | """Returns the args keys of `build_description`.""" 3088 | return ["end_phrase"] 3089 | 3090 | def check_following(self, value): 3091 | """Checks if the response ends with the expected phrase.""" 3092 | value = value.strip().strip('"').lower() 3093 | self._end_phrase = self._end_phrase.strip().lower() 3094 | return value.endswith(self._end_phrase) 3095 | 3096 | 3097 | class TitleChecker(Instruction): 3098 | """Checks the response for a title.""" 3099 | 3100 | def build_description(self): 3101 | """Build the instruction description.""" 3102 | self._description_pattern = ( 3103 | "Your answer must contain a title, wrapped in double angular brackets," 3104 | " such as <>." 3105 | ) 3106 | return self._description_pattern 3107 | 3108 | def get_instruction_args(self): 3109 | return None 3110 | 3111 | def get_instruction_args_keys(self): 3112 | """Returns the args keys of `build_description`.""" 3113 | return [] 3114 | 3115 | def check_following(self, value): 3116 | """Checks if the response contains a title.""" 3117 | pattern = r"<<[^\n]+>>" 3118 | re_pattern = re.compile(pattern) 3119 | titles = re.findall(re_pattern, value) 3120 | 3121 | for title in titles: 3122 | if title.lstrip("<").rstrip(">").strip(): 3123 | return True 3124 | return False 3125 | 3126 | 3127 | class LetterFrequencyChecker(Instruction): 3128 | """Checks letter frequency.""" 3129 | 3130 | def build_description(self, *, letter=None, let_frequency=None, let_relation=None): 3131 | """Build the instruction description. 3132 | 3133 | Args: 3134 | letter: A string representing a letter that is expected in the response. 3135 | let_frequency: An integer specifying the number of times `keyword` is 3136 | expected to appear in the response. 3137 | let_relation: A string in (`less than`, `at least`), defining the 3138 | relational operator for comparison. Two relational comparisons are 3139 | supported for now; if 'less than', the actual number of 3140 | occurrences < frequency; if 'at least', the actual number of 3141 | occurrences >= frequency. 3142 | 3143 | Returns: 3144 | A string representing the instruction description. 3145 | """ 3146 | if ( 3147 | not letter 3148 | or len(letter) > 1 3149 | or ord(letter.lower()) < 97 3150 | or ord(letter.lower()) > 122 3151 | ): 3152 | self._letter = random.choice(list(string.ascii_letters)) 3153 | else: 3154 | self._letter = letter.strip() 3155 | self._letter = self._letter.lower() 3156 | 3157 | self._frequency = let_frequency 3158 | if self._frequency is None or self._frequency < 0: 3159 | self._frequency = random.randint(1, _LETTER_FREQUENCY) 3160 | 3161 | if let_relation is None: 3162 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 3163 | elif let_relation not in _COMPARISON_RELATION: 3164 | raise ValueError( 3165 | "The supported relation for comparison must be in " 3166 | f"{_COMPARISON_RELATION}, but {let_relation} is given." 3167 | ) 3168 | else: 3169 | self._comparison_relation = let_relation 3170 | 3171 | self._description_pattern = ( 3172 | "In your response, the letter {letter} should appear {let_relation}" 3173 | " {let_frequency} times." 3174 | ) 3175 | 3176 | return self._description_pattern.format( 3177 | letter=self._letter, 3178 | let_frequency=self._frequency, 3179 | let_relation=self._comparison_relation, 3180 | ) 3181 | 3182 | def get_instruction_args(self): 3183 | """Returns the keyword args of build description.""" 3184 | return { 3185 | "letter": self._letter, 3186 | "let_frequency": self._frequency, 3187 | "let_relation": self._comparison_relation, 3188 | } 3189 | 3190 | def get_instruction_args_keys(self): 3191 | """Returns the args keys of `build_description`.""" 3192 | return ["letter", "let_frequency", "let_relation"] 3193 | 3194 | def check_following(self, value): 3195 | """Checks that the response contains the letter at the right frequency.""" 3196 | value = value.lower() 3197 | letters = collections.Counter(value) 3198 | 3199 | if self._comparison_relation == _COMPARISON_RELATION[0]: 3200 | return letters[self._letter] < self._frequency 3201 | else: 3202 | return letters[self._letter] >= self._frequency 3203 | 3204 | 3205 | class CapitalLettersEnglishChecker(Instruction): 3206 | """Checks that the response is in english and is in all capital letters.""" 3207 | 3208 | def build_description(self): 3209 | """Build the instruction description.""" 3210 | self._description_pattern = ( 3211 | "Your entire response should be in English, and in all capital letters." 3212 | ) 3213 | return self._description_pattern 3214 | 3215 | def get_instruction_args(self): 3216 | return None 3217 | 3218 | def get_instruction_args_keys(self): 3219 | """Returns the args keys of `build_description`.""" 3220 | return [] 3221 | 3222 | def check_following(self, value): 3223 | """Checks that the response is in English and in all capital letters.""" 3224 | assert isinstance(value, str) 3225 | 3226 | try: 3227 | return value.isupper() and langdetect.detect(value) == "en" 3228 | except langdetect.LangDetectException as e: 3229 | # Count as instruction is followed. 3230 | logger.info( 3231 | "Unable to detect language for text %s due to %s", value, e 3232 | ) # refex: disable=pytotw.037 3233 | return True 3234 | 3235 | 3236 | class LowercaseLettersEnglishChecker(Instruction): 3237 | """Checks that the response is in english and is in all lowercase letters.""" 3238 | 3239 | def build_description(self): 3240 | """Build the instruction description.""" 3241 | self._description_pattern = ( 3242 | "Your entire response should be in English, and in all lowercase" 3243 | " letters. No capital letters are allowed." 3244 | ) 3245 | return self._description_pattern 3246 | 3247 | def get_instruction_args(self): 3248 | return None 3249 | 3250 | def get_instruction_args_keys(self): 3251 | """Returns the args keys of `build_description`.""" 3252 | return [] 3253 | 3254 | def check_following(self, value): 3255 | """Checks that the response is in English and in all lowercase letters.""" 3256 | assert isinstance(value, str) 3257 | 3258 | try: 3259 | return value.islower() and langdetect.detect(value) == "en" 3260 | except langdetect.LangDetectException as e: 3261 | # Count as instruction is followed. 3262 | logger.info( 3263 | "Unable to detect language for text %s due to %s", value, e 3264 | ) # refex: disable=pytotw.037 3265 | return True 3266 | 3267 | 3268 | class CommaChecker(Instruction): 3269 | """Checks the response for no commas.""" 3270 | 3271 | def build_description(self): 3272 | """Build the instruction description.""" 3273 | self._description_pattern = ( 3274 | "In your entire response, refrain from the use of any commas." 3275 | ) 3276 | return self._description_pattern 3277 | 3278 | def get_instruction_args(self): 3279 | return None 3280 | 3281 | def get_instruction_args_keys(self): 3282 | """Returns the args keys of `build_description`.""" 3283 | return [] 3284 | 3285 | def check_following(self, value): 3286 | """Checks that the response does not contain commas.""" 3287 | return not re.search(r"\,", value) 3288 | 3289 | 3290 | class CapitalWordFrequencyChecker(Instruction): 3291 | """Checks frequency of words with all capital letters.""" 3292 | 3293 | def build_description( 3294 | self, 3295 | capital_frequency=None, 3296 | capital_relation=None, 3297 | ): 3298 | """Build the instruction description. 3299 | 3300 | Args: 3301 | capital_frequency: An integer that represents the number of words that 3302 | should be in all capital letters. 3303 | capital_relation: A string that is 'at least' or 'at most' that refers to 3304 | the frequency. 3305 | 3306 | Returns: 3307 | A string representing the instruction description. 3308 | """ 3309 | self._frequency = capital_frequency 3310 | if self._frequency is None: 3311 | self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) 3312 | 3313 | self._comparison_relation = capital_relation 3314 | if capital_relation is None: 3315 | self._comparison_relation = random.choice(_COMPARISON_RELATION) 3316 | elif capital_relation not in _COMPARISON_RELATION: 3317 | raise ValueError( 3318 | "The supported relation for comparison must be in " 3319 | f"{_COMPARISON_RELATION}, but {capital_relation} is given." 3320 | ) 3321 | 3322 | self._description_pattern = ( 3323 | "In your response, words with all capital letters should appear" 3324 | " {relation} {frequency} times." 3325 | ) 3326 | 3327 | return self._description_pattern.format( 3328 | frequency=self._frequency, relation=self._comparison_relation 3329 | ) 3330 | 3331 | def get_instruction_args(self): 3332 | """Returns the keyword args of build description.""" 3333 | return { 3334 | "capital_frequency": self._frequency, 3335 | "capital_relation": self._comparison_relation, 3336 | } 3337 | 3338 | def get_instruction_args_keys(self): 3339 | """Returns the args keys of `build_description`.""" 3340 | return ["capital_frequency", "capital_relation"] 3341 | 3342 | def check_following(self, value): 3343 | """Checks the frequency of words with all capital letters.""" 3344 | # Hyphenated words will count as one word 3345 | words = nltk.word_tokenize(value) 3346 | capital_words = [word for word in words if word.isupper()] 3347 | 3348 | capital_words = len(capital_words) 3349 | 3350 | if self._comparison_relation == _COMPARISON_RELATION[0]: 3351 | return capital_words < self._frequency 3352 | else: 3353 | return capital_words >= self._frequency 3354 | 3355 | 3356 | class QuotationChecker(Instruction): 3357 | """Checks response is wrapped with double quotation marks.""" 3358 | 3359 | def build_description(self): 3360 | """Build the instruction description.""" 3361 | self._description_pattern = ( 3362 | "Wrap your entire response with double quotation marks." 3363 | ) 3364 | return self._description_pattern 3365 | 3366 | def get_instruction_args(self): 3367 | """Returns the keyword args of build description.""" 3368 | return None 3369 | 3370 | def get_instruction_args_keys(self): 3371 | """Returns the args keys of `build_description`.""" 3372 | return [] 3373 | 3374 | def check_following(self, value): 3375 | """Checks if the response is wrapped with double quotation marks.""" 3376 | value = value.strip() 3377 | return len(value) > 1 and (value[0] == '"' and value[-1] == '"' or # e.g., English 3378 | value[0] == '“' and value[-1] == '”' or # e.g., Chinese 3379 | value[0] == '「' and value[-1] == '」') # e.g., Japanese 3380 | 3381 | 3382 | # Define instruction dicts 3383 | _KEYWORD = "keywords:" 3384 | _LANGUAGE = "language:" 3385 | _LENGTH = "length_constraints:" 3386 | _CONTENT = "detectable_content:" 3387 | _FORMAT = "detectable_format:" 3388 | _MULTITURN = "multi-turn:" 3389 | _COMBINATION = "combination:" 3390 | _STARTEND = "startend:" 3391 | _CHANGE_CASES = "change_case:" 3392 | _PUNCTUATION = "punctuation:" 3393 | 3394 | INSTRUCTION_DICT = { 3395 | _KEYWORD + "existence": KeywordChecker, 3396 | _KEYWORD + "frequency": KeywordFrequencyChecker, 3397 | # _KEYWORD + "key_sentences": KeySentenceChecker, 3398 | _KEYWORD + "forbidden_words": ForbiddenWords, 3399 | _KEYWORD + "letter_frequency": LetterFrequencyChecker, 3400 | _LANGUAGE + "response_language": ResponseLanguageChecker, 3401 | _LENGTH + "number_sentences": NumberOfSentences, 3402 | _LENGTH + "number_paragraphs": ParagraphChecker, 3403 | _LENGTH + "number_words": NumberOfWords, 3404 | _LENGTH + "nth_paragraph_first_word": ParagraphFirstWordCheck, 3405 | _CONTENT + "number_placeholders": PlaceholderChecker, 3406 | _CONTENT + "postscript": PostscriptChecker, 3407 | _FORMAT + "number_bullet_lists": BulletListChecker, 3408 | # _CONTENT + "rephrase_paragraph": RephraseParagraph, 3409 | _FORMAT + "constrained_response": ConstrainedResponseChecker, 3410 | _FORMAT + "number_highlighted_sections": (HighlightSectionChecker), 3411 | _FORMAT + "multiple_sections": SectionChecker, 3412 | # _FORMAT + "rephrase": RephraseChecker, 3413 | _FORMAT + "json_format": JsonFormat, 3414 | _FORMAT + "title": TitleChecker, 3415 | # _MULTITURN + "constrained_start": ConstrainedStartChecker, 3416 | _COMBINATION + "two_responses": TwoResponsesChecker, 3417 | _COMBINATION + "repeat_prompt": RepeatPromptThenAnswer, 3418 | _STARTEND + "end_checker": EndChecker, 3419 | _CHANGE_CASES + "capital_word_frequency": CapitalWordFrequencyChecker, 3420 | _CHANGE_CASES + "english_capital": CapitalLettersEnglishChecker, 3421 | _CHANGE_CASES + "english_lowercase": LowercaseLettersEnglishChecker, 3422 | _PUNCTUATION + "no_comma": CommaChecker, 3423 | _STARTEND + "quotation": QuotationChecker, 3424 | } 3425 | 3426 | INSTRUCTION_LIST = list(INSTRUCTION_DICT.keys()) + [ 3427 | _KEYWORD[:-1], 3428 | _LANGUAGE[:-1], 3429 | _LENGTH[:-1], 3430 | _CONTENT[:-1], 3431 | _FORMAT[:-1], 3432 | _MULTITURN[:-1], 3433 | _COMBINATION[:-1], 3434 | _STARTEND[:-1], 3435 | _CHANGE_CASES[:-1], 3436 | _PUNCTUATION[:-1], 3437 | ] 3438 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import logging 16 | from typing import Any, Dict, List, Optional, Tuple 17 | 18 | import ifeval 19 | 20 | import langdetect 21 | 22 | import nltk 23 | import numpy as np 24 | import pandas as pd 25 | from scipy.stats import bootstrap 26 | 27 | 28 | logger: logging.Logger = logging.getLogger(__name__) 29 | 30 | 31 | 32 | def gen_acc_strict(x: Dict[str, Any]) -> Dict[str, float]: 33 | # reference: fbcode/gen_ai/github/fair_evals/evals/tasks/finetune/ifeval.py 34 | response = str(x["response"]) 35 | instruction_list = x["instruction_id_list"] 36 | is_following_list = [] 37 | for index, instruction_id in enumerate(instruction_list): 38 | instruction_cls = ifeval.INSTRUCTION_DICT[instruction_id] 39 | instruction = instruction_cls(instruction_id) 40 | 41 | instruction.build_description(**x["kwargs"][index]) 42 | if response and instruction.check_following(response): 43 | is_following_list.append(True) 44 | else: 45 | is_following_list.append(False) 46 | 47 | return { 48 | "follow_instruction_list": is_following_list, 49 | "instruction_id_list": instruction_list, 50 | } 51 | 52 | 53 | def gen_acc_loose(x: Dict[str, Any]) -> Dict[str, float]: 54 | response = str(x["response"]) 55 | r = response.split("\n") 56 | response_remove_first = "\n".join(r[1:]).strip() 57 | response_remove_last = "\n".join(r[:-1]).strip() 58 | response_remove_both = "\n".join(r[1:-1]).strip() 59 | revised_response = response.replace("*", "") 60 | revised_response_remove_first = response_remove_first.replace("*", "") 61 | revised_response_remove_last = response_remove_last.replace("*", "") 62 | revised_response_remove_both = response_remove_both.replace("*", "") 63 | all_responses = [ 64 | response, 65 | revised_response, 66 | response_remove_first, 67 | response_remove_last, 68 | response_remove_both, 69 | revised_response_remove_first, 70 | revised_response_remove_last, 71 | revised_response_remove_both, 72 | ] 73 | instruction_list = x["instruction_id_list"] 74 | is_following_list = [] 75 | for index, instruction_id in enumerate(instruction_list): 76 | instruction_cls = ifeval.INSTRUCTION_DICT[instruction_id] 77 | instruction = instruction_cls(instruction_id) 78 | 79 | instruction.build_description(**x["kwargs"][index]) 80 | 81 | is_following = False 82 | for r in all_responses: # type: ignore 83 | if r.strip() and instruction.check_following(r): # type: ignore 84 | is_following = True 85 | break 86 | 87 | is_following_list.append(is_following) 88 | return { 89 | "follow_instruction_list": is_following_list, 90 | "instruction_id_list": instruction_list, 91 | } 92 | 93 | 94 | def parse_result(outputs: List[Dict[str, Any]]) -> Tuple[float, float]: 95 | 96 | prompt_total = 0 97 | prompt_correct = 0 98 | instruction_total = 0 99 | instruction_correct = 0 100 | 101 | for example in outputs: 102 | follow_instruction_list = example["follow_instruction_list"] 103 | instruction_id_list = example["instruction_id_list"] 104 | 105 | prompt_total += 1 106 | if all(follow_instruction_list): 107 | prompt_correct += 1 108 | 109 | instruction_total += len(instruction_id_list) 110 | instruction_correct += sum(follow_instruction_list) 111 | 112 | return prompt_correct / prompt_total, instruction_correct / instruction_total 113 | 114 | 115 | def parse_result_no_reduce(outputs: List[Dict[str, Any]]) -> Tuple[List, List]: 116 | 117 | prompt_res = [] 118 | inst_res = [] 119 | 120 | for example in outputs: 121 | follow_instruction_list = example["follow_instruction_list"] 122 | instruction_id_list = example["instruction_id_list"] 123 | if all(follow_instruction_list): 124 | prompt_res.append(1) 125 | else: 126 | prompt_res.append(0) 127 | inst_res.append(sum(follow_instruction_list)/len(instruction_id_list)) 128 | 129 | return prompt_res, inst_res 130 | 131 | 132 | class MultiTurnInstructionFollowingPromptSolution: 133 | PROMPT_COLUMN_NAME = "multi_turn_prompt_column" 134 | 135 | def reformat_prompt( 136 | df_row: pd.core.series.Series, 137 | prompt_columns: List[str], 138 | ) -> str: 139 | """ 140 | Reformat a DataFrame row into the format that can be used with the prompt generation template. 141 | """ 142 | # TODO: revisit the prompt reformatting logic for oss 143 | prompt_col = None 144 | response_col = None 145 | if len(prompt_columns) >= 2: 146 | prompt_col = prompt_columns[0] # turns 147 | response_col = prompt_columns[1] # responses 148 | 149 | if prompt_col != "turns" or response_col != "responses": 150 | raise ValueError( 151 | f"Expecting prompt_columns to be [turns, responses], got {prompt_columns}" 152 | ) 153 | 154 | if "turn_index" in df_row: 155 | turn_index = int(df_row["turn_index"]) 156 | else: 157 | turn_index = 1 158 | if turn_index > 1: 159 | try: 160 | old_prompt = json.loads(df_row[prompt_col]) 161 | old_response = [{"role": "assistant", "content": df_row[response_col]}] 162 | except Exception as e: 163 | raise ValueError( 164 | f"Failed to parse old prompt and response due to error {e}" 165 | ) 166 | new_turn_index = f"turn_{turn_index}_prompt" 167 | if new_turn_index in df_row.index: 168 | if df_row[new_turn_index] != "None" and df_row[new_turn_index]: 169 | new_prompt = [json.loads(df_row[new_turn_index])] 170 | output_prompt = old_prompt + old_response + new_prompt 171 | else: 172 | output_prompt = [{"role": "user", "content": "None"}] 173 | else: 174 | logger.warning(f"Column {new_turn_index} does not exist!") 175 | output_prompt = [{"role": "user", "content": "None"}] 176 | else: 177 | # original input soruce table 178 | output_prompt = [json.loads(df_row[f"turn_{turn_index}_prompt"])] 179 | return output_prompt 180 | 181 | def compute_ci_via_bootstrap( 182 | result_list, 183 | n_resamples=10000, 184 | method='percentile', 185 | confidence_level=0.95 186 | ) -> float: 187 | prompt_lst, inst_lst = parse_result_no_reduce(result_list) 188 | prompt_pct_low, prompt_pct_high = bootstrap( 189 | (np.array(prompt_lst),), 190 | np.mean, 191 | n_resamples=n_resamples, 192 | method=method, 193 | confidence_level=confidence_level, 194 | ).confidence_interval 195 | inst_pct_low, inst_pct_high = bootstrap( 196 | (np.array(inst_lst),), 197 | np.mean, 198 | n_resamples=n_resamples, 199 | method=method, 200 | confidence_level=confidence_level, 201 | ).confidence_interval 202 | return prompt_pct_low, prompt_pct_high, inst_pct_low, inst_pct_high 203 | 204 | def metrics_gen( 205 | output_df: pd.DataFrame, 206 | return_outputs: bool = False 207 | ) -> Dict[str, Any]: 208 | """ 209 | Generate metrics from the given table 210 | """ 211 | language_list = [ 212 | "all_languages", 213 | "German", 214 | "Italian", 215 | "Vietnamese", 216 | "Spanish", 217 | "Hindi", 218 | "Portuguese", 219 | "English", 220 | "French", 221 | "Thai", 222 | "Chinese", 223 | "Russian", 224 | ] 225 | outputs_strict = {language: [] for language in language_list} 226 | outputs_loose = {language: [] for language in language_list} 227 | 228 | row = output_df.iloc[0] 229 | turn_index = int(row["turn_index"]) 230 | turn_index_prompt = f"turn_{turn_index}_prompt" 231 | 232 | index_counter = [] 233 | for _, row in output_df.iterrows(): 234 | if row[turn_index_prompt] == "None" or len(str(row[turn_index_prompt])) == 0: 235 | continue 236 | try: 237 | instruction_id_list = json.loads( 238 | row[f"turn_{turn_index}_instruction_id_list"] 239 | ) 240 | except: 241 | continue 242 | kwargs_list = json.loads(row[f"turn_{turn_index}_kwargs"]) 243 | kwargs = [json.loads(kwarg) for kwarg in kwargs_list] 244 | try: 245 | response = json.loads(row['responses'])[0]['response'] 246 | except: 247 | response = row["responses"] 248 | 249 | input_dict = { 250 | "response": response, 251 | "instruction_id_list": instruction_id_list, 252 | "kwargs": kwargs, 253 | } 254 | 255 | outputs_strict["all_languages"].append(gen_acc_strict(input_dict)) 256 | outputs_loose["all_languages"].append(gen_acc_loose(input_dict)) 257 | 258 | language = row["language"] 259 | outputs_strict[language].append(gen_acc_strict(input_dict)) 260 | outputs_loose[language].append(gen_acc_loose(input_dict)) 261 | index_counter.append(row['key']) 262 | 263 | result_dict = {} 264 | 265 | result_dict.update( 266 | {f"turn_{turn_index}_prompts_number": len(outputs_strict["all_languages"])} 267 | ) 268 | for language in language_list: 269 | if outputs_strict[language] == []: 270 | continue 271 | res_strict = parse_result(outputs=outputs_strict[language]) 272 | res_strict_cis = MultiTurnInstructionFollowingPromptSolution.compute_ci_via_bootstrap(outputs_strict[language]) 273 | res_loose = parse_result(outputs=outputs_loose[language]) 274 | res_loose_cis = MultiTurnInstructionFollowingPromptSolution.compute_ci_via_bootstrap(outputs_loose[language]) 275 | result_list = [res_strict[0], res_strict[1], res_loose[0], res_loose[1]] 276 | average = sum(result_list) / len(result_list) 277 | 278 | result_dict.update({f"turn_{turn_index}_{language}_overall": average}) 279 | result_dict[f'{language}_cis_strict'] = res_strict_cis 280 | result_dict[f'{language}_cis_loose'] = res_loose_cis 281 | if not return_outputs: 282 | return result_dict 283 | else: 284 | outputs_strict['counter'] = index_counter 285 | outputs_loose['counter'] = index_counter 286 | return result_dict, outputs_strict, outputs_loose 287 | 288 | def get_text_column_name() -> str: 289 | """ 290 | Get the column name that stores text in the dataframe 291 | """ 292 | return MultiTurnInstructionFollowingPromptSolution.PROMPT_COLUMN_NAME 293 | 294 | def expand_df( 295 | df: pd.DataFrame, 296 | ) -> pd.DataFrame: 297 | logger.info("Expanding/updating dataframe...") 298 | new_index = 1 299 | if "turn_index" in df.columns: 300 | # TODO: need update astype thing. 301 | new_index = df["turn_index"].astype(int) + 1 302 | df["turn_index"] = new_index.astype(str) 303 | else: 304 | df["turn_index"] = str(new_index) 305 | return df 306 | 307 | 308 | if __name__ == '__main__': 309 | evaluator = MultiTurnInstructionFollowingPromptSolution() 310 | df = pd.read_csv('data/o1-preview/eval_result_step_1.csv', keep_default_na=False) 311 | res = MultiTurnInstructionFollowingPromptSolution.metrics_gen(df) 312 | print(res) 313 | -------------------------------------------------------------------------------- /multi_turn_instruct_following_eval_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import json 17 | import logging 18 | import os 19 | import time 20 | from concurrent.futures import ThreadPoolExecutor, as_completed 21 | from threading import Lock 22 | 23 | import pandas as pd 24 | from tqdm import tqdm 25 | 26 | from api_client import get_api_bot 27 | from metrics import MultiTurnInstructionFollowingPromptSolution 28 | from utils import GenerationSetting 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig(level=logging.DEBUG) 32 | 33 | lock = Lock() 34 | 35 | 36 | def max_retry_wrapper(api_bot, messages, max_retry=3): 37 | for attempt in range(max_retry, 0, -1): 38 | try: 39 | response = api_bot.generate(messages) 40 | return response 41 | except Exception as e: 42 | print(messages) 43 | logger.error(f"API call failed with error: {e}. Retries left: {attempt - 1}") 44 | time.sleep(1) # Brief pause before retrying 45 | return f'[MAX_RETRY=0] Failed.' 46 | 47 | 48 | def process_row(api_bot, row, step, max_retry): 49 | try: 50 | if step == 1: 51 | messages = [json.loads(row['turn_1_prompt'])] 52 | else: 53 | messages = json.loads(row['turns']) + [ 54 | {'role': 'assistant', 'content': row['responses']}, 55 | json.loads(row[f'turn_{step}_prompt']), 56 | ] 57 | response = max_retry_wrapper(api_bot, messages, max_retry) 58 | updated_turns = json.dumps(messages) 59 | status = 'success' if not response.startswith('[MAX_RETRY') else 'failed' 60 | return updated_turns, response, status 61 | except Exception as e: 62 | logger.exception(f"Error processing row: {e}") 63 | print(row) 64 | return row.get('turns', '[]'), f'Exception: {e}', 'exception' 65 | 66 | 67 | def step_fn_api( 68 | api_bot, 69 | input_df, 70 | step, 71 | need_write2file=True, 72 | output_filepath=None, 73 | max_retry=3, 74 | max_workers=5, # Limit the number of threads 75 | ): 76 | total_loc = len(input_df) 77 | output_df = input_df.copy() 78 | with lock: 79 | if "turns" not in output_df.columns: 80 | output_df["turns"] = pd.array(["[]"] * len(output_df), dtype="string") 81 | if "responses" not in output_df.columns: 82 | output_df["responses"] = pd.array(["None"] * len(output_df), dtype="string") 83 | if "status" not in output_df.columns: 84 | output_df["status"] = pd.array(["pending"] * len(output_df), dtype="string") 85 | output_df['turn_index'] = step # Update to current step 86 | 87 | rows_to_process = [] 88 | for idx, row in input_df.iterrows(): 89 | current_turn_index = row.get('turn_index', 0) 90 | response = row.get('responses', 'None') 91 | if current_turn_index > step or (current_turn_index == step and not response.startswith('[MAX_RETRY')): 92 | print(f"Skipped idx: {idx}") 93 | continue # Skip already processed rows 94 | rows_to_process.append((idx, row)) 95 | logger.info(f"Processing {len(rows_to_process)} out of {total_loc} rows for step {step}") 96 | 97 | results = {} 98 | 99 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 100 | future_to_idx = { 101 | executor.submit(process_row, api_bot, row, step, max_retry): idx 102 | for idx, row in rows_to_process 103 | } 104 | for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx)): 105 | idx = future_to_idx[future] 106 | try: 107 | updated_turns, response, status = future.result() 108 | results[idx] = (updated_turns, response, status) 109 | except Exception as exc: 110 | logger.error(f"Row {idx} generated an exception: {exc}") 111 | results[idx] = (output_df.at[idx, "turns"], f'Exception: {exc}', 'exception') 112 | 113 | with lock: 114 | for idx, (turns, response, status) in results.items(): 115 | output_df.at[idx, "turns"] = turns 116 | output_df.at[idx, "responses"] = response 117 | output_df.at[idx, "status"] = status 118 | 119 | if need_write2file and output_filepath: 120 | output_df.to_csv(output_filepath, index=False) 121 | logger.info(f"Step {step} results written to {output_filepath}") 122 | 123 | return output_df 124 | 125 | 126 | def consolidate_results(api_model_name, output_filepath_prefix, steps): 127 | consolidated_df = None 128 | for step in steps: 129 | step_csv = f"results/{api_model_name}/{output_filepath_prefix}_step_{step}.csv" 130 | if os.path.exists(step_csv): 131 | temp_df = pd.read_csv(step_csv, keep_default_na=False) 132 | if consolidated_df is None: 133 | consolidated_df = temp_df.copy() 134 | else: 135 | # Merge on a unique identifier; assuming the index serves as a unique identifier 136 | consolidated_df = consolidated_df.combine_first(temp_df) 137 | else: 138 | logger.warning(f"Step {step} file {step_csv} does not exist and will be skipped.") 139 | 140 | if consolidated_df is not None: 141 | consolidated_csv = f"results/{api_model_name}/{output_filepath_prefix}_consolidated.csv" 142 | consolidated_df.to_csv(consolidated_csv, index=False) 143 | logger.info(f"All steps consolidated into {consolidated_csv}") 144 | else: 145 | logger.warning("No data available to consolidate.") 146 | 147 | 148 | def main( 149 | api_model_name, 150 | input_data_csv: str = "dataset/multi_turn_sample.csv", 151 | generation_setting=GenerationSetting( 152 | max_new_tokens=25000, temperature=0.6, top_p=0.9, seed=42 153 | ), 154 | need_write2file: bool = True, 155 | output_filepath_prefix: str = "eval_result", 156 | max_workers: int = 5, # Number of threads 157 | steps: list = [1, 2, 3], # New parameter for steps 158 | ): 159 | benchmark_df = pd.read_csv(input_data_csv, keep_default_na=False) 160 | num_rows = len(benchmark_df.axes[0]) 161 | logger.info(f"Number of rows in input data: {num_rows}") 162 | final_metric_result = {} 163 | 164 | api_bot = get_api_bot(api_model_name, generation_setting) 165 | step_input_df = benchmark_df.copy() 166 | for step in steps: # Use the user-provided steps 167 | output_filepath = ( 168 | f"results/{api_model_name}/{output_filepath_prefix}_step_{step}.csv" 169 | ) 170 | os.makedirs(f'results/{api_model_name}', exist_ok=True) 171 | step_output_df = step_fn_api( 172 | api_bot=api_bot, 173 | input_df=step_input_df, 174 | step=step, 175 | need_write2file=need_write2file, 176 | output_filepath=output_filepath, 177 | max_workers=max_workers, 178 | ) 179 | 180 | step_input_df = step_output_df.copy() 181 | step_metric_result = run_metric( 182 | api_model_name, 183 | output_filepath_prefix=output_filepath_prefix, 184 | step=step, 185 | ) 186 | final_metric_result[step] = step_metric_result 187 | 188 | consolidate_results(api_model_name, output_filepath_prefix, steps=steps) 189 | 190 | 191 | def run_metric( 192 | api_model_name, 193 | output_filepath_prefix: str = "eval_result", 194 | step: int = 1 195 | ): 196 | step_output_df = None 197 | step_csv = f"results/{api_model_name}/{output_filepath_prefix}_step_{step}.csv" 198 | if not os.path.exists(step_csv): 199 | logger.warning(f"CSV file {step_csv} does not exist and will be skipped.") 200 | return {} 201 | 202 | logger.info(f"Calculating metrics for step_{step}") 203 | step_output_df = pd.read_csv(step_csv, keep_default_na=False) 204 | 205 | if step_output_df is not None and not step_output_df.empty: 206 | metric_result = MultiTurnInstructionFollowingPromptSolution.metrics_gen( 207 | step_output_df 208 | ) 209 | metric_result_df = pd.DataFrame.from_dict(metric_result, orient="index") 210 | metric_result_df.to_csv(f"results/{api_model_name}/{output_filepath_prefix}_step_{step}_metric.csv") 211 | logger.info(f"Step {step} metrics:\n{metric_result}") 212 | return metric_result 213 | else: 214 | logger.warning(f"No data available for step {step} to compute metrics.") 215 | return {} 216 | 217 | 218 | if __name__ == '__main__': 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument("--api_model_name", type=str, default="o1-mini") 221 | parser.add_argument( 222 | "--input_data_csv", type=str, default="dataset/multi_turn_sample.csv" 223 | ) 224 | parser.add_argument("--need_write2file", type=bool, default=True) 225 | parser.add_argument("--output_filepath_prefix", type=str, default="eval_result") 226 | 227 | parser.add_argument('--max_new_tokens', type=int, default=1024, help='o1 recommends 25000') 228 | parser.add_argument('--temperature', type=float, default=0.6) 229 | parser.add_argument('--top_p', type=float, default=0.9) 230 | parser.add_argument('--seed', type=int, default=42) 231 | parser.add_argument('--max_workers', type=int, default=5, help='Number of threads for concurrency') 232 | 233 | # New steps argument 234 | parser.add_argument( 235 | '--steps', 236 | type=int, 237 | nargs='+', 238 | default=[1, 2, 3], 239 | help='List of steps to process (e.g., --steps 1 2 3)' 240 | ) 241 | 242 | args = parser.parse_args() 243 | logger.info(f'Args: \n max_new_tokens: {args.max_new_tokens}, \n temperature: {args.temperature}, \n top_p: {args.top_p}, \n seed: {args.seed}, \n max_workers: {args.max_workers}, \n steps: {args.steps}') 244 | 245 | if 'o1' in args.api_model_name: 246 | # o1 doesn't allow for customized top_p and temperature. 247 | args.top_p = 1 248 | args.temperature = 1 249 | generation_setting = GenerationSetting( 250 | max_new_tokens=args.max_new_tokens, 251 | temperature=args.temperature, 252 | top_p=args.top_p, 253 | seed=args.seed 254 | ) 255 | 256 | main( 257 | api_model_name=args.api_model_name, 258 | input_data_csv=args.input_data_csv, 259 | generation_setting=generation_setting, 260 | need_write2file=args.need_write2file, 261 | output_filepath_prefix=args.output_filepath_prefix, 262 | max_workers=args.max_workers, 263 | steps=args.steps, 264 | ) -------------------------------------------------------------------------------- /multi_turn_instruct_following_eval_vllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import logging 17 | import os 18 | import time 19 | from typing import Any, Dict, List, Optional 20 | 21 | import pandas as pd 22 | from metrics import MultiTurnInstructionFollowingPromptSolution 23 | from transformers import ( 24 | AutoTokenizer, 25 | PreTrainedTokenizerBase, 26 | ) 27 | from vllm import LLM 28 | from utils import GenerationSetting, get_inference_batch_vllm, preprocess_data 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig(level=logging.DEBUG) 32 | 33 | 34 | def main( 35 | model_path: str, 36 | tokenizer_path: str, 37 | input_data_csv: str = "dataset/multi_turn_sample.csv", 38 | batch_size=24, 39 | generation_setting=GenerationSetting( 40 | max_new_tokens=1024, temperature=0.6, top_p=0.9 41 | ), 42 | need_write2file: bool = True, 43 | output_filepath_prefix: str = "eval_result", 44 | tensor_parallel_size: int = 8, 45 | steps: List[int] = [1, 2, 3], 46 | tag: str = '_test_long_sys' 47 | ) -> None: 48 | benchmark_df = pd.read_csv(input_data_csv, keep_default_na=False) 49 | num_rows = len(benchmark_df) 50 | logger.info(f"Number of rows: {num_rows}") 51 | 52 | final_metric_result = {} 53 | 54 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side="left") 55 | tokenizer.pad_token = tokenizer.eos_token 56 | 57 | model = LLM( 58 | model_path, 59 | tensor_parallel_size=tensor_parallel_size, 60 | dtype="bfloat16", 61 | distributed_executor_backend="mp", 62 | seed=generation_setting.seed, 63 | ) 64 | 65 | start = time.time() 66 | model_name = model_path.split('/')[-1].strip() + tag 67 | step_input_df = benchmark_df.copy() 68 | 69 | for step in steps: 70 | output_filepath = ( 71 | f"results/{model_name}/{output_filepath_prefix}_step_{step}.csv" 72 | ) 73 | os.makedirs(os.path.dirname(output_filepath), exist_ok=True) 74 | step_output_df = run_step( 75 | model=model, 76 | tokenizer=tokenizer, 77 | input_df=step_input_df, 78 | row_limit=-1, 79 | need_write2file=need_write2file, 80 | output_filepath=output_filepath, 81 | device=0, 82 | generation_setting=generation_setting, 83 | batch_size=batch_size, 84 | ) 85 | step_input_df = step_output_df.copy() 86 | step_metric_result = run_metric( 87 | output_filepath=output_filepath, 88 | step=step, 89 | ) 90 | final_metric_result[step] = step_metric_result 91 | 92 | logger.info( 93 | f"Total time: {time.time() - start}\n Number of rows: {num_rows}, \n Number of processes: 1" 94 | ) 95 | logger.info(f"Final metrics: {final_metric_result}") 96 | 97 | 98 | def run_step( 99 | model: LLM, # Changed to vllm's LLM 100 | tokenizer: PreTrainedTokenizerBase, 101 | input_df: pd.DataFrame, 102 | prompt_columns: List[str] = ["turns", "responses"], 103 | step: int = 0, 104 | row_limit: int = -1, 105 | need_write2file: bool = True, 106 | output_filepath: str = "eval_result.csv", 107 | device: Optional[str] = None, 108 | generation_setting: GenerationSetting = GenerationSetting(), 109 | batch_size: int = 24 110 | ) -> pd.DataFrame: 111 | output_df = preprocess_data( 112 | input_df, prompt_columns=prompt_columns, row_limit=row_limit 113 | ) 114 | step_output_df = get_inference_batch_vllm( 115 | model=model, 116 | tokenizer=tokenizer, 117 | input_df=output_df, 118 | batch_size=batch_size, 119 | generation_setting=generation_setting, 120 | need_write2file=need_write2file, 121 | output_filepath=output_filepath, 122 | device=device, 123 | ) 124 | return step_output_df 125 | 126 | 127 | def run_metric( 128 | output_filepath: str = "eval_result", step: int = 0 129 | ) -> Dict[str, Any]: 130 | step_output_df = None 131 | csv = output_filepath 132 | logger.info(f"calculating metrics for step_{step}") 133 | temp_df = pd.read_csv(csv, keep_default_na=False) 134 | step_output_df = temp_df.copy() 135 | metric_result = MultiTurnInstructionFollowingPromptSolution.metrics_gen( 136 | step_output_df 137 | ) 138 | metric_result_df = pd.DataFrame.from_dict(metric_result, orient="index") 139 | metric_result_df.to_csv(output_filepath.replace('.csv', '_metric.csv'), index=False) 140 | logger.info(f"step_{step} metrics \n: {metric_result}") 141 | return metric_result 142 | 143 | 144 | if __name__ == "__main__": 145 | """ 146 | !!NOTE!!: make sure the number of available GPUs == tensor_parallel_size 147 | Usage: 148 | python multi_turn_instruct_following_eval_vllm_final.py \ 149 | --model_path \ 150 | --tokenizer_path \ 151 | --input_data_csv \ 152 | --batch_size \ 153 | --need_write2file \ 154 | --output_filepath_prefix \ 155 | --tensor_parallel_size 156 | 157 | Example: 158 | python multi_turn_instruct_following_eval_vllm_final.py \ 159 | --model_path meta-llama/Llama-3.1-70B-Instruct \ 160 | --tokenizer_path meta-llama/Llama-3.1-70B-Instruct \ 161 | --input_data_csv dataset/multi_turn_sample_v6.csv \ 162 | --batch_size 64 \ 163 | --tensor_parallel_size 8 164 | 165 | """ 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--model_path", type=str, default="Meta-Llama-3.1-8B-Instruct") 168 | parser.add_argument( 169 | "--tokenizer_path", type=str, default="Meta-Llama-3.1-8B-Instruct" 170 | ) 171 | parser.add_argument( 172 | "--input_data_csv", type=str, default="dataset/multi_turn_sample.csv" 173 | ) 174 | parser.add_argument("--batch_size", type=int, default=24) 175 | parser.add_argument("--need_write2file", type=bool, default=True) 176 | parser.add_argument("--output_filepath_prefix", type=str, default="eval_result") 177 | parser.add_argument("--tensor_parallel_size", type=int, default=8) 178 | parser.add_argument( 179 | '--steps', 180 | type=int, 181 | nargs='+', 182 | default=[1, 2, 3], 183 | help='List of steps to process (e.g., --steps 1 2 3)' 184 | ) 185 | args = parser.parse_args() 186 | main( 187 | model_path=args.model_path, 188 | tokenizer_path=args.tokenizer_path, 189 | input_data_csv=args.input_data_csv, 190 | batch_size=args.batch_size, 191 | need_write2file=args.need_write2file, 192 | output_filepath_prefix=args.output_filepath_prefix, 193 | tensor_parallel_size=args.tensor_parallel_size, 194 | steps=args.steps 195 | ) 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | transformers 3 | langdetect 4 | six 5 | emoji 6 | nltk 7 | pythainlp 8 | pandas 9 | scipy 10 | anthropic 11 | mistralai 12 | google.generativeai 13 | openai 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase 17 | import torch 18 | import pandas as pd 19 | import logging 20 | from typing import List, Optional 21 | from metrics import MultiTurnInstructionFollowingPromptSolution 22 | import json 23 | import numpy as np 24 | from dataclasses import dataclass 25 | from vllm import LLM 26 | from vllm.sampling_params import SamplingParams 27 | 28 | logger: logging.Logger = logging.getLogger(__name__) 29 | logging.basicConfig(level=logging.DEBUG) 30 | 31 | @dataclass 32 | class GenerationSetting: 33 | max_new_tokens: int = 4096 34 | temperature: float = 1.0 35 | top_p: float = 0.9 36 | seed: int = 42 37 | 38 | def get_inference_batch(model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, input_df: pd.DataFrame, batch_size: int = 24, generation_setting: GenerationSetting = GenerationSetting(), need_write2file: bool = True, output_filepath: str = "generation_output.csv", device: Optional[str] = None)-> pd.DataFrame: 39 | """ 40 | generate inference result given pre-trained model, tokenizer, and input dataframe 41 | the result will be written to the output_filepath if defined. by default, it will return 42 | the generation output dataframe as the output 43 | Args: 44 | model: pre-trained model 45 | tokenizer: pre-trained tokenizer 46 | input_df: input dataframe 47 | batch_size: batch size for inference 48 | generation_setting: generation setting 49 | need_write2file: whether to write the result to the output_filepath 50 | output_filepath: output filepath 51 | device: device to run the inference on 52 | """ 53 | output_df = input_df.copy() 54 | if "turns" not in output_df.columns: 55 | output_df.insert(1, "turns", value=pd.array(["None"] * len(output_df), dtype="string")) 56 | if "responses" not in output_df.columns: 57 | output_df.insert(1, "responses", value=pd.array(["None"] * len(output_df), dtype="string")) 58 | 59 | num_split = len(output_df)//batch_size + 1 60 | print(f"num_split = {num_split}") 61 | # 2: Apply the chat template 62 | for batch in np.array_split(output_df, num_split): 63 | logger.info(f"processing {len(batch)} input row. ") 64 | chat = batch["multi_turn_prompt_column"].tolist() 65 | # 2: Apply the chat template 66 | formatted_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 67 | 68 | # 3: Tokenize the chat (This can be combined with the previous step using tokenize=True) 69 | tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default 70 | inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False, padding=True).to(model.device) 71 | 72 | # 4: Generate text from the model 73 | # TODO: make generation setting configurable 74 | prefix = inputs['input_ids'].size(1) 75 | if tokenizer.convert_tokens_to_ids("<|eot_id|>"): 76 | terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")] 77 | else: 78 | terminators = [tokenizer.eos_token_id] 79 | outputs = model.generate(**inputs, max_new_tokens=generation_setting.max_new_tokens, eos_token_id=terminators, temperature=generation_setting.temperature, top_p=generation_setting.top_p) 80 | 81 | # 5: Decode the output back to a string 82 | decoded_output = tokenizer.batch_decode([output[prefix:] for output in outputs], skip_special_tokens=True) 83 | 84 | if len(batch) != len(decoded_output): 85 | raise ValueError(f"batch size {len(batch)} != decoded_output size {len(decoded_output)}") 86 | for index, row in batch.iterrows(): 87 | output_df.loc[index, "turns"] = json.dumps(row["multi_turn_prompt_column"]) 88 | output_df.loc[index, "responses"] = decoded_output.pop(0) 89 | 90 | 91 | if need_write2file: 92 | output_df.drop(columns=["multi_turn_prompt_column"], inplace=True) 93 | output_df.to_csv(output_filepath) 94 | 95 | return output_df 96 | 97 | def get_inference_batch_vllm( 98 | model: LLM, # Changed to use vllm's LLM 99 | tokenizer: PreTrainedTokenizerBase, 100 | input_df: pd.DataFrame, 101 | batch_size: int = 24, 102 | generation_setting: GenerationSetting = GenerationSetting(), 103 | need_write2file: bool = True, 104 | output_filepath: str = "generation_output.csv", 105 | device: Optional[str] = None, 106 | ) -> pd.DataFrame: 107 | """ 108 | Generate inference results using vllm for faster generation. 109 | """ 110 | output_df = input_df.copy() 111 | if "turns" not in output_df.columns: 112 | output_df.insert(1, "turns", value=pd.array(["None"] * len(output_df), dtype="string")) 113 | if "responses" not in output_df.columns: 114 | output_df.insert(1, "responses", value=pd.array(["None"] * len(output_df), dtype="string")) 115 | 116 | num_split = len(output_df) // batch_size + 1 117 | logger.info(f"Number of splits: {num_split}") 118 | 119 | # Prepare sampling parameters for vllm 120 | if tokenizer.convert_tokens_to_ids("<|eot_id|>"): 121 | terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] 122 | else: 123 | terminators = [tokenizer.eos_token_id] 124 | 125 | gen_params = SamplingParams( 126 | temperature=generation_setting.temperature, 127 | top_p=generation_setting.top_p, 128 | max_tokens=generation_setting.max_new_tokens, 129 | stop_token_ids=terminators, 130 | ) 131 | 132 | # Use vllm for generation 133 | print_once = True 134 | for batch in np.array_split(output_df, num_split): 135 | logger.info(f"Processing {len(batch)} input rows.") 136 | chat = batch["multi_turn_prompt_column"].tolist() 137 | 138 | try: 139 | formatted_chat = tokenizer.apply_chat_template( 140 | chat, tokenize=False, add_generation_prompt=True 141 | ) 142 | except: 143 | import ipdb; ipdb.set_trace() 144 | if print_once: 145 | print("Chat:", formatted_chat[0]) 146 | print(gen_params) 147 | print_once=False 148 | # Generate outputs using vllm 149 | generation_outputs = model.generate(formatted_chat, sampling_params=gen_params) 150 | 151 | if len(batch) != len(generation_outputs): 152 | raise ValueError( 153 | f"Batch size {len(batch)} != number of generation outputs {len(generation_outputs)}" 154 | ) 155 | 156 | for index, gen_output in zip(batch.index, generation_outputs): 157 | decoded_output = gen_output.outputs[0].text 158 | output_df.loc[index, "turns"] = json.dumps(batch.loc[index, "multi_turn_prompt_column"]) 159 | output_df.loc[index, "responses"] = decoded_output.strip() 160 | 161 | if need_write2file: 162 | output_df.drop(columns=["multi_turn_prompt_column"], inplace=True) 163 | output_df.to_csv(output_filepath, index=False) 164 | 165 | return output_df 166 | 167 | def preprocess_data(input_df: pd.DataFrame, prompt_columns: List[str], row_limit: int = -1)-> pd.DataFrame: 168 | """ 169 | Preprocess the data by applying prompt reformatting 170 | input_df: the input dataframe 171 | row_limit: number of rows to process, -1 means all rows 172 | """ 173 | new_prompt_column = MultiTurnInstructionFollowingPromptSolution.get_text_column_name() 174 | if new_prompt_column not in input_df.columns: 175 | input_df.insert(1, new_prompt_column, value=pd.array([None] * len(input_df), dtype="string")) 176 | else: 177 | raise ValueError(f"Column {new_prompt_column} already exists in the input dataframe!") 178 | 179 | # step 1: expanding/updating existing df 180 | input_df = MultiTurnInstructionFollowingPromptSolution.expand_df(input_df) 181 | if row_limit > 0: 182 | input_df = input_df.head(row_limit) 183 | # step 2: apply prompt reformatting 184 | def reformat_prompt(df_row: pd.core.series.Series) -> str: 185 | return MultiTurnInstructionFollowingPromptSolution.reformat_prompt( 186 | df_row, prompt_columns 187 | ) 188 | input_df[new_prompt_column] = input_df.apply(reformat_prompt, axis=1) 189 | return input_df 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | --------------------------------------------------------------------------------