├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── eval_plus ├── convert_data.py ├── data │ ├── HumanEval.jsonl │ ├── HumanEvalPlus-v0.1.10.jsonl │ ├── HumanEvalPlus-v0.1.9.jsonl │ └── MbppPlus-v0.1.0.jsonl ├── exclude_patterns.txt ├── generate.py ├── model.py ├── readme.md ├── requirements.txt └── test.sh ├── grpo_code ├── __init__.py ├── executor.py ├── parallel_executor.py ├── rewards.py ├── transforms.py └── wasm.py ├── pyproject.toml ├── r1_acecode.yaml └── wasm ├── python-3.12.0.wasm └── python-3.12.0.wasm.sha256sum /.gitignore: -------------------------------------------------------------------------------- 1 | **/axolotl.egg-info 2 | configs 3 | last_run_prepared/ 4 | outputs 5 | .vscode 6 | _site/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | venv3.10/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | 170 | # WandB 171 | # wandb creates a folder to store logs for training runs 172 | wandb 173 | 174 | # Runs 175 | lora-out/* 176 | qlora-out/* 177 | mlruns/* 178 | 179 | /.quarto/ 180 | prepared-datasets/ 181 | submit.sh 182 | *.out* 183 | 184 | typings/ 185 | out/ 186 | 187 | # vim 188 | *.swp 189 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axolotl-ai-cloud/grpo_code/148ea79321f34bbed79b3b55f04c0a7de002665d/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | > [!NOTE] 3 | > Check out our [blog-post](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) for more detail and benchmarks! 4 | 5 | ## Installation 6 | 7 | ```bash 8 | git clone https://github.com/axolotl-ai-cloud/grpo_code.git 9 | cd grpo_code 10 | pip install -e . 11 | pip install axolotl==0.8.0[vllm,flash-attn] 12 | ``` 13 | 14 | ## Training 15 | 16 | The following environment variables can be used to modify the behaviour of the reward functions: 17 | - `WASM_FUEL` - Controls the amount of fuel (computation resources) allocated to the WASM environment (default: 10000000000) 18 | - `WASM_PATH` - Path to the Python WASM runtime file (default: "./wasm/python-3.12.0.wasm") 19 | - `TIMEOUT` - Maximum execution time in seconds for code evaluation (default: 1) 20 | - `MAX_WORKERS` - Number of parallel workers for multiprocessing reward functions (default: 1) 21 | 22 | First, spin up a `vLLM` instance: 23 | 24 | ```bash 25 | CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve r1_acecode.yaml 26 | ``` 27 | 28 | Then, in another terminal, kick off the training process: 29 | 30 | ```bash 31 | CUDA_VISIBLE_DEVICES=0,1 MAX_WORKERS=64 axolotl train r1_acecode.yaml --num-processes 2 32 | ``` 33 | 34 | This example uses 4 A100 GPUs - adjust `CUDA_VISIBLE_DEVICES`, `MAX_WORKERS`, `cfg.micro_batch_size` and `cfg.gradient_accumulation_steps` as necessary to match your hardware. 35 | 36 | ## Python WASM Runtime 37 | 38 | This project uses Python 3.12.0 compiled to WebAssembly from VMware Labs. 39 | 40 | ### Verify an Existing Download 41 | If you already have the WASM file and want to verify its integrity: 42 | 43 | 1. Ensure you have both `python-3.12.0.wasm` and `python-3.12.0.wasm.sha256sum` in the `wasm` directory. 44 | 2. Run the verification command: 45 | 46 | **Linux/macOS:** 47 | ```bash 48 | sha256sum -c ./wasm/python-3.12.0.wasm.sha256sum 49 | ``` 50 | 51 | ### Manual Download 52 | To download the runtime files yourself: 53 | 54 | 1. Download the Python WASM runtime: 55 | ```bash 56 | curl -LO https://github.com/vmware-labs/webassembly-language-runtimes/releases/download/python%2F3.12.0%2B20231211-040d5a6/python-3.12.0.wasm -o ./wasm/python-3.12.0.wasm 57 | ``` 58 | 59 | 2. Download the SHA256 checksum file: 60 | ```bash 61 | curl -LO https://github.com/vmware-labs/webassembly-language-runtimes/releases/download/python%2F3.12.0%2B20231211-040d5a6/python-3.12.0.wasm.sha256sum -o ./wasm/python-3.12.0.wasm.sha256sum 62 | ``` 63 | 64 | 3. Verify the download: 65 | ```bash 66 | sha256sum -c ./wasm/python-3.12.0.wasm.sha256sum 67 | ``` 68 | 69 | 4. Place both files in your project directory or specify the path in your configuration. 70 | -------------------------------------------------------------------------------- /eval_plus/convert_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jsonlines 3 | import os 4 | import sys 5 | 6 | eval_plus_path = os.path.dirname(os.path.abspath(__file__)) + "/evalplus/" 7 | sys.path = [eval_plus_path] + sys.path 8 | from evalplus.data import get_human_eval_plus, get_mbpp_plus 9 | 10 | MBPP_OUTPUT_SET_EQ_TASKS = [ 11 | "similar_elements", # Mbpp/2 12 | "find_char_long", # Mbpp/7 13 | "common_in_nested_lists", # Mbpp/111 14 | "extract_singly", # Mbpp/140 15 | "larg_nnum", # Mbpp/232 16 | "intersection_array", # Mbpp/249 17 | "find_dissimilar", # Mbpp/579 18 | "Diff", # Mbpp/769 19 | ] 20 | MBPP_OUTPUT_NOT_NONE_TASKS = ["check_str", "text_match_three", "text_starta_endb"] 21 | 22 | 23 | def convert_file(root_dir): 24 | import jsonlines 25 | from copy import deepcopy 26 | import sys 27 | import tqdm 28 | import re 29 | sys.set_int_max_str_digits(10000000) 30 | 31 | def write_jsonl_file(objs, target_path): 32 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 33 | with jsonlines.open(target_path, "w") as w: 34 | for obj in objs: 35 | w.write(obj) 36 | print(f"Successfully saving to {target_path}") 37 | 38 | # def get_humaneval_prompt(doc, language): 39 | # language = language.lower() 40 | # question = doc["prompt"].strip() 41 | # return """ 42 | # Please continue to complete the function and return all completed code in a codeblock. Here is the given code to do completion: 43 | # ```{} 44 | # {} 45 | # ``` 46 | # """.strip().format( 47 | # language.lower(), question.strip() 48 | # ) 49 | 50 | def get_prompt(doc, language): 51 | language = language.lower() 52 | question = doc["prompt"].strip() 53 | return """ 54 | Can you complete the following Python function? 55 | ```{} 56 | {} 57 | ``` 58 | """.strip().format(language.lower(), question.strip()) 59 | 60 | def create_high_accuracy_function(code, entry_point): 61 | high_accuracy = """ 62 | from decimal import Decimal, getcontext 63 | from functools import wraps 64 | getcontext().prec = 100 65 | 66 | def convert_to_decimal(value): 67 | if isinstance(value, float): 68 | return Decimal(str(value)) 69 | elif isinstance(value, list): 70 | return [convert_to_decimal(item) for item in value] 71 | elif isinstance(value, dict): 72 | return {k: convert_to_decimal(v) for k, v in value.items()} 73 | return value 74 | 75 | def float_to_decimal(func): 76 | @wraps(func) 77 | def wrapper(*args, **kwargs): 78 | new_args = [convert_to_decimal(arg) for arg in args] 79 | new_kwargs = {k: convert_to_decimal(v) for k, v in kwargs.items()} 80 | result = func(*new_args, **new_kwargs) 81 | return result 82 | return wrapper 83 | 84 | def convert_to_float(value): 85 | if isinstance(value, Decimal): 86 | return float(value) 87 | elif isinstance(value, list): 88 | return [convert_to_float(item) for item in value] 89 | elif isinstance(value, dict): 90 | return {k: convert_to_float(v) for k, v in value.items()} 91 | return value 92 | 93 | def decimal_to_float(func): 94 | @wraps(func) 95 | def wrapper(*args, **kwargs): 96 | # Execute the wrapped function 97 | result = func(*args, **kwargs) 98 | 99 | # Convert the result back to float, if necessary 100 | result = convert_to_float(result) 101 | return result 102 | return wrapper 103 | """ 104 | """Execute trusted code in place.""" 105 | code = high_accuracy + code 106 | code = code.split("\n") 107 | new_code = [] 108 | cnt = 0 109 | for c in code: 110 | if re.search(rf"def {entry_point}\(.*?\)", c) is not None: 111 | cnt += 1 112 | new_code.append("@float_to_decimal") 113 | new_code.append("@decimal_to_float") 114 | new_code.append(c) 115 | code = "\n".join(new_code) 116 | return code 117 | 118 | def trusted_exec(code, inputs, entry_point, record_time=False, output_not_none=False): 119 | exec_globals = {} 120 | # if entry_point not in ["triangle_area", "angle_complex", "volume_sphere"]: # avoid special case (a ** b) 121 | # code = create_high_accuracy_function(code, entry_point) 122 | if "**" not in code and entry_point not in ["triangle_area", "angle_complex", "volume_sphere"]: 123 | code = create_high_accuracy_function(code, entry_point) 124 | #print(code) 125 | exec(code, exec_globals) 126 | fn = exec_globals[entry_point] 127 | 128 | rtime = [] 129 | ret = [] 130 | for inp in inputs: 131 | inp = deepcopy(inp) 132 | if record_time: 133 | start = time.time() 134 | ret.append(fn(*inp)) 135 | rtime.append(time.time() - start) 136 | else: 137 | ret.append(fn(*inp)) 138 | 139 | if output_not_none: 140 | ret = [i is not None for i in ret] 141 | 142 | if record_time: 143 | return ret, rtime 144 | else: 145 | return ret 146 | 147 | def convert(objs, test_set="base_input", task_name=f"evalplus/humaneval"): 148 | type 149 | data = [] 150 | for obj in tqdm.tqdm(objs): 151 | prompt = get_prompt(obj, language="python") 152 | if test_set == "base_input": 153 | inputs = obj["base_input"] 154 | else: 155 | inputs = obj["base_input"] + obj["plus_input"] if not isinstance(obj["plus_input"], dict) else obj["base_input"] 156 | #outputs = trusted_exec(code = obj["prompt"] + obj["canonical_solution"], inputs = obj["base_input"], entry_point = obj["entry_point"]) 157 | #tests = create_check_function(test_cases = inputs, entry_point=obj["entry_point"], outputs = outputs) 158 | outputs = trusted_exec(code=obj["prompt"] + obj["canonical_solution"], inputs=[obj["base_input"][0]], entry_point=obj["entry_point"]) 159 | atol = obj["atol"] 160 | if atol == 0: 161 | atol = 1e-6 # enforce atol for float comparison 162 | #```python 163 | if obj["entry_point"] == "find_zero": #humaneval 164 | tests = create_dynamic_check_function_find_zero(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol) 165 | elif obj["entry_point"] in MBPP_OUTPUT_NOT_NONE_TASKS: 166 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="not_none", atol=atol) 167 | elif obj["entry_point"] in MBPP_OUTPUT_SET_EQ_TASKS: # mbpp 168 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="set", atol=atol) 169 | elif obj["entry_point"] == "are_equivalent": # mbpp 170 | tests = create_dynamic_check_function_are_equivalent(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol) 171 | elif obj["entry_point"] == "sum_div": # mbpp 172 | tests = create_dynamic_check_function_sum_div(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol) 173 | elif isinstance(outputs[0], float) or (isinstance(outputs[0], list) and len(outputs[0]) > 0 and isinstance(outputs[0][0], float)): 174 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="np.allcose", atol=atol) 175 | else: 176 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="==", atol=atol) 177 | data.append({ 178 | "prompt": prompt, 179 | "test": tests, 180 | "entry_point": obj["entry_point"], 181 | "tags": f"coding,en,python,core", 182 | "task": task_name, 183 | "source": f"evalplus", 184 | "eval_args": { 185 | "greedy": True, 186 | #"seed": 1234, 187 | "out_seq_length": 1024, 188 | "repetition_penalty": 1.0, 189 | "temperature": 1.0, 190 | "top_k": -1, 191 | "top_p": 0.95, 192 | "presence_penalty": 0, 193 | "system_str": "You are an intelligent programming assistant to produce Python algorithmic solutions", 194 | }, 195 | "extra_response_prefix": "```python\n" 196 | }) 197 | return data 198 | 199 | def create_check_function(test_cases, entry_point, outputs): 200 | test_cases_str = "def check():\n" 201 | for case, output in zip(test_cases, outputs): 202 | for i in range(len(case)): 203 | if isinstance(case[i], str) and "\n" in case[i]: 204 | case[i] = case[i].replace("\n", "\\n") 205 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case]) 206 | output = str(output) if not isinstance(output, str) else f"'{output}'" 207 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {output}\n" 208 | test_cases_str += single_test_case_str 209 | test_cases_str += "check()" 210 | return test_cases_str 211 | 212 | def create_dynamic_check_function_are_equivalent(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0): 213 | test_cases_str = "import numpy as np\n" + prompt + correct_solution 214 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(") 215 | test_cases_str += "def check():\n" 216 | for case in test_cases: 217 | for i in range(len(case)): 218 | if isinstance(case[i], str) and "\n" in case[i]: 219 | case[i] = case[i].replace("\n", "\\n") 220 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case]) 221 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params}) or {entry_point}({input_params}) == 0\n" 222 | test_cases_str += single_test_case_str 223 | test_cases_str += "check()" 224 | return test_cases_str 225 | 226 | def create_dynamic_check_function_sum_div(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0): 227 | test_cases_str = "import numpy as np\n" + prompt + correct_solution 228 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(") 229 | test_cases_str += "def check():\n" 230 | for case in test_cases: 231 | for i in range(len(case)): 232 | if isinstance(case[i], str) and "\n" in case[i]: 233 | case[i] = case[i].replace("\n", "\\n") 234 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case]) 235 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params}) or {entry_point}({input_params}) == 0\n" 236 | test_cases_str += single_test_case_str 237 | test_cases_str += "check()" 238 | return test_cases_str 239 | 240 | def create_dynamic_check_function(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0): 241 | test_cases_str = "import numpy as np\n" + prompt + correct_solution 242 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(") 243 | test_cases_str += "def check():\n" 244 | for case in test_cases: 245 | for i in range(len(case)): 246 | if isinstance(case[i], str) and "\n" in case[i]: 247 | case[i] = case[i].replace("\n", "\\n") 248 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case]) 249 | if check_style == "np.allcose": 250 | single_test_case_str = f"\tassert np.allclose({entry_point}({input_params}), {entry_point}_ground_truth({input_params}), rtol=1e-07, atol={atol})\n" 251 | elif check_style == "==": 252 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params})\n" 253 | elif check_style == "set": 254 | single_test_case_str = f"\tassert set({entry_point}({input_params})) == set({entry_point}_ground_truth({input_params}))\n" 255 | elif check_style == "not_none": 256 | single_test_case_str = f"\tif isinstance({entry_point}({input_params}), bool):\n" 257 | single_test_case_str += f"\t\tassert {entry_point}({input_params}) == ({entry_point}_ground_truth({input_params}) is not None)\n" 258 | single_test_case_str += f"\telse:\n" 259 | single_test_case_str += f"\t\tassert ({entry_point}({input_params}) is not None) == ({entry_point}_ground_truth({input_params}) is not None)\n" 260 | test_cases_str += single_test_case_str 261 | test_cases_str += "check()" 262 | return test_cases_str 263 | 264 | def create_dynamic_check_function_find_zero(test_cases, entry_point, prompt, correct_solution, atol=0): 265 | test_cases_str = "import numpy as np\n" + prompt + correct_solution 266 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(") 267 | test_cases_str += "def check():\n" 268 | for case in test_cases: 269 | for i in range(len(case)): 270 | if isinstance(case[i], str) and "\n" in case[i]: 271 | case[i] = case[i].replace("\n", "\\n") 272 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case]) 273 | single_test_case_str = f"\tassert abs(poly({input_params}, {entry_point}({input_params}))) <= {atol}\n" 274 | test_cases_str += single_test_case_str 275 | test_cases_str += "check()" 276 | return test_cases_str 277 | 278 | humaneval_data = get_human_eval_plus() 279 | data1 = convert(humaneval_data.values(), test_set="base_input", task_name="evalplus/humaneval") 280 | write_jsonl_file(data1, f"{root_dir}/evalplus_v2/humaneval.jsonl") 281 | data2 = convert(humaneval_data.values(), test_set="plus_input", task_name="evalplus/humaneval_plus") 282 | write_jsonl_file(data2, f"{root_dir}/evalplus_v2/humaneval_plus.jsonl") 283 | 284 | mbpp_data = get_mbpp_plus() 285 | data3 = convert(mbpp_data.values(), test_set="base_input", task_name="evalplus/mbpp") 286 | write_jsonl_file(data3, f"{root_dir}/evalplus_v2/mbpp.jsonl") 287 | data4 = convert(mbpp_data.values(), test_set="plus_input", task_name="evalplus/mbpp_plus") 288 | write_jsonl_file(data4, f"{root_dir}/evalplus_v2/mbpp_plus.jsonl") 289 | 290 | all_data = data1 + data2 + data3 + data4 291 | write_jsonl_file(all_data, f"{root_dir}/evalplus_v2/evalplus.jsonl") 292 | 293 | all_data = np.random.choice(all_data, 10) 294 | write_jsonl_file(all_data, f"{root_dir}/evalplus_v2/evalplus.jsonl.sampled") 295 | 296 | 297 | if __name__ == "__main__": 298 | convert_file(root_dir="data/eval/code/") 299 | -------------------------------------------------------------------------------- /eval_plus/exclude_patterns.txt: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | output*/ 4 | -------------------------------------------------------------------------------- /eval_plus/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import PathLike 4 | import sys 5 | eval_plus_path = os.path.dirname(os.path.abspath(__file__)) + "/evalplus/" 6 | sys.path = [eval_plus_path] + sys.path 7 | from model import DecoderBase, make_model 8 | from rich.progress import ( 9 | BarColumn, 10 | MofNCompleteColumn, 11 | Progress, 12 | TextColumn, 13 | TimeElapsedColumn, 14 | ) 15 | 16 | 17 | MODEL_MAPPING = { 18 | # Can be either repo's name or /path/to/model 19 | "codeqwen": { 20 | "base": "Qwen/CodeQwen1.5-7B", 21 | "chat": "Qwen/CodeQwen1.5-7B-Chat", 22 | "chat-awq": "Qwen/CodeQwen1.5-7B-Chat-AWQ", 23 | }, 24 | "qwen2": { 25 | "chat": "Qwen/CodeQwen1.5-7B-Chat", 26 | }, 27 | } 28 | 29 | 30 | def construct_contract_prompt(prompt: str, contract_type: str, contract: str) -> str: 31 | if contract_type == "none": 32 | return prompt 33 | elif contract_type == "docstring": 34 | # embed within the docstring 35 | sep = "" 36 | if '"""' in prompt: 37 | sep = '"""' 38 | elif "'''" in prompt: 39 | sep = "'''" 40 | assert sep != "" 41 | l = prompt.split(sep) 42 | contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) 43 | l[1] = l[1] + contract + "\n" + " " * (len(contract) - len(contract.lstrip()) - 1) 44 | return sep.join(l) 45 | elif contract_type == "code": 46 | # at the beginning of the function 47 | contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) 48 | return prompt + contract 49 | 50 | 51 | def code_generate(args, workdir: PathLike, model: DecoderBase, id_range=None): 52 | with Progress( 53 | TextColumn(f"{args.dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%"), 54 | BarColumn(), 55 | MofNCompleteColumn(), 56 | TextColumn("•"), 57 | TimeElapsedColumn(), 58 | ) as p: 59 | if args.dataset == "humaneval": 60 | from evalplus.data import get_human_eval_plus 61 | dataset = get_human_eval_plus() 62 | elif args.dataset == "mbpp": 63 | from evalplus.data import get_mbpp_plus 64 | dataset = get_mbpp_plus() 65 | 66 | for task_id, task in p.track(dataset.items()): 67 | if id_range is not None: 68 | id_num = int(task_id.split("/")[1]) 69 | low, high = id_range 70 | if id_num < low or id_num >= high: 71 | p.console.print(f"Skipping {task_id} as it is not in {id_range}") 72 | continue 73 | 74 | p_name = task_id.replace("/", "_") 75 | if args.contract_type != "none" and task["contract"] == "": 76 | continue 77 | os.makedirs(os.path.join(workdir, p_name), exist_ok=True) 78 | log = f"Codegen: {p_name} @ {model}" 79 | n_existing = 0 80 | if args.resume: 81 | # count existing .py files 82 | n_existing = len([f for f in os.listdir(os.path.join(workdir, p_name)) if f.endswith(".py")]) 83 | if n_existing > 0: 84 | log += f" (resuming from {n_existing})" 85 | 86 | nsamples = args.n_samples - n_existing 87 | p.console.print(log) 88 | 89 | sidx = args.n_samples - nsamples 90 | while sidx < args.n_samples: 91 | model.dataset = args.dataset 92 | outputs = model.codegen( 93 | construct_contract_prompt(task["prompt"], args.contract_type, task["contract"]).strip(), 94 | do_sample=not args.greedy, 95 | num_samples=args.n_samples - sidx, 96 | ) 97 | assert outputs, "No outputs from model!" 98 | for impl in outputs: 99 | if "```" in impl: 100 | impl = impl.split("```")[0] 101 | print("``` exist in generation. Please check the generation results.") 102 | 103 | try: 104 | with open( 105 | os.path.join(workdir, p_name, f"{sidx}.py"), 106 | "w", 107 | encoding="utf-8", 108 | ) as f: 109 | if model.direct_completion: 110 | f.write(task["prompt"] + impl) 111 | else: 112 | f.write(impl) 113 | except UnicodeEncodeError: 114 | continue 115 | sidx += 1 116 | 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--model_type", required=True, type=str, choices=MODEL_MAPPING.keys()) 121 | parser.add_argument("--model_path", type=str, default=None) 122 | parser.add_argument("--model_size", required=True, type=str) 123 | parser.add_argument("--bs", default=1, type=int) 124 | parser.add_argument("--temperature", default=0.0, type=float) 125 | parser.add_argument("--dataset", required=True, type=str, choices=["humaneval", "mbpp"]) 126 | parser.add_argument("--root", type=str, required=True) 127 | parser.add_argument("--n_samples", default=1, type=int) 128 | parser.add_argument("--resume", action="store_true") 129 | parser.add_argument("--output", type=str) 130 | parser.add_argument("--tensor-parallel-size", default=1, type=int) 131 | parser.add_argument( 132 | "--contract-type", 133 | default="none", 134 | type=str, 135 | choices=["none", "code", "docstring"], 136 | ) 137 | parser.add_argument("--greedy", action="store_true") 138 | # id_range is list 139 | parser.add_argument("--id-range", default=None, nargs="+", type=int) 140 | args = parser.parse_args() 141 | print(args) 142 | assert args.model_size in MODEL_MAPPING[args.model_type] 143 | 144 | model_path = MODEL_MAPPING[args.model_type][args.model_size] 145 | if args.model_path is not None: 146 | model_path = args.model_path 147 | print(f"Loading model from {model_path}") 148 | 149 | print(f"Running model={args.model_type}, size={args.model_size}") 150 | print(f"\tLoad from `{model_path}`") 151 | 152 | if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1): 153 | args.temperature = 0 154 | args.bs = 1 155 | args.n_samples = 1 156 | print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") 157 | 158 | if args.id_range is not None: 159 | assert len(args.id_range) == 2, "id_range must be a list of length 2" 160 | assert args.id_range[0] < args.id_range[1], "id_range must be increasing" 161 | args.id_range = tuple(args.id_range) 162 | 163 | # Make project dir 164 | os.makedirs(args.root, exist_ok=True) 165 | # Make dataset dir 166 | os.makedirs(os.path.join(args.root, args.dataset), exist_ok=True) 167 | # Make dir for codes generated by each model 168 | 169 | model = make_model( 170 | model_type=args.model_type, 171 | model_size=args.model_size, 172 | model_path=model_path, 173 | batch_size=args.bs, 174 | temperature=args.temperature, 175 | dataset=args.dataset, 176 | tensor_parallel_size=args.tensor_parallel_size 177 | ) 178 | workdir = os.path.join( 179 | args.root, 180 | args.dataset, 181 | args.model_type 182 | + f"_{args.model_size}" 183 | + f"_temp_{args.temperature}" 184 | + ("" if args.contract_type == "none" else f"-contract-{args.contract_type}"), 185 | ) 186 | os.makedirs(workdir, exist_ok=True) 187 | print(f"Working dir: {workdir}") 188 | 189 | with open(os.path.join(workdir, "args.txt"), "w") as f: 190 | f.write(str(args)) 191 | 192 | print(f"Model cls: {model.__class__}") 193 | print(f"EOS tokens: {model.eos}") 194 | code_generate(args, workdir=workdir, model=model, id_range=args.id_range) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /eval_plus/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import List 4 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 5 | os.environ["HF_HOME"] = os.environ.get("HF_HOME", "./hf_home") 6 | 7 | import torch 8 | from stop_sequencer import StopSequencer 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | from vllm import LLM, SamplingParams 11 | 12 | 13 | EOS = [ 14 | "<|endoftext|>", 15 | "<|endofmask|>", 16 | "", 17 | "\nif __name__", 18 | "\ndef main(", 19 | "\nprint(", 20 | "\n#" 21 | ] 22 | 23 | 24 | class DecoderBase(ABC): 25 | def __init__( 26 | self, 27 | name: str, 28 | batch_size: int = 1, 29 | temperature: float = 0.8, 30 | max_new_tokens: int = 512, 31 | direct_completion: bool = True, 32 | dtype: str = "bfloat16", # default 33 | trust_remote_code: bool = False, 34 | dataset: str = None, 35 | ) -> None: 36 | print("Initializing a decoder model: {} ...".format(name)) 37 | self.name = name 38 | self.batch_size = batch_size 39 | self.temperature = temperature 40 | self.eos = EOS 41 | self.skip_special_tokens = False 42 | self.max_new_tokens = max_new_tokens 43 | self.direct_completion = direct_completion 44 | self.dtype = dtype 45 | self.trust_remote_code = trust_remote_code 46 | 47 | if direct_completion: 48 | if dataset.lower() == "humaneval": 49 | self.eos += ["\ndef", "\nclass ", "\nimport ", "\nfrom ", "\nassert "] 50 | elif dataset.lower() == "mbpp": 51 | self.eos += ['\n"""', "\nassert"] 52 | 53 | @abstractmethod 54 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 55 | pass 56 | 57 | def __repr__(self) -> str: 58 | return self.name 59 | 60 | def __str__(self) -> str: 61 | return self.name 62 | 63 | 64 | class VLlmDecoder(DecoderBase): 65 | def __init__(self, name: str, tensor_parallel_size = 1, **kwargs) -> None: 66 | super().__init__(name, **kwargs) 67 | 68 | kwargs = { 69 | "tensor_parallel_size": tensor_parallel_size, #int(os.getenv("VLLM_N_GPUS", "1")) 70 | "dtype": self.dtype, 71 | "trust_remote_code": self.trust_remote_code, 72 | "enforce_eager": True, 73 | "gpu_memory_utilization": 0.7 74 | } 75 | print(kwargs) 76 | self.llm = LLM(model=name, max_model_len=1536, **kwargs) 77 | 78 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 79 | if do_sample: 80 | assert self.temperature > 0, "Temperature must be greater than 0!" 81 | batch_size = min(self.batch_size, num_samples) 82 | 83 | vllm_outputs = self.llm.generate( 84 | [prompt] * batch_size, 85 | SamplingParams( 86 | temperature=self.temperature, 87 | max_tokens=self.max_new_tokens, 88 | top_p=0.95 if do_sample else 1.0, 89 | stop=self.eos, 90 | ), 91 | use_tqdm=False, 92 | ) 93 | 94 | gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs] 95 | return gen_strs 96 | 97 | 98 | class VLlmAWQDecoder(DecoderBase): 99 | def __init__(self, name: str, **kwargs) -> None: 100 | super().__init__(name, **kwargs) 101 | 102 | kwargs = { 103 | "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", "1")), 104 | "dtype": torch.float16, 105 | "trust_remote_code": self.trust_remote_code, 106 | "quantization": "AWQ", 107 | } 108 | 109 | self.llm = LLM(model=name, max_model_len=2048, **kwargs) 110 | 111 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 112 | if do_sample: 113 | assert self.temperature > 0, "Temperature must be greater than 0!" 114 | batch_size = min(self.batch_size, num_samples) 115 | 116 | vllm_outputs = self.llm.generate( 117 | [prompt] * batch_size, 118 | SamplingParams( 119 | temperature=self.temperature, 120 | max_tokens=self.max_new_tokens, 121 | top_p=0.95 if do_sample else 1.0, 122 | stop=self.eos, 123 | ), 124 | use_tqdm=False, 125 | ) 126 | 127 | gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs] 128 | return gen_strs 129 | 130 | 131 | class AWQChatML(VLlmAWQDecoder): 132 | def __init__(self, name: str, tensor_parallel_size, **kwargs) -> None: 133 | kwargs["direct_completion"] = False 134 | super().__init__(name, **kwargs) 135 | self.eos += ["\n```"] 136 | 137 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 138 | if do_sample: 139 | assert self.temperature > 0, "Temperature must be greater than 0!" 140 | 141 | input = f"""<|im_start|>system 142 | You are an intelligent programming assistant to produce Python algorithmic solutions<|im_end|> 143 | <|im_start|>user 144 | Can you complete the following Python function? 145 | ```python 146 | {prompt} 147 | ``` 148 | <|im_end|> 149 | <|im_start|>assistant 150 | ```python 151 | """ 152 | return VLlmDecoder.codegen(self, input, do_sample, num_samples) 153 | 154 | 155 | class ChatML(VLlmDecoder): 156 | def __init__(self, name: str, tensor_parallel_size, **kwargs) -> None: 157 | kwargs["direct_completion"] = False 158 | super().__init__(name, tensor_parallel_size, **kwargs) 159 | self.eos += ["\n```"] 160 | 161 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 162 | if do_sample: 163 | assert self.temperature > 0, "Temperature must be greater than 0!" 164 | 165 | input = f"""<|im_start|>system 166 | You are an intelligent programming assistant to produce Python algorithmic solutions<|im_end|> 167 | <|im_start|>user 168 | Can you complete the following Python function? 169 | ```python 170 | {prompt} 171 | ``` 172 | <|im_end|> 173 | <|im_start|>assistant 174 | ```python 175 | """ 176 | return VLlmDecoder.codegen(self, input, do_sample, num_samples) 177 | 178 | 179 | def make_model( 180 | model_type: str, 181 | model_size: str, 182 | model_path: str, 183 | batch_size: int = 1, 184 | temperature: float = 0.8, 185 | dataset: str = None, 186 | tensor_parallel_size = 1 187 | ): 188 | if model_type == "codeqwen" or model_type == "qwen2": 189 | if "chat" in model_size.lower(): 190 | if "awq" in model_size.lower(): 191 | return AWQChatML( 192 | batch_size=batch_size, 193 | name=model_path, 194 | temperature=temperature, 195 | max_new_tokens=2048, 196 | tensor_parallel_size = tensor_parallel_size 197 | ) 198 | else: 199 | return ChatML( 200 | batch_size=batch_size, 201 | name=model_path, 202 | temperature=temperature, 203 | max_new_tokens=2048, 204 | tensor_parallel_size = tensor_parallel_size 205 | ) 206 | else: 207 | return VLlmDecoder( 208 | batch_size=batch_size, 209 | name=model_path, 210 | temperature=temperature, 211 | dataset=dataset, 212 | tensor_parallel_size = tensor_parallel_size 213 | ) 214 | else: 215 | raise ValueError(f"Invalid model name: {model_type}@{model_size}") 216 | -------------------------------------------------------------------------------- /eval_plus/readme.md: -------------------------------------------------------------------------------- 1 | Sourced from the [Qwen2.5-Coder repository](https://github.com/QwenLM/Qwen2.5-Coder/tree/main/qwencoder-eval/instruct/eval_plus) with updated dependencies for better reproducability. 2 | 3 | ## Evaluation for HumanEval(+) and MBPP(+) 4 | 5 | This folder contains the code and scripts to evaluate the performance of the **QwenCoder-2.5** series models on [**EvalPlus**](https://github.com/evalplus/evalplus) benchmark, which includes HumanEval(+) and MBPP(+) datasets. These datasets are designed to test code generation capabilities under varied conditions. 6 | 7 | 8 | ### 1. Setup 9 | 10 | Please refer to [**EvalPlus**](https://github.com/evalplus/evalplus) for detailed setup instructions. Install the required packages using: 11 | 12 | ```bash 13 | pip install evalplus --upgrade 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ### 2. Inference and Evaluation 18 | 19 | We utilize 8xA100 GPUs for this benchmark. The following scripts are used to run the inference and evaluations: 20 | 21 | ```bash 22 | bash test.sh {path_to_your_local_model_checkpoint} {tensor_parallel_size} {output_dir} 23 | ``` 24 | -------------------------------------------------------------------------------- /eval_plus/requirements.txt: -------------------------------------------------------------------------------- 1 | datamodel_code_generator 2 | anthropic 3 | mistralai 4 | google-generativeai 5 | rich 6 | accelerate 7 | vllm==0.6.6 8 | stop-sequencer 9 | evalplus 10 | setuptools 11 | scipy 12 | hf_transfer -------------------------------------------------------------------------------- /eval_plus/test.sh: -------------------------------------------------------------------------------- 1 | mkdir -p results/humaneval 2 | # export HF_ENDPOINT=https://hf-mirror.com 3 | export PATH=./vllm/bin:$PATH 4 | export PYTHONPATH=$PYTHONPATH:./eval_plus/evalplus 5 | MODEL_DIR=${1} 6 | # uv pip install datamodel_code_generator anthropic mistralai google-generativeai 7 | 8 | 9 | MODEL_DIR=${MODEL_DIR:-"/path/to/pretrained_models/"} 10 | TP=${2} 11 | TP=${TP:-1} 12 | OUTPUT_DIR=${3} 13 | OUTPUT_DIR=${OUTPUT_DIR:-"/path/to/"} 14 | mkdir -p ${OUTPUT_DIR} 15 | 16 | echo "EvalPlus: ${MODEL_DIR}, OUTPUT_DIR ${OUTPUT_DIR}" 17 | 18 | python generate.py \ 19 | --model_type qwen2 \ 20 | --model_size chat \ 21 | --model_path ${MODEL_DIR} \ 22 | --bs 1 \ 23 | --temperature 0 \ 24 | --n_samples 1 \ 25 | --greedy \ 26 | --root ${OUTPUT_DIR} \ 27 | --dataset humaneval \ 28 | --tensor-parallel-size ${TP} 29 | 30 | echo "Generated samples: ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0" 31 | 32 | echo "Sanitizing samples" 33 | python -m evalplus.sanitize --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0 34 | 35 | echo "Evaluating humaneval raw" 36 | evalplus.evaluate \ 37 | --dataset humaneval \ 38 | --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0 > ${OUTPUT_DIR}/raw_humaneval_results.txt 39 | echo "Finished evaluating humaneval file at ${OUTPUT_DIR}/raw_humaneval_results.txt" 40 | echo "Evaluating humaneval sanitized" 41 | evalplus.evaluate \ 42 | --dataset humaneval \ 43 | --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0-sanitized > ${OUTPUT_DIR}/humaneval_results.txt 44 | echo "Finished evaluating humaneval file at ${OUTPUT_DIR}/humaneval_results.txt" 45 | 46 | python generate.py \ 47 | --model_type qwen2 \ 48 | --model_size chat \ 49 | --model_path ${MODEL_DIR} \ 50 | --bs 1 \ 51 | --temperature 0 \ 52 | --n_samples 1 \ 53 | --greedy \ 54 | --root ${OUTPUT_DIR} \ 55 | --dataset mbpp \ 56 | --tensor-parallel-size ${TP} 57 | 58 | echo "Sanitizing mbpp" 59 | python -m evalplus.sanitize --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0 60 | 61 | echo "Evaluating mbpp raw" 62 | evalplus.evaluate \ 63 | --dataset mbpp \ 64 | --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0 > ${OUTPUT_DIR}/raw_mbpp_results.txt 65 | echo "Finished evaluating mbpp file at ${OUTPUT_DIR}/raw_mbpp_results.txt" 66 | echo "Evaluating mbpp sanitized" 67 | evalplus.evaluate \ 68 | --dataset mbpp \ 69 | --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0-sanitized > ${OUTPUT_DIR}/mbpp_results.txt 70 | echo "Finished evaluating mbpp file at ${OUTPUT_DIR}/mbpp_results.txt" -------------------------------------------------------------------------------- /grpo_code/__init__.py: -------------------------------------------------------------------------------- 1 | from .rewards import ( 2 | code_execution_reward_func, 3 | answer_execution_reward_func, 4 | soft_format_reward_func, 5 | ) 6 | from .transforms import axolotl_acecode_transform 7 | 8 | __all__ = [ 9 | "code_execution_reward_func", 10 | "answer_execution_reward_func", 11 | "soft_format_reward_func", 12 | "axolotl_acecode_transform", 13 | ] 14 | -------------------------------------------------------------------------------- /grpo_code/executor.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | 3 | from grpo_code.wasm import does_code_run, PythonWasmEnvironment 4 | 5 | _executor, worker_env = None, None 6 | 7 | 8 | def get_executor( 9 | max_processes: int, wasm_path: str, fuel: int 10 | ) -> None | ProcessPoolExecutor: 11 | """ 12 | Get the executor for the given number of processes and WASM environment. 13 | For parallel execution, we use a multiprocessing executor, where a pool 14 | of workers, each with their own WASM environment, execute the tasks in parallel. 15 | 16 | For single process execution, we instead initialize a global WASM environment. 17 | 18 | Args: 19 | max_processes (int): The maximum number of processes to use. 20 | wasm_path (str): The path to the .wasm file. 21 | fuel (int): The amount of fuel to use for the WASM environment. 22 | 23 | Returns: 24 | executor (None | ProcessPoolExecutor): If parallel execution is requested, 25 | we return a ProcessPoolExecutor, otherwise we return None. 26 | """ 27 | global _executor, worker_env 28 | if max_processes > 1: 29 | from grpo_code.parallel_executor import get_multiprocessing_executor 30 | 31 | _executor = get_multiprocessing_executor(max_processes, wasm_path, fuel) 32 | return _executor 33 | if worker_env is None: 34 | import grpo_code.wasm as wasm 35 | 36 | wasm.worker_env = PythonWasmEnvironment(wasm_path, fuel) 37 | 38 | 39 | def execute_tasks( 40 | tasks: list[str], max_processes: int, wasm_path: str, fuel: int, task_timeout: int 41 | ): 42 | """ 43 | Run a list of code snippets in a WASM environment. 44 | 45 | Args: 46 | tasks (list[str]): The list of code snippets to run. 47 | max_processes (int): The maximum number of processes to use. 48 | wasm_path (str): The path to the .wasm file. 49 | fuel (int): The amount of fuel to use for the WASM environment. 50 | task_timeout (int): If using multiprocessing, the timeout for each task. 51 | 52 | Returns: 53 | list[bool]: The list of results from running the code snippets. 54 | """ 55 | executor = get_executor(max_processes, wasm_path, fuel) 56 | if max_processes > 1: 57 | from grpo_code.parallel_executor import run_tasks_with_multiprocessing_executor 58 | 59 | return run_tasks_with_multiprocessing_executor(executor, tasks, task_timeout) 60 | else: 61 | return list(map(does_code_run, tasks)) 62 | -------------------------------------------------------------------------------- /grpo_code/parallel_executor.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import signal 3 | import sys 4 | from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor, wait 5 | 6 | from grpo_code.wasm import does_code_run, PythonWasmEnvironment 7 | 8 | _executor = None 9 | 10 | 11 | def worker_init(wasm_path: str, fuel: int): 12 | """ 13 | Initialize a WASM environment for a worker process. 14 | 15 | Args: 16 | wasm_path (str): The path to the .wasm file. 17 | fuel (int): The amount of fuel available to the WASM environment. 18 | """ 19 | import grpo_code.wasm as wasm 20 | 21 | wasm.worker_env = PythonWasmEnvironment(wasm_path, fuel) 22 | 23 | 24 | def cleanup_executor(): 25 | """ 26 | Cleanup any running `ProcessPoolExecutor`. 27 | """ 28 | global _executor 29 | if _executor is not None: 30 | _executor.shutdown(wait=False) 31 | _executor = None 32 | 33 | 34 | def cleanup_and_exit(): 35 | """ 36 | Gracefully shutsdown any running `ProcessPoolExecutor` on 37 | recieving a terminal signal. 38 | """ 39 | cleanup_executor() 40 | sys.exit(0) 41 | 42 | 43 | def get_multiprocessing_executor(max_processes: int, wasm_path: str, fuel: int): 44 | """ 45 | Initialize a reusable `ProcessPoolExecutor` instance. 46 | 47 | Args: 48 | max_processes (int): The maximum number of processes to use. 49 | wasm_path (str): The path to the .wasm file. 50 | fuel (int): The amount of fuel available to the WASM environment. 51 | 52 | Returns: 53 | ProcessPoolExecutor: A `ProcessPoolExecutor` for parallel execution. 54 | """ 55 | global _executor 56 | if _executor is None: 57 | _executor = ProcessPoolExecutor( 58 | max_workers=max_processes, 59 | initializer=worker_init, 60 | initargs=(wasm_path, fuel), 61 | ) 62 | atexit.register(cleanup_executor) 63 | signal.signal(signal.SIGINT, cleanup_and_exit) 64 | signal.signal(signal.SIGTERM, cleanup_and_exit) 65 | return _executor 66 | 67 | 68 | def run_tasks_with_multiprocessing_executor( 69 | executor: ProcessPoolExecutor, tasks: list[str], timeout: int 70 | ): 71 | """ 72 | Run a list of code snippets in parallel by dispatching them to workers 73 | in a `ProcessPoolExecutor`. 74 | 75 | This function will gracefully handle timeout errors caused by workers, and 76 | recreate the `ProcessPoolExecutor` if necessary to avoid interupting training. 77 | 78 | Args: 79 | executor (ProcessPoolExecutor): The executor to use. 80 | tasks (list[str]): The list of code snippets to run. 81 | timeout (int): The timeout for each task. 82 | 83 | Returns: 84 | list[float]: The list of results from running the code snippets. 85 | """ 86 | futures_to_index = { 87 | executor.submit(does_code_run, task): i for i, task in enumerate(tasks) 88 | } 89 | futures = list(futures_to_index) 90 | results = [0.0] * len(tasks) 91 | 92 | while futures: 93 | done, futures = wait(futures, timeout=timeout, return_when=FIRST_EXCEPTION) 94 | 95 | if futures and len(done) == 0: 96 | print( 97 | f"WARNING: Tasks timed out after {timeout} seconds, recreating process pool..." 98 | ) 99 | cleanup_executor() 100 | break 101 | 102 | for future in done: 103 | task_index = futures_to_index[future] 104 | results[task_index] = future.result() 105 | 106 | return results 107 | -------------------------------------------------------------------------------- /grpo_code/rewards.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import re 4 | from pathlib import Path 5 | 6 | import grpo_code 7 | from grpo_code.executor import execute_tasks 8 | 9 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) 10 | MAX_WORKERS = max(1, int(os.environ.get("MAX_WORKERS", 1)) // WORLD_SIZE) 11 | TASK_TIMEOUT = int(os.environ.get("TASK_TIMEOUT", 1)) 12 | WASM_PATH = os.environ.get( 13 | "WASM_PATH", Path(grpo_code.__file__).parent.parent / "wasm" / "python-3.12.0.wasm" 14 | ) 15 | FUEL = int(os.environ.get("FUEL", 1_000_000_000)) 16 | 17 | if not os.path.exists(WASM_PATH): 18 | raise FileNotFoundError(f"WASM file not found at {WASM_PATH}") 19 | 20 | 21 | def extract_xml_answer(text: str) -> str: 22 | """ 23 | Extract text between and tags. 24 | 25 | Args: 26 | text (str): The text to extract the answer from. 27 | Returns: 28 | str: The answer extracted from the text. "" if no answer is found. 29 | 30 | """ 31 | match = re.search(r"(.*?)", text, re.S) 32 | return match.group(1).strip() if match else "" 33 | 34 | 35 | def code_execution_reward_func(completions: list[list[dict]], **kwargs) -> list[float]: 36 | """ 37 | Reward function for code execution. 38 | 39 | Args: 40 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format 41 | [ 42 | [ 43 | {"role": "user", "content": "......"} 44 | ] 45 | ] 46 | Returns: 47 | list[float]: The rewards for the completions. Each completion is rewarded 0.5 if the code executes, -0.25 otherwise. 48 | """ 49 | model_answers = [ 50 | extract_xml_answer(completion[0]["content"]) for completion in completions 51 | ] 52 | task_results = execute_tasks( 53 | model_answers, MAX_WORKERS, WASM_PATH, FUEL, TASK_TIMEOUT 54 | ) 55 | return [0.5 if result == 1.0 else -0.25 for result in task_results] 56 | 57 | 58 | def answer_execution_reward_func( 59 | completions: list[list[dict]], answers: list[list[str]], **kwargs 60 | ) -> list[float]: 61 | """ 62 | Reward function for answer execution. 63 | 64 | Args: 65 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format 66 | [ 67 | [ 68 | {"role": "user", "content": "......"} 69 | ] 70 | ] 71 | answers (list[list[str]]): The expected answers to the code completions. These take the form of executable 72 | assert statements, e.g. 73 | [ 74 | [ 75 | "assert foo(1) == 2", 76 | "assert foo(2) == 3", 77 | ] 78 | ] 79 | Returns: 80 | list[float]: The accuracy rewards for the completions. Each completion is rewarded 81 | (accuracy)^3 * 2, where accuracy is the proportion of test cases that pass. 82 | """ 83 | model_answers = [ 84 | extract_xml_answer(completion[0]["content"]) for completion in completions 85 | ] 86 | tasks = [] 87 | test_indices = [] 88 | for i, (code, tests) in enumerate(zip(model_answers, answers)): 89 | for test in tests: 90 | tasks.append(code + "\n" + test) 91 | test_indices.append(i) 92 | 93 | task_results = execute_tasks(tasks, MAX_WORKERS, WASM_PATH, FUEL, TASK_TIMEOUT) 94 | 95 | completion_results = {} 96 | for idx, result in zip(test_indices, task_results): 97 | if idx not in completion_results: 98 | completion_results[idx] = [] 99 | completion_results[idx].append(result) 100 | 101 | rewards = [] 102 | for i in range(len(completions)): 103 | if i in completion_results: 104 | test_results = completion_results[i] 105 | accuracy = sum(test_results) / len(test_results) 106 | reward = math.pow(accuracy, 3) * 2 107 | else: 108 | reward = 0.0 109 | rewards.append(reward) 110 | return rewards 111 | 112 | 113 | def soft_format_reward_func(completions, **kwargs) -> list[float]: 114 | """ 115 | Reward function for soft format checking. 116 | 117 | Args: 118 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format 119 | [ 120 | [ 121 | {"role": "user", "content": content} 122 | ] 123 | ] 124 | Returns: 125 | list[float]: The rewards for the completions. Each completion is rewarded 0.25 if the format is correct, 0.0 otherwise. 126 | """ 127 | 128 | responses = [completion[0]["content"] for completion in completions] 129 | rewards = [] 130 | for response in responses: 131 | if re.match( 132 | r".*?\s*.*?.*", response, re.S 133 | ): 134 | rewards.append(0.25) 135 | else: 136 | rewards.append(0.0) 137 | return rewards 138 | -------------------------------------------------------------------------------- /grpo_code/transforms.py: -------------------------------------------------------------------------------- 1 | SYSTEM_PROMPT = """ 2 | Respond in the following format: 3 | 4 | ... 5 | 6 | 7 | ... 8 | 9 | 10 | Additionally, you may optionally use the following imports: 11 | 12 | import time 13 | import itertools 14 | from itertools import accumulate, product, permutations, combinations 15 | import collections 16 | from collections import Counter, OrderedDict, deque, defaultdict, ChainMap 17 | from functools import lru_cache 18 | import math 19 | from typing import List, Dict, Tuple, Optional, Any 20 | 21 | If you choose to use any of these imports, ensure they are included 22 | inside the tags, e.g: 23 | 24 | 25 | import time 26 | import math 27 | ... 28 | 29 | 30 | You may not utilise any other imports or filesystem operations. 31 | """ 32 | 33 | 34 | def axolotl_acecode_transform(cfg, *args, **kwargs): 35 | def transform_fn(example, tokenizer=None): 36 | return { 37 | "prompt": [ 38 | { 39 | "role": "user", 40 | "content": example["question"] + "\n\n" + SYSTEM_PROMPT, 41 | } 42 | ], 43 | "answers": example["test_cases"], 44 | } 45 | 46 | return transform_fn, {"remove_columns": ["question", "test_cases"]} 47 | -------------------------------------------------------------------------------- /grpo_code/wasm.py: -------------------------------------------------------------------------------- 1 | from wasmtime import Config, Engine, Linker, Module, Store, WasiConfig 2 | 3 | worker_env = None 4 | 5 | 6 | class PythonWasmEnvironment: 7 | """A reusable WASM environment for running Python code. 8 | 9 | Args: 10 | wasm_path (str): The path to the .wasm file. 11 | fuel (int): The amount of fuel to use for the WASM environment. 12 | """ 13 | 14 | def __init__(self, wasm_path: str, fuel: int): 15 | self.wasm_path = wasm_path 16 | self.fuel = fuel 17 | 18 | # Set up the engine and linker 19 | engine_cfg = Config() 20 | engine_cfg.consume_fuel = True 21 | engine_cfg.cache = True 22 | 23 | self.engine = Engine(engine_cfg) 24 | self.linker = Linker(self.engine) 25 | self.linker.define_wasi() 26 | 27 | # Load the Python module 28 | self.python_module = Module.from_file(self.engine, self.wasm_path) 29 | 30 | def run_code(self, code: str): 31 | """Run Python code in the WASM environment subject to fuel limits. 32 | 33 | Args: 34 | code (str): The Python code to run. 35 | 36 | """ 37 | config = WasiConfig() 38 | config.argv = ("python", "-c", code) 39 | config.inherit_env = False 40 | 41 | store = Store(self.engine) 42 | store.set_fuel(self.fuel) 43 | store.set_wasi(config) 44 | 45 | instance = self.linker.instantiate(store, self.python_module) 46 | start = instance.exports(store)["_start"] 47 | start(store) 48 | 49 | 50 | def does_code_run(code: str) -> bool: 51 | """Execute code in the worker's WASM environment and check if it runs without errors. 52 | 53 | Args: 54 | code (str): The Python code to run. 55 | 56 | Returns: 57 | bool: True if the code runs without errors, False otherwise. 58 | """ 59 | global worker_env 60 | try: 61 | worker_env.run_code(code) 62 | return True 63 | except Exception: 64 | return False 65 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "grpo_code" 3 | dependencies = [ 4 | "wasmtime", 5 | ] 6 | version = "0.1.0" 7 | 8 | [build-system] 9 | requires = ["setuptools>=64", "wheel"] 10 | build-backend = "setuptools.build_meta" 11 | 12 | 13 | [tool.setuptools.packages.find] 14 | where = [""] 15 | include = ["grpo_code*"] 16 | -------------------------------------------------------------------------------- /r1_acecode.yaml: -------------------------------------------------------------------------------- 1 | base_model: Qwen/Qwen2.5-3B-Instruct 2 | # Automatically upload checkpoint and final model to HF 3 | # hub_model_id: username/custom_model_name 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | torch_compile: true 10 | 11 | vllm: 12 | host: 0.0.0.0 13 | port: 8000 14 | tensor_parallel_size: 2 15 | gpu_memory_utilization: 0.85 16 | dtype: auto 17 | 18 | rl: grpo 19 | trl: 20 | beta: 0.001 21 | use_vllm: true 22 | vllm_server_host: 0.0.0.0 23 | vllm_server_port: 8000 24 | vllm_server_timeout: 300 25 | reward_funcs: 26 | - grpo_code.soft_format_reward_func 27 | - grpo_code.code_execution_reward_func 28 | - grpo_code.answer_execution_reward_func 29 | 30 | num_generations: 16 31 | max_completion_length: 512 32 | log_completions: false 33 | 34 | chat_template: qwen_25 35 | datasets: 36 | - path: axolotl-ai-co/AceCode-87K 37 | type: grpo_code.axolotl_acecode_transform 38 | split: train 39 | 40 | dataset_prepared_path: /workspace/data/last_run_prepared 41 | dataset_processes: 42 | skip_prepare_dataset: true 43 | val_set_size: 0.0 44 | output_dir: /workspace/data/axolotl-artifacts/r1-outputs/1403 45 | 46 | dataloader_prefetch_factor: 32 47 | dataloader_num_workers: 2 48 | dataloader_pin_memory: true 49 | 50 | gc_steps: 1 51 | sequence_len: 1024 52 | sample_packing: false 53 | eval_sample_packing: false 54 | pad_to_sequence_len: false 55 | 56 | gradient_accumulation_steps: 2 57 | micro_batch_size: 32 58 | num_epochs: 1 59 | max_steps: 2500 60 | 61 | 62 | optimizer: adamw_torch_fused 63 | lr_scheduler: warmup_stable_decay 64 | lr_scheduler_kwargs: 65 | num_stable_steps: 1500 66 | num_decay_steps: 500 67 | min_lr_ratio: 0.1 68 | num_cycles: 0.5 69 | 70 | learning_rate: 5.3e-6 71 | max_grad_norm: 1.0 72 | 73 | train_on_inputs: false 74 | group_by_length: false 75 | 76 | bf16: true 77 | tf32: true 78 | early_stopping_patience: 79 | resume_from_checkpoint: 80 | local_rank: 81 | logging_steps: 1 82 | gradient_checkpointing: true 83 | gradient_checkpointing_kwargs: 84 | use_reentrant: false 85 | flash_attention: true 86 | 87 | warmup_steps: 500 88 | evals_per_epoch: 0 89 | saves_per_epoch: 0 90 | save_steps: 0.5 91 | 92 | 93 | # wandb_project: 94 | # wandb_entity: 95 | # wandb_name: passk 96 | # hub_model_id: 97 | -------------------------------------------------------------------------------- /wasm/python-3.12.0.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axolotl-ai-cloud/grpo_code/148ea79321f34bbed79b3b55f04c0a7de002665d/wasm/python-3.12.0.wasm -------------------------------------------------------------------------------- /wasm/python-3.12.0.wasm.sha256sum: -------------------------------------------------------------------------------- 1 | e5dc5a398b07b54ea8fdb503bf68fb583d533f10ec3f930963e02b9505f7a763 wasm/python-3.12.0.wasm 2 | --------------------------------------------------------------------------------