├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── eval_plus
├── convert_data.py
├── data
│ ├── HumanEval.jsonl
│ ├── HumanEvalPlus-v0.1.10.jsonl
│ ├── HumanEvalPlus-v0.1.9.jsonl
│ └── MbppPlus-v0.1.0.jsonl
├── exclude_patterns.txt
├── generate.py
├── model.py
├── readme.md
├── requirements.txt
└── test.sh
├── grpo_code
├── __init__.py
├── executor.py
├── parallel_executor.py
├── rewards.py
├── transforms.py
└── wasm.py
├── pyproject.toml
├── r1_acecode.yaml
└── wasm
├── python-3.12.0.wasm
└── python-3.12.0.wasm.sha256sum
/.gitignore:
--------------------------------------------------------------------------------
1 | **/axolotl.egg-info
2 | configs
3 | last_run_prepared/
4 | outputs
5 | .vscode
6 | _site/
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 | venv3.10/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | .idea/
169 |
170 | # WandB
171 | # wandb creates a folder to store logs for training runs
172 | wandb
173 |
174 | # Runs
175 | lora-out/*
176 | qlora-out/*
177 | mlruns/*
178 |
179 | /.quarto/
180 | prepared-datasets/
181 | submit.sh
182 | *.out*
183 |
184 | typings/
185 | out/
186 |
187 | # vim
188 | *.swp
189 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axolotl-ai-cloud/grpo_code/148ea79321f34bbed79b3b55f04c0a7de002665d/CONTRIBUTING.md
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | > [!NOTE]
3 | > Check out our [blog-post](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) for more detail and benchmarks!
4 |
5 | ## Installation
6 |
7 | ```bash
8 | git clone https://github.com/axolotl-ai-cloud/grpo_code.git
9 | cd grpo_code
10 | pip install -e .
11 | pip install axolotl==0.8.0[vllm,flash-attn]
12 | ```
13 |
14 | ## Training
15 |
16 | The following environment variables can be used to modify the behaviour of the reward functions:
17 | - `WASM_FUEL` - Controls the amount of fuel (computation resources) allocated to the WASM environment (default: 10000000000)
18 | - `WASM_PATH` - Path to the Python WASM runtime file (default: "./wasm/python-3.12.0.wasm")
19 | - `TIMEOUT` - Maximum execution time in seconds for code evaluation (default: 1)
20 | - `MAX_WORKERS` - Number of parallel workers for multiprocessing reward functions (default: 1)
21 |
22 | First, spin up a `vLLM` instance:
23 |
24 | ```bash
25 | CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve r1_acecode.yaml
26 | ```
27 |
28 | Then, in another terminal, kick off the training process:
29 |
30 | ```bash
31 | CUDA_VISIBLE_DEVICES=0,1 MAX_WORKERS=64 axolotl train r1_acecode.yaml --num-processes 2
32 | ```
33 |
34 | This example uses 4 A100 GPUs - adjust `CUDA_VISIBLE_DEVICES`, `MAX_WORKERS`, `cfg.micro_batch_size` and `cfg.gradient_accumulation_steps` as necessary to match your hardware.
35 |
36 | ## Python WASM Runtime
37 |
38 | This project uses Python 3.12.0 compiled to WebAssembly from VMware Labs.
39 |
40 | ### Verify an Existing Download
41 | If you already have the WASM file and want to verify its integrity:
42 |
43 | 1. Ensure you have both `python-3.12.0.wasm` and `python-3.12.0.wasm.sha256sum` in the `wasm` directory.
44 | 2. Run the verification command:
45 |
46 | **Linux/macOS:**
47 | ```bash
48 | sha256sum -c ./wasm/python-3.12.0.wasm.sha256sum
49 | ```
50 |
51 | ### Manual Download
52 | To download the runtime files yourself:
53 |
54 | 1. Download the Python WASM runtime:
55 | ```bash
56 | curl -LO https://github.com/vmware-labs/webassembly-language-runtimes/releases/download/python%2F3.12.0%2B20231211-040d5a6/python-3.12.0.wasm -o ./wasm/python-3.12.0.wasm
57 | ```
58 |
59 | 2. Download the SHA256 checksum file:
60 | ```bash
61 | curl -LO https://github.com/vmware-labs/webassembly-language-runtimes/releases/download/python%2F3.12.0%2B20231211-040d5a6/python-3.12.0.wasm.sha256sum -o ./wasm/python-3.12.0.wasm.sha256sum
62 | ```
63 |
64 | 3. Verify the download:
65 | ```bash
66 | sha256sum -c ./wasm/python-3.12.0.wasm.sha256sum
67 | ```
68 |
69 | 4. Place both files in your project directory or specify the path in your configuration.
70 |
--------------------------------------------------------------------------------
/eval_plus/convert_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jsonlines
3 | import os
4 | import sys
5 |
6 | eval_plus_path = os.path.dirname(os.path.abspath(__file__)) + "/evalplus/"
7 | sys.path = [eval_plus_path] + sys.path
8 | from evalplus.data import get_human_eval_plus, get_mbpp_plus
9 |
10 | MBPP_OUTPUT_SET_EQ_TASKS = [
11 | "similar_elements", # Mbpp/2
12 | "find_char_long", # Mbpp/7
13 | "common_in_nested_lists", # Mbpp/111
14 | "extract_singly", # Mbpp/140
15 | "larg_nnum", # Mbpp/232
16 | "intersection_array", # Mbpp/249
17 | "find_dissimilar", # Mbpp/579
18 | "Diff", # Mbpp/769
19 | ]
20 | MBPP_OUTPUT_NOT_NONE_TASKS = ["check_str", "text_match_three", "text_starta_endb"]
21 |
22 |
23 | def convert_file(root_dir):
24 | import jsonlines
25 | from copy import deepcopy
26 | import sys
27 | import tqdm
28 | import re
29 | sys.set_int_max_str_digits(10000000)
30 |
31 | def write_jsonl_file(objs, target_path):
32 | os.makedirs(os.path.dirname(target_path), exist_ok=True)
33 | with jsonlines.open(target_path, "w") as w:
34 | for obj in objs:
35 | w.write(obj)
36 | print(f"Successfully saving to {target_path}")
37 |
38 | # def get_humaneval_prompt(doc, language):
39 | # language = language.lower()
40 | # question = doc["prompt"].strip()
41 | # return """
42 | # Please continue to complete the function and return all completed code in a codeblock. Here is the given code to do completion:
43 | # ```{}
44 | # {}
45 | # ```
46 | # """.strip().format(
47 | # language.lower(), question.strip()
48 | # )
49 |
50 | def get_prompt(doc, language):
51 | language = language.lower()
52 | question = doc["prompt"].strip()
53 | return """
54 | Can you complete the following Python function?
55 | ```{}
56 | {}
57 | ```
58 | """.strip().format(language.lower(), question.strip())
59 |
60 | def create_high_accuracy_function(code, entry_point):
61 | high_accuracy = """
62 | from decimal import Decimal, getcontext
63 | from functools import wraps
64 | getcontext().prec = 100
65 |
66 | def convert_to_decimal(value):
67 | if isinstance(value, float):
68 | return Decimal(str(value))
69 | elif isinstance(value, list):
70 | return [convert_to_decimal(item) for item in value]
71 | elif isinstance(value, dict):
72 | return {k: convert_to_decimal(v) for k, v in value.items()}
73 | return value
74 |
75 | def float_to_decimal(func):
76 | @wraps(func)
77 | def wrapper(*args, **kwargs):
78 | new_args = [convert_to_decimal(arg) for arg in args]
79 | new_kwargs = {k: convert_to_decimal(v) for k, v in kwargs.items()}
80 | result = func(*new_args, **new_kwargs)
81 | return result
82 | return wrapper
83 |
84 | def convert_to_float(value):
85 | if isinstance(value, Decimal):
86 | return float(value)
87 | elif isinstance(value, list):
88 | return [convert_to_float(item) for item in value]
89 | elif isinstance(value, dict):
90 | return {k: convert_to_float(v) for k, v in value.items()}
91 | return value
92 |
93 | def decimal_to_float(func):
94 | @wraps(func)
95 | def wrapper(*args, **kwargs):
96 | # Execute the wrapped function
97 | result = func(*args, **kwargs)
98 |
99 | # Convert the result back to float, if necessary
100 | result = convert_to_float(result)
101 | return result
102 | return wrapper
103 | """
104 | """Execute trusted code in place."""
105 | code = high_accuracy + code
106 | code = code.split("\n")
107 | new_code = []
108 | cnt = 0
109 | for c in code:
110 | if re.search(rf"def {entry_point}\(.*?\)", c) is not None:
111 | cnt += 1
112 | new_code.append("@float_to_decimal")
113 | new_code.append("@decimal_to_float")
114 | new_code.append(c)
115 | code = "\n".join(new_code)
116 | return code
117 |
118 | def trusted_exec(code, inputs, entry_point, record_time=False, output_not_none=False):
119 | exec_globals = {}
120 | # if entry_point not in ["triangle_area", "angle_complex", "volume_sphere"]: # avoid special case (a ** b)
121 | # code = create_high_accuracy_function(code, entry_point)
122 | if "**" not in code and entry_point not in ["triangle_area", "angle_complex", "volume_sphere"]:
123 | code = create_high_accuracy_function(code, entry_point)
124 | #print(code)
125 | exec(code, exec_globals)
126 | fn = exec_globals[entry_point]
127 |
128 | rtime = []
129 | ret = []
130 | for inp in inputs:
131 | inp = deepcopy(inp)
132 | if record_time:
133 | start = time.time()
134 | ret.append(fn(*inp))
135 | rtime.append(time.time() - start)
136 | else:
137 | ret.append(fn(*inp))
138 |
139 | if output_not_none:
140 | ret = [i is not None for i in ret]
141 |
142 | if record_time:
143 | return ret, rtime
144 | else:
145 | return ret
146 |
147 | def convert(objs, test_set="base_input", task_name=f"evalplus/humaneval"):
148 | type
149 | data = []
150 | for obj in tqdm.tqdm(objs):
151 | prompt = get_prompt(obj, language="python")
152 | if test_set == "base_input":
153 | inputs = obj["base_input"]
154 | else:
155 | inputs = obj["base_input"] + obj["plus_input"] if not isinstance(obj["plus_input"], dict) else obj["base_input"]
156 | #outputs = trusted_exec(code = obj["prompt"] + obj["canonical_solution"], inputs = obj["base_input"], entry_point = obj["entry_point"])
157 | #tests = create_check_function(test_cases = inputs, entry_point=obj["entry_point"], outputs = outputs)
158 | outputs = trusted_exec(code=obj["prompt"] + obj["canonical_solution"], inputs=[obj["base_input"][0]], entry_point=obj["entry_point"])
159 | atol = obj["atol"]
160 | if atol == 0:
161 | atol = 1e-6 # enforce atol for float comparison
162 | #```python
163 | if obj["entry_point"] == "find_zero": #humaneval
164 | tests = create_dynamic_check_function_find_zero(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol)
165 | elif obj["entry_point"] in MBPP_OUTPUT_NOT_NONE_TASKS:
166 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="not_none", atol=atol)
167 | elif obj["entry_point"] in MBPP_OUTPUT_SET_EQ_TASKS: # mbpp
168 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="set", atol=atol)
169 | elif obj["entry_point"] == "are_equivalent": # mbpp
170 | tests = create_dynamic_check_function_are_equivalent(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol)
171 | elif obj["entry_point"] == "sum_div": # mbpp
172 | tests = create_dynamic_check_function_sum_div(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], atol=atol)
173 | elif isinstance(outputs[0], float) or (isinstance(outputs[0], list) and len(outputs[0]) > 0 and isinstance(outputs[0][0], float)):
174 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="np.allcose", atol=atol)
175 | else:
176 | tests = create_dynamic_check_function(test_cases=inputs, entry_point=obj["entry_point"], prompt=obj["prompt"], correct_solution=obj["canonical_solution"], check_style="==", atol=atol)
177 | data.append({
178 | "prompt": prompt,
179 | "test": tests,
180 | "entry_point": obj["entry_point"],
181 | "tags": f"coding,en,python,core",
182 | "task": task_name,
183 | "source": f"evalplus",
184 | "eval_args": {
185 | "greedy": True,
186 | #"seed": 1234,
187 | "out_seq_length": 1024,
188 | "repetition_penalty": 1.0,
189 | "temperature": 1.0,
190 | "top_k": -1,
191 | "top_p": 0.95,
192 | "presence_penalty": 0,
193 | "system_str": "You are an intelligent programming assistant to produce Python algorithmic solutions",
194 | },
195 | "extra_response_prefix": "```python\n"
196 | })
197 | return data
198 |
199 | def create_check_function(test_cases, entry_point, outputs):
200 | test_cases_str = "def check():\n"
201 | for case, output in zip(test_cases, outputs):
202 | for i in range(len(case)):
203 | if isinstance(case[i], str) and "\n" in case[i]:
204 | case[i] = case[i].replace("\n", "\\n")
205 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case])
206 | output = str(output) if not isinstance(output, str) else f"'{output}'"
207 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {output}\n"
208 | test_cases_str += single_test_case_str
209 | test_cases_str += "check()"
210 | return test_cases_str
211 |
212 | def create_dynamic_check_function_are_equivalent(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0):
213 | test_cases_str = "import numpy as np\n" + prompt + correct_solution
214 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(")
215 | test_cases_str += "def check():\n"
216 | for case in test_cases:
217 | for i in range(len(case)):
218 | if isinstance(case[i], str) and "\n" in case[i]:
219 | case[i] = case[i].replace("\n", "\\n")
220 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case])
221 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params}) or {entry_point}({input_params}) == 0\n"
222 | test_cases_str += single_test_case_str
223 | test_cases_str += "check()"
224 | return test_cases_str
225 |
226 | def create_dynamic_check_function_sum_div(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0):
227 | test_cases_str = "import numpy as np\n" + prompt + correct_solution
228 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(")
229 | test_cases_str += "def check():\n"
230 | for case in test_cases:
231 | for i in range(len(case)):
232 | if isinstance(case[i], str) and "\n" in case[i]:
233 | case[i] = case[i].replace("\n", "\\n")
234 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case])
235 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params}) or {entry_point}({input_params}) == 0\n"
236 | test_cases_str += single_test_case_str
237 | test_cases_str += "check()"
238 | return test_cases_str
239 |
240 | def create_dynamic_check_function(test_cases, entry_point, prompt, correct_solution, check_style="np.allclose", atol=0):
241 | test_cases_str = "import numpy as np\n" + prompt + correct_solution
242 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(")
243 | test_cases_str += "def check():\n"
244 | for case in test_cases:
245 | for i in range(len(case)):
246 | if isinstance(case[i], str) and "\n" in case[i]:
247 | case[i] = case[i].replace("\n", "\\n")
248 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case])
249 | if check_style == "np.allcose":
250 | single_test_case_str = f"\tassert np.allclose({entry_point}({input_params}), {entry_point}_ground_truth({input_params}), rtol=1e-07, atol={atol})\n"
251 | elif check_style == "==":
252 | single_test_case_str = f"\tassert {entry_point}({input_params}) == {entry_point}_ground_truth({input_params})\n"
253 | elif check_style == "set":
254 | single_test_case_str = f"\tassert set({entry_point}({input_params})) == set({entry_point}_ground_truth({input_params}))\n"
255 | elif check_style == "not_none":
256 | single_test_case_str = f"\tif isinstance({entry_point}({input_params}), bool):\n"
257 | single_test_case_str += f"\t\tassert {entry_point}({input_params}) == ({entry_point}_ground_truth({input_params}) is not None)\n"
258 | single_test_case_str += f"\telse:\n"
259 | single_test_case_str += f"\t\tassert ({entry_point}({input_params}) is not None) == ({entry_point}_ground_truth({input_params}) is not None)\n"
260 | test_cases_str += single_test_case_str
261 | test_cases_str += "check()"
262 | return test_cases_str
263 |
264 | def create_dynamic_check_function_find_zero(test_cases, entry_point, prompt, correct_solution, atol=0):
265 | test_cases_str = "import numpy as np\n" + prompt + correct_solution
266 | test_cases_str = test_cases_str.replace(f"def {entry_point}(", f"def {entry_point}_ground_truth(")
267 | test_cases_str += "def check():\n"
268 | for case in test_cases:
269 | for i in range(len(case)):
270 | if isinstance(case[i], str) and "\n" in case[i]:
271 | case[i] = case[i].replace("\n", "\\n")
272 | input_params = ", ".join([str(c) if not isinstance(c, str) else f"'{c}'" for c in case])
273 | single_test_case_str = f"\tassert abs(poly({input_params}, {entry_point}({input_params}))) <= {atol}\n"
274 | test_cases_str += single_test_case_str
275 | test_cases_str += "check()"
276 | return test_cases_str
277 |
278 | humaneval_data = get_human_eval_plus()
279 | data1 = convert(humaneval_data.values(), test_set="base_input", task_name="evalplus/humaneval")
280 | write_jsonl_file(data1, f"{root_dir}/evalplus_v2/humaneval.jsonl")
281 | data2 = convert(humaneval_data.values(), test_set="plus_input", task_name="evalplus/humaneval_plus")
282 | write_jsonl_file(data2, f"{root_dir}/evalplus_v2/humaneval_plus.jsonl")
283 |
284 | mbpp_data = get_mbpp_plus()
285 | data3 = convert(mbpp_data.values(), test_set="base_input", task_name="evalplus/mbpp")
286 | write_jsonl_file(data3, f"{root_dir}/evalplus_v2/mbpp.jsonl")
287 | data4 = convert(mbpp_data.values(), test_set="plus_input", task_name="evalplus/mbpp_plus")
288 | write_jsonl_file(data4, f"{root_dir}/evalplus_v2/mbpp_plus.jsonl")
289 |
290 | all_data = data1 + data2 + data3 + data4
291 | write_jsonl_file(all_data, f"{root_dir}/evalplus_v2/evalplus.jsonl")
292 |
293 | all_data = np.random.choice(all_data, 10)
294 | write_jsonl_file(all_data, f"{root_dir}/evalplus_v2/evalplus.jsonl.sampled")
295 |
296 |
297 | if __name__ == "__main__":
298 | convert_file(root_dir="data/eval/code/")
299 |
--------------------------------------------------------------------------------
/eval_plus/exclude_patterns.txt:
--------------------------------------------------------------------------------
1 | .venv
2 | __pycache__
3 | output*/
4 |
--------------------------------------------------------------------------------
/eval_plus/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import PathLike
4 | import sys
5 | eval_plus_path = os.path.dirname(os.path.abspath(__file__)) + "/evalplus/"
6 | sys.path = [eval_plus_path] + sys.path
7 | from model import DecoderBase, make_model
8 | from rich.progress import (
9 | BarColumn,
10 | MofNCompleteColumn,
11 | Progress,
12 | TextColumn,
13 | TimeElapsedColumn,
14 | )
15 |
16 |
17 | MODEL_MAPPING = {
18 | # Can be either repo's name or /path/to/model
19 | "codeqwen": {
20 | "base": "Qwen/CodeQwen1.5-7B",
21 | "chat": "Qwen/CodeQwen1.5-7B-Chat",
22 | "chat-awq": "Qwen/CodeQwen1.5-7B-Chat-AWQ",
23 | },
24 | "qwen2": {
25 | "chat": "Qwen/CodeQwen1.5-7B-Chat",
26 | },
27 | }
28 |
29 |
30 | def construct_contract_prompt(prompt: str, contract_type: str, contract: str) -> str:
31 | if contract_type == "none":
32 | return prompt
33 | elif contract_type == "docstring":
34 | # embed within the docstring
35 | sep = ""
36 | if '"""' in prompt:
37 | sep = '"""'
38 | elif "'''" in prompt:
39 | sep = "'''"
40 | assert sep != ""
41 | l = prompt.split(sep)
42 | contract = "\n".join([x.split("#")[0] for x in contract.splitlines()])
43 | l[1] = l[1] + contract + "\n" + " " * (len(contract) - len(contract.lstrip()) - 1)
44 | return sep.join(l)
45 | elif contract_type == "code":
46 | # at the beginning of the function
47 | contract = "\n".join([x.split("#")[0] for x in contract.splitlines()])
48 | return prompt + contract
49 |
50 |
51 | def code_generate(args, workdir: PathLike, model: DecoderBase, id_range=None):
52 | with Progress(
53 | TextColumn(f"{args.dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%"),
54 | BarColumn(),
55 | MofNCompleteColumn(),
56 | TextColumn("•"),
57 | TimeElapsedColumn(),
58 | ) as p:
59 | if args.dataset == "humaneval":
60 | from evalplus.data import get_human_eval_plus
61 | dataset = get_human_eval_plus()
62 | elif args.dataset == "mbpp":
63 | from evalplus.data import get_mbpp_plus
64 | dataset = get_mbpp_plus()
65 |
66 | for task_id, task in p.track(dataset.items()):
67 | if id_range is not None:
68 | id_num = int(task_id.split("/")[1])
69 | low, high = id_range
70 | if id_num < low or id_num >= high:
71 | p.console.print(f"Skipping {task_id} as it is not in {id_range}")
72 | continue
73 |
74 | p_name = task_id.replace("/", "_")
75 | if args.contract_type != "none" and task["contract"] == "":
76 | continue
77 | os.makedirs(os.path.join(workdir, p_name), exist_ok=True)
78 | log = f"Codegen: {p_name} @ {model}"
79 | n_existing = 0
80 | if args.resume:
81 | # count existing .py files
82 | n_existing = len([f for f in os.listdir(os.path.join(workdir, p_name)) if f.endswith(".py")])
83 | if n_existing > 0:
84 | log += f" (resuming from {n_existing})"
85 |
86 | nsamples = args.n_samples - n_existing
87 | p.console.print(log)
88 |
89 | sidx = args.n_samples - nsamples
90 | while sidx < args.n_samples:
91 | model.dataset = args.dataset
92 | outputs = model.codegen(
93 | construct_contract_prompt(task["prompt"], args.contract_type, task["contract"]).strip(),
94 | do_sample=not args.greedy,
95 | num_samples=args.n_samples - sidx,
96 | )
97 | assert outputs, "No outputs from model!"
98 | for impl in outputs:
99 | if "```" in impl:
100 | impl = impl.split("```")[0]
101 | print("``` exist in generation. Please check the generation results.")
102 |
103 | try:
104 | with open(
105 | os.path.join(workdir, p_name, f"{sidx}.py"),
106 | "w",
107 | encoding="utf-8",
108 | ) as f:
109 | if model.direct_completion:
110 | f.write(task["prompt"] + impl)
111 | else:
112 | f.write(impl)
113 | except UnicodeEncodeError:
114 | continue
115 | sidx += 1
116 |
117 |
118 | def main():
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument("--model_type", required=True, type=str, choices=MODEL_MAPPING.keys())
121 | parser.add_argument("--model_path", type=str, default=None)
122 | parser.add_argument("--model_size", required=True, type=str)
123 | parser.add_argument("--bs", default=1, type=int)
124 | parser.add_argument("--temperature", default=0.0, type=float)
125 | parser.add_argument("--dataset", required=True, type=str, choices=["humaneval", "mbpp"])
126 | parser.add_argument("--root", type=str, required=True)
127 | parser.add_argument("--n_samples", default=1, type=int)
128 | parser.add_argument("--resume", action="store_true")
129 | parser.add_argument("--output", type=str)
130 | parser.add_argument("--tensor-parallel-size", default=1, type=int)
131 | parser.add_argument(
132 | "--contract-type",
133 | default="none",
134 | type=str,
135 | choices=["none", "code", "docstring"],
136 | )
137 | parser.add_argument("--greedy", action="store_true")
138 | # id_range is list
139 | parser.add_argument("--id-range", default=None, nargs="+", type=int)
140 | args = parser.parse_args()
141 | print(args)
142 | assert args.model_size in MODEL_MAPPING[args.model_type]
143 |
144 | model_path = MODEL_MAPPING[args.model_type][args.model_size]
145 | if args.model_path is not None:
146 | model_path = args.model_path
147 | print(f"Loading model from {model_path}")
148 |
149 | print(f"Running model={args.model_type}, size={args.model_size}")
150 | print(f"\tLoad from `{model_path}`")
151 |
152 | if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1):
153 | args.temperature = 0
154 | args.bs = 1
155 | args.n_samples = 1
156 | print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")
157 |
158 | if args.id_range is not None:
159 | assert len(args.id_range) == 2, "id_range must be a list of length 2"
160 | assert args.id_range[0] < args.id_range[1], "id_range must be increasing"
161 | args.id_range = tuple(args.id_range)
162 |
163 | # Make project dir
164 | os.makedirs(args.root, exist_ok=True)
165 | # Make dataset dir
166 | os.makedirs(os.path.join(args.root, args.dataset), exist_ok=True)
167 | # Make dir for codes generated by each model
168 |
169 | model = make_model(
170 | model_type=args.model_type,
171 | model_size=args.model_size,
172 | model_path=model_path,
173 | batch_size=args.bs,
174 | temperature=args.temperature,
175 | dataset=args.dataset,
176 | tensor_parallel_size=args.tensor_parallel_size
177 | )
178 | workdir = os.path.join(
179 | args.root,
180 | args.dataset,
181 | args.model_type
182 | + f"_{args.model_size}"
183 | + f"_temp_{args.temperature}"
184 | + ("" if args.contract_type == "none" else f"-contract-{args.contract_type}"),
185 | )
186 | os.makedirs(workdir, exist_ok=True)
187 | print(f"Working dir: {workdir}")
188 |
189 | with open(os.path.join(workdir, "args.txt"), "w") as f:
190 | f.write(str(args))
191 |
192 | print(f"Model cls: {model.__class__}")
193 | print(f"EOS tokens: {model.eos}")
194 | code_generate(args, workdir=workdir, model=model, id_range=args.id_range)
195 |
196 |
197 | if __name__ == "__main__":
198 | main()
199 |
--------------------------------------------------------------------------------
/eval_plus/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC, abstractmethod
3 | from typing import List
4 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
5 | os.environ["HF_HOME"] = os.environ.get("HF_HOME", "./hf_home")
6 |
7 | import torch
8 | from stop_sequencer import StopSequencer
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 | from vllm import LLM, SamplingParams
11 |
12 |
13 | EOS = [
14 | "<|endoftext|>",
15 | "<|endofmask|>",
16 | "",
17 | "\nif __name__",
18 | "\ndef main(",
19 | "\nprint(",
20 | "\n#"
21 | ]
22 |
23 |
24 | class DecoderBase(ABC):
25 | def __init__(
26 | self,
27 | name: str,
28 | batch_size: int = 1,
29 | temperature: float = 0.8,
30 | max_new_tokens: int = 512,
31 | direct_completion: bool = True,
32 | dtype: str = "bfloat16", # default
33 | trust_remote_code: bool = False,
34 | dataset: str = None,
35 | ) -> None:
36 | print("Initializing a decoder model: {} ...".format(name))
37 | self.name = name
38 | self.batch_size = batch_size
39 | self.temperature = temperature
40 | self.eos = EOS
41 | self.skip_special_tokens = False
42 | self.max_new_tokens = max_new_tokens
43 | self.direct_completion = direct_completion
44 | self.dtype = dtype
45 | self.trust_remote_code = trust_remote_code
46 |
47 | if direct_completion:
48 | if dataset.lower() == "humaneval":
49 | self.eos += ["\ndef", "\nclass ", "\nimport ", "\nfrom ", "\nassert "]
50 | elif dataset.lower() == "mbpp":
51 | self.eos += ['\n"""', "\nassert"]
52 |
53 | @abstractmethod
54 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
55 | pass
56 |
57 | def __repr__(self) -> str:
58 | return self.name
59 |
60 | def __str__(self) -> str:
61 | return self.name
62 |
63 |
64 | class VLlmDecoder(DecoderBase):
65 | def __init__(self, name: str, tensor_parallel_size = 1, **kwargs) -> None:
66 | super().__init__(name, **kwargs)
67 |
68 | kwargs = {
69 | "tensor_parallel_size": tensor_parallel_size, #int(os.getenv("VLLM_N_GPUS", "1"))
70 | "dtype": self.dtype,
71 | "trust_remote_code": self.trust_remote_code,
72 | "enforce_eager": True,
73 | "gpu_memory_utilization": 0.7
74 | }
75 | print(kwargs)
76 | self.llm = LLM(model=name, max_model_len=1536, **kwargs)
77 |
78 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
79 | if do_sample:
80 | assert self.temperature > 0, "Temperature must be greater than 0!"
81 | batch_size = min(self.batch_size, num_samples)
82 |
83 | vllm_outputs = self.llm.generate(
84 | [prompt] * batch_size,
85 | SamplingParams(
86 | temperature=self.temperature,
87 | max_tokens=self.max_new_tokens,
88 | top_p=0.95 if do_sample else 1.0,
89 | stop=self.eos,
90 | ),
91 | use_tqdm=False,
92 | )
93 |
94 | gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs]
95 | return gen_strs
96 |
97 |
98 | class VLlmAWQDecoder(DecoderBase):
99 | def __init__(self, name: str, **kwargs) -> None:
100 | super().__init__(name, **kwargs)
101 |
102 | kwargs = {
103 | "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", "1")),
104 | "dtype": torch.float16,
105 | "trust_remote_code": self.trust_remote_code,
106 | "quantization": "AWQ",
107 | }
108 |
109 | self.llm = LLM(model=name, max_model_len=2048, **kwargs)
110 |
111 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
112 | if do_sample:
113 | assert self.temperature > 0, "Temperature must be greater than 0!"
114 | batch_size = min(self.batch_size, num_samples)
115 |
116 | vllm_outputs = self.llm.generate(
117 | [prompt] * batch_size,
118 | SamplingParams(
119 | temperature=self.temperature,
120 | max_tokens=self.max_new_tokens,
121 | top_p=0.95 if do_sample else 1.0,
122 | stop=self.eos,
123 | ),
124 | use_tqdm=False,
125 | )
126 |
127 | gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs]
128 | return gen_strs
129 |
130 |
131 | class AWQChatML(VLlmAWQDecoder):
132 | def __init__(self, name: str, tensor_parallel_size, **kwargs) -> None:
133 | kwargs["direct_completion"] = False
134 | super().__init__(name, **kwargs)
135 | self.eos += ["\n```"]
136 |
137 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
138 | if do_sample:
139 | assert self.temperature > 0, "Temperature must be greater than 0!"
140 |
141 | input = f"""<|im_start|>system
142 | You are an intelligent programming assistant to produce Python algorithmic solutions<|im_end|>
143 | <|im_start|>user
144 | Can you complete the following Python function?
145 | ```python
146 | {prompt}
147 | ```
148 | <|im_end|>
149 | <|im_start|>assistant
150 | ```python
151 | """
152 | return VLlmDecoder.codegen(self, input, do_sample, num_samples)
153 |
154 |
155 | class ChatML(VLlmDecoder):
156 | def __init__(self, name: str, tensor_parallel_size, **kwargs) -> None:
157 | kwargs["direct_completion"] = False
158 | super().__init__(name, tensor_parallel_size, **kwargs)
159 | self.eos += ["\n```"]
160 |
161 | def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
162 | if do_sample:
163 | assert self.temperature > 0, "Temperature must be greater than 0!"
164 |
165 | input = f"""<|im_start|>system
166 | You are an intelligent programming assistant to produce Python algorithmic solutions<|im_end|>
167 | <|im_start|>user
168 | Can you complete the following Python function?
169 | ```python
170 | {prompt}
171 | ```
172 | <|im_end|>
173 | <|im_start|>assistant
174 | ```python
175 | """
176 | return VLlmDecoder.codegen(self, input, do_sample, num_samples)
177 |
178 |
179 | def make_model(
180 | model_type: str,
181 | model_size: str,
182 | model_path: str,
183 | batch_size: int = 1,
184 | temperature: float = 0.8,
185 | dataset: str = None,
186 | tensor_parallel_size = 1
187 | ):
188 | if model_type == "codeqwen" or model_type == "qwen2":
189 | if "chat" in model_size.lower():
190 | if "awq" in model_size.lower():
191 | return AWQChatML(
192 | batch_size=batch_size,
193 | name=model_path,
194 | temperature=temperature,
195 | max_new_tokens=2048,
196 | tensor_parallel_size = tensor_parallel_size
197 | )
198 | else:
199 | return ChatML(
200 | batch_size=batch_size,
201 | name=model_path,
202 | temperature=temperature,
203 | max_new_tokens=2048,
204 | tensor_parallel_size = tensor_parallel_size
205 | )
206 | else:
207 | return VLlmDecoder(
208 | batch_size=batch_size,
209 | name=model_path,
210 | temperature=temperature,
211 | dataset=dataset,
212 | tensor_parallel_size = tensor_parallel_size
213 | )
214 | else:
215 | raise ValueError(f"Invalid model name: {model_type}@{model_size}")
216 |
--------------------------------------------------------------------------------
/eval_plus/readme.md:
--------------------------------------------------------------------------------
1 | Sourced from the [Qwen2.5-Coder repository](https://github.com/QwenLM/Qwen2.5-Coder/tree/main/qwencoder-eval/instruct/eval_plus) with updated dependencies for better reproducability.
2 |
3 | ## Evaluation for HumanEval(+) and MBPP(+)
4 |
5 | This folder contains the code and scripts to evaluate the performance of the **QwenCoder-2.5** series models on [**EvalPlus**](https://github.com/evalplus/evalplus) benchmark, which includes HumanEval(+) and MBPP(+) datasets. These datasets are designed to test code generation capabilities under varied conditions.
6 |
7 |
8 | ### 1. Setup
9 |
10 | Please refer to [**EvalPlus**](https://github.com/evalplus/evalplus) for detailed setup instructions. Install the required packages using:
11 |
12 | ```bash
13 | pip install evalplus --upgrade
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ### 2. Inference and Evaluation
18 |
19 | We utilize 8xA100 GPUs for this benchmark. The following scripts are used to run the inference and evaluations:
20 |
21 | ```bash
22 | bash test.sh {path_to_your_local_model_checkpoint} {tensor_parallel_size} {output_dir}
23 | ```
24 |
--------------------------------------------------------------------------------
/eval_plus/requirements.txt:
--------------------------------------------------------------------------------
1 | datamodel_code_generator
2 | anthropic
3 | mistralai
4 | google-generativeai
5 | rich
6 | accelerate
7 | vllm==0.6.6
8 | stop-sequencer
9 | evalplus
10 | setuptools
11 | scipy
12 | hf_transfer
--------------------------------------------------------------------------------
/eval_plus/test.sh:
--------------------------------------------------------------------------------
1 | mkdir -p results/humaneval
2 | # export HF_ENDPOINT=https://hf-mirror.com
3 | export PATH=./vllm/bin:$PATH
4 | export PYTHONPATH=$PYTHONPATH:./eval_plus/evalplus
5 | MODEL_DIR=${1}
6 | # uv pip install datamodel_code_generator anthropic mistralai google-generativeai
7 |
8 |
9 | MODEL_DIR=${MODEL_DIR:-"/path/to/pretrained_models/"}
10 | TP=${2}
11 | TP=${TP:-1}
12 | OUTPUT_DIR=${3}
13 | OUTPUT_DIR=${OUTPUT_DIR:-"/path/to/"}
14 | mkdir -p ${OUTPUT_DIR}
15 |
16 | echo "EvalPlus: ${MODEL_DIR}, OUTPUT_DIR ${OUTPUT_DIR}"
17 |
18 | python generate.py \
19 | --model_type qwen2 \
20 | --model_size chat \
21 | --model_path ${MODEL_DIR} \
22 | --bs 1 \
23 | --temperature 0 \
24 | --n_samples 1 \
25 | --greedy \
26 | --root ${OUTPUT_DIR} \
27 | --dataset humaneval \
28 | --tensor-parallel-size ${TP}
29 |
30 | echo "Generated samples: ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0"
31 |
32 | echo "Sanitizing samples"
33 | python -m evalplus.sanitize --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0
34 |
35 | echo "Evaluating humaneval raw"
36 | evalplus.evaluate \
37 | --dataset humaneval \
38 | --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0 > ${OUTPUT_DIR}/raw_humaneval_results.txt
39 | echo "Finished evaluating humaneval file at ${OUTPUT_DIR}/raw_humaneval_results.txt"
40 | echo "Evaluating humaneval sanitized"
41 | evalplus.evaluate \
42 | --dataset humaneval \
43 | --samples ${OUTPUT_DIR}/humaneval/qwen2_chat_temp_0.0-sanitized > ${OUTPUT_DIR}/humaneval_results.txt
44 | echo "Finished evaluating humaneval file at ${OUTPUT_DIR}/humaneval_results.txt"
45 |
46 | python generate.py \
47 | --model_type qwen2 \
48 | --model_size chat \
49 | --model_path ${MODEL_DIR} \
50 | --bs 1 \
51 | --temperature 0 \
52 | --n_samples 1 \
53 | --greedy \
54 | --root ${OUTPUT_DIR} \
55 | --dataset mbpp \
56 | --tensor-parallel-size ${TP}
57 |
58 | echo "Sanitizing mbpp"
59 | python -m evalplus.sanitize --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0
60 |
61 | echo "Evaluating mbpp raw"
62 | evalplus.evaluate \
63 | --dataset mbpp \
64 | --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0 > ${OUTPUT_DIR}/raw_mbpp_results.txt
65 | echo "Finished evaluating mbpp file at ${OUTPUT_DIR}/raw_mbpp_results.txt"
66 | echo "Evaluating mbpp sanitized"
67 | evalplus.evaluate \
68 | --dataset mbpp \
69 | --samples ${OUTPUT_DIR}/mbpp/qwen2_chat_temp_0.0-sanitized > ${OUTPUT_DIR}/mbpp_results.txt
70 | echo "Finished evaluating mbpp file at ${OUTPUT_DIR}/mbpp_results.txt"
--------------------------------------------------------------------------------
/grpo_code/__init__.py:
--------------------------------------------------------------------------------
1 | from .rewards import (
2 | code_execution_reward_func,
3 | answer_execution_reward_func,
4 | soft_format_reward_func,
5 | )
6 | from .transforms import axolotl_acecode_transform
7 |
8 | __all__ = [
9 | "code_execution_reward_func",
10 | "answer_execution_reward_func",
11 | "soft_format_reward_func",
12 | "axolotl_acecode_transform",
13 | ]
14 |
--------------------------------------------------------------------------------
/grpo_code/executor.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ProcessPoolExecutor
2 |
3 | from grpo_code.wasm import does_code_run, PythonWasmEnvironment
4 |
5 | _executor, worker_env = None, None
6 |
7 |
8 | def get_executor(
9 | max_processes: int, wasm_path: str, fuel: int
10 | ) -> None | ProcessPoolExecutor:
11 | """
12 | Get the executor for the given number of processes and WASM environment.
13 | For parallel execution, we use a multiprocessing executor, where a pool
14 | of workers, each with their own WASM environment, execute the tasks in parallel.
15 |
16 | For single process execution, we instead initialize a global WASM environment.
17 |
18 | Args:
19 | max_processes (int): The maximum number of processes to use.
20 | wasm_path (str): The path to the .wasm file.
21 | fuel (int): The amount of fuel to use for the WASM environment.
22 |
23 | Returns:
24 | executor (None | ProcessPoolExecutor): If parallel execution is requested,
25 | we return a ProcessPoolExecutor, otherwise we return None.
26 | """
27 | global _executor, worker_env
28 | if max_processes > 1:
29 | from grpo_code.parallel_executor import get_multiprocessing_executor
30 |
31 | _executor = get_multiprocessing_executor(max_processes, wasm_path, fuel)
32 | return _executor
33 | if worker_env is None:
34 | import grpo_code.wasm as wasm
35 |
36 | wasm.worker_env = PythonWasmEnvironment(wasm_path, fuel)
37 |
38 |
39 | def execute_tasks(
40 | tasks: list[str], max_processes: int, wasm_path: str, fuel: int, task_timeout: int
41 | ):
42 | """
43 | Run a list of code snippets in a WASM environment.
44 |
45 | Args:
46 | tasks (list[str]): The list of code snippets to run.
47 | max_processes (int): The maximum number of processes to use.
48 | wasm_path (str): The path to the .wasm file.
49 | fuel (int): The amount of fuel to use for the WASM environment.
50 | task_timeout (int): If using multiprocessing, the timeout for each task.
51 |
52 | Returns:
53 | list[bool]: The list of results from running the code snippets.
54 | """
55 | executor = get_executor(max_processes, wasm_path, fuel)
56 | if max_processes > 1:
57 | from grpo_code.parallel_executor import run_tasks_with_multiprocessing_executor
58 |
59 | return run_tasks_with_multiprocessing_executor(executor, tasks, task_timeout)
60 | else:
61 | return list(map(does_code_run, tasks))
62 |
--------------------------------------------------------------------------------
/grpo_code/parallel_executor.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import signal
3 | import sys
4 | from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor, wait
5 |
6 | from grpo_code.wasm import does_code_run, PythonWasmEnvironment
7 |
8 | _executor = None
9 |
10 |
11 | def worker_init(wasm_path: str, fuel: int):
12 | """
13 | Initialize a WASM environment for a worker process.
14 |
15 | Args:
16 | wasm_path (str): The path to the .wasm file.
17 | fuel (int): The amount of fuel available to the WASM environment.
18 | """
19 | import grpo_code.wasm as wasm
20 |
21 | wasm.worker_env = PythonWasmEnvironment(wasm_path, fuel)
22 |
23 |
24 | def cleanup_executor():
25 | """
26 | Cleanup any running `ProcessPoolExecutor`.
27 | """
28 | global _executor
29 | if _executor is not None:
30 | _executor.shutdown(wait=False)
31 | _executor = None
32 |
33 |
34 | def cleanup_and_exit():
35 | """
36 | Gracefully shutsdown any running `ProcessPoolExecutor` on
37 | recieving a terminal signal.
38 | """
39 | cleanup_executor()
40 | sys.exit(0)
41 |
42 |
43 | def get_multiprocessing_executor(max_processes: int, wasm_path: str, fuel: int):
44 | """
45 | Initialize a reusable `ProcessPoolExecutor` instance.
46 |
47 | Args:
48 | max_processes (int): The maximum number of processes to use.
49 | wasm_path (str): The path to the .wasm file.
50 | fuel (int): The amount of fuel available to the WASM environment.
51 |
52 | Returns:
53 | ProcessPoolExecutor: A `ProcessPoolExecutor` for parallel execution.
54 | """
55 | global _executor
56 | if _executor is None:
57 | _executor = ProcessPoolExecutor(
58 | max_workers=max_processes,
59 | initializer=worker_init,
60 | initargs=(wasm_path, fuel),
61 | )
62 | atexit.register(cleanup_executor)
63 | signal.signal(signal.SIGINT, cleanup_and_exit)
64 | signal.signal(signal.SIGTERM, cleanup_and_exit)
65 | return _executor
66 |
67 |
68 | def run_tasks_with_multiprocessing_executor(
69 | executor: ProcessPoolExecutor, tasks: list[str], timeout: int
70 | ):
71 | """
72 | Run a list of code snippets in parallel by dispatching them to workers
73 | in a `ProcessPoolExecutor`.
74 |
75 | This function will gracefully handle timeout errors caused by workers, and
76 | recreate the `ProcessPoolExecutor` if necessary to avoid interupting training.
77 |
78 | Args:
79 | executor (ProcessPoolExecutor): The executor to use.
80 | tasks (list[str]): The list of code snippets to run.
81 | timeout (int): The timeout for each task.
82 |
83 | Returns:
84 | list[float]: The list of results from running the code snippets.
85 | """
86 | futures_to_index = {
87 | executor.submit(does_code_run, task): i for i, task in enumerate(tasks)
88 | }
89 | futures = list(futures_to_index)
90 | results = [0.0] * len(tasks)
91 |
92 | while futures:
93 | done, futures = wait(futures, timeout=timeout, return_when=FIRST_EXCEPTION)
94 |
95 | if futures and len(done) == 0:
96 | print(
97 | f"WARNING: Tasks timed out after {timeout} seconds, recreating process pool..."
98 | )
99 | cleanup_executor()
100 | break
101 |
102 | for future in done:
103 | task_index = futures_to_index[future]
104 | results[task_index] = future.result()
105 |
106 | return results
107 |
--------------------------------------------------------------------------------
/grpo_code/rewards.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import re
4 | from pathlib import Path
5 |
6 | import grpo_code
7 | from grpo_code.executor import execute_tasks
8 |
9 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
10 | MAX_WORKERS = max(1, int(os.environ.get("MAX_WORKERS", 1)) // WORLD_SIZE)
11 | TASK_TIMEOUT = int(os.environ.get("TASK_TIMEOUT", 1))
12 | WASM_PATH = os.environ.get(
13 | "WASM_PATH", Path(grpo_code.__file__).parent.parent / "wasm" / "python-3.12.0.wasm"
14 | )
15 | FUEL = int(os.environ.get("FUEL", 1_000_000_000))
16 |
17 | if not os.path.exists(WASM_PATH):
18 | raise FileNotFoundError(f"WASM file not found at {WASM_PATH}")
19 |
20 |
21 | def extract_xml_answer(text: str) -> str:
22 | """
23 | Extract text between and tags.
24 |
25 | Args:
26 | text (str): The text to extract the answer from.
27 | Returns:
28 | str: The answer extracted from the text. "" if no answer is found.
29 |
30 | """
31 | match = re.search(r"(.*?)", text, re.S)
32 | return match.group(1).strip() if match else ""
33 |
34 |
35 | def code_execution_reward_func(completions: list[list[dict]], **kwargs) -> list[float]:
36 | """
37 | Reward function for code execution.
38 |
39 | Args:
40 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format
41 | [
42 | [
43 | {"role": "user", "content": "......"}
44 | ]
45 | ]
46 | Returns:
47 | list[float]: The rewards for the completions. Each completion is rewarded 0.5 if the code executes, -0.25 otherwise.
48 | """
49 | model_answers = [
50 | extract_xml_answer(completion[0]["content"]) for completion in completions
51 | ]
52 | task_results = execute_tasks(
53 | model_answers, MAX_WORKERS, WASM_PATH, FUEL, TASK_TIMEOUT
54 | )
55 | return [0.5 if result == 1.0 else -0.25 for result in task_results]
56 |
57 |
58 | def answer_execution_reward_func(
59 | completions: list[list[dict]], answers: list[list[str]], **kwargs
60 | ) -> list[float]:
61 | """
62 | Reward function for answer execution.
63 |
64 | Args:
65 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format
66 | [
67 | [
68 | {"role": "user", "content": "......"}
69 | ]
70 | ]
71 | answers (list[list[str]]): The expected answers to the code completions. These take the form of executable
72 | assert statements, e.g.
73 | [
74 | [
75 | "assert foo(1) == 2",
76 | "assert foo(2) == 3",
77 | ]
78 | ]
79 | Returns:
80 | list[float]: The accuracy rewards for the completions. Each completion is rewarded
81 | (accuracy)^3 * 2, where accuracy is the proportion of test cases that pass.
82 | """
83 | model_answers = [
84 | extract_xml_answer(completion[0]["content"]) for completion in completions
85 | ]
86 | tasks = []
87 | test_indices = []
88 | for i, (code, tests) in enumerate(zip(model_answers, answers)):
89 | for test in tests:
90 | tasks.append(code + "\n" + test)
91 | test_indices.append(i)
92 |
93 | task_results = execute_tasks(tasks, MAX_WORKERS, WASM_PATH, FUEL, TASK_TIMEOUT)
94 |
95 | completion_results = {}
96 | for idx, result in zip(test_indices, task_results):
97 | if idx not in completion_results:
98 | completion_results[idx] = []
99 | completion_results[idx].append(result)
100 |
101 | rewards = []
102 | for i in range(len(completions)):
103 | if i in completion_results:
104 | test_results = completion_results[i]
105 | accuracy = sum(test_results) / len(test_results)
106 | reward = math.pow(accuracy, 3) * 2
107 | else:
108 | reward = 0.0
109 | rewards.append(reward)
110 | return rewards
111 |
112 |
113 | def soft_format_reward_func(completions, **kwargs) -> list[float]:
114 | """
115 | Reward function for soft format checking.
116 |
117 | Args:
118 | completions (list[list[dict]]): The predicted code completions to execute. This takes the format
119 | [
120 | [
121 | {"role": "user", "content": content}
122 | ]
123 | ]
124 | Returns:
125 | list[float]: The rewards for the completions. Each completion is rewarded 0.25 if the format is correct, 0.0 otherwise.
126 | """
127 |
128 | responses = [completion[0]["content"] for completion in completions]
129 | rewards = []
130 | for response in responses:
131 | if re.match(
132 | r".*?\s*.*?.*", response, re.S
133 | ):
134 | rewards.append(0.25)
135 | else:
136 | rewards.append(0.0)
137 | return rewards
138 |
--------------------------------------------------------------------------------
/grpo_code/transforms.py:
--------------------------------------------------------------------------------
1 | SYSTEM_PROMPT = """
2 | Respond in the following format:
3 |
4 | ...
5 |
6 |
7 | ...
8 |
9 |
10 | Additionally, you may optionally use the following imports:
11 |
12 | import time
13 | import itertools
14 | from itertools import accumulate, product, permutations, combinations
15 | import collections
16 | from collections import Counter, OrderedDict, deque, defaultdict, ChainMap
17 | from functools import lru_cache
18 | import math
19 | from typing import List, Dict, Tuple, Optional, Any
20 |
21 | If you choose to use any of these imports, ensure they are included
22 | inside the tags, e.g:
23 |
24 |
25 | import time
26 | import math
27 | ...
28 |
29 |
30 | You may not utilise any other imports or filesystem operations.
31 | """
32 |
33 |
34 | def axolotl_acecode_transform(cfg, *args, **kwargs):
35 | def transform_fn(example, tokenizer=None):
36 | return {
37 | "prompt": [
38 | {
39 | "role": "user",
40 | "content": example["question"] + "\n\n" + SYSTEM_PROMPT,
41 | }
42 | ],
43 | "answers": example["test_cases"],
44 | }
45 |
46 | return transform_fn, {"remove_columns": ["question", "test_cases"]}
47 |
--------------------------------------------------------------------------------
/grpo_code/wasm.py:
--------------------------------------------------------------------------------
1 | from wasmtime import Config, Engine, Linker, Module, Store, WasiConfig
2 |
3 | worker_env = None
4 |
5 |
6 | class PythonWasmEnvironment:
7 | """A reusable WASM environment for running Python code.
8 |
9 | Args:
10 | wasm_path (str): The path to the .wasm file.
11 | fuel (int): The amount of fuel to use for the WASM environment.
12 | """
13 |
14 | def __init__(self, wasm_path: str, fuel: int):
15 | self.wasm_path = wasm_path
16 | self.fuel = fuel
17 |
18 | # Set up the engine and linker
19 | engine_cfg = Config()
20 | engine_cfg.consume_fuel = True
21 | engine_cfg.cache = True
22 |
23 | self.engine = Engine(engine_cfg)
24 | self.linker = Linker(self.engine)
25 | self.linker.define_wasi()
26 |
27 | # Load the Python module
28 | self.python_module = Module.from_file(self.engine, self.wasm_path)
29 |
30 | def run_code(self, code: str):
31 | """Run Python code in the WASM environment subject to fuel limits.
32 |
33 | Args:
34 | code (str): The Python code to run.
35 |
36 | """
37 | config = WasiConfig()
38 | config.argv = ("python", "-c", code)
39 | config.inherit_env = False
40 |
41 | store = Store(self.engine)
42 | store.set_fuel(self.fuel)
43 | store.set_wasi(config)
44 |
45 | instance = self.linker.instantiate(store, self.python_module)
46 | start = instance.exports(store)["_start"]
47 | start(store)
48 |
49 |
50 | def does_code_run(code: str) -> bool:
51 | """Execute code in the worker's WASM environment and check if it runs without errors.
52 |
53 | Args:
54 | code (str): The Python code to run.
55 |
56 | Returns:
57 | bool: True if the code runs without errors, False otherwise.
58 | """
59 | global worker_env
60 | try:
61 | worker_env.run_code(code)
62 | return True
63 | except Exception:
64 | return False
65 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "grpo_code"
3 | dependencies = [
4 | "wasmtime",
5 | ]
6 | version = "0.1.0"
7 |
8 | [build-system]
9 | requires = ["setuptools>=64", "wheel"]
10 | build-backend = "setuptools.build_meta"
11 |
12 |
13 | [tool.setuptools.packages.find]
14 | where = [""]
15 | include = ["grpo_code*"]
16 |
--------------------------------------------------------------------------------
/r1_acecode.yaml:
--------------------------------------------------------------------------------
1 | base_model: Qwen/Qwen2.5-3B-Instruct
2 | # Automatically upload checkpoint and final model to HF
3 | # hub_model_id: username/custom_model_name
4 |
5 | load_in_8bit: false
6 | load_in_4bit: false
7 | strict: false
8 |
9 | torch_compile: true
10 |
11 | vllm:
12 | host: 0.0.0.0
13 | port: 8000
14 | tensor_parallel_size: 2
15 | gpu_memory_utilization: 0.85
16 | dtype: auto
17 |
18 | rl: grpo
19 | trl:
20 | beta: 0.001
21 | use_vllm: true
22 | vllm_server_host: 0.0.0.0
23 | vllm_server_port: 8000
24 | vllm_server_timeout: 300
25 | reward_funcs:
26 | - grpo_code.soft_format_reward_func
27 | - grpo_code.code_execution_reward_func
28 | - grpo_code.answer_execution_reward_func
29 |
30 | num_generations: 16
31 | max_completion_length: 512
32 | log_completions: false
33 |
34 | chat_template: qwen_25
35 | datasets:
36 | - path: axolotl-ai-co/AceCode-87K
37 | type: grpo_code.axolotl_acecode_transform
38 | split: train
39 |
40 | dataset_prepared_path: /workspace/data/last_run_prepared
41 | dataset_processes:
42 | skip_prepare_dataset: true
43 | val_set_size: 0.0
44 | output_dir: /workspace/data/axolotl-artifacts/r1-outputs/1403
45 |
46 | dataloader_prefetch_factor: 32
47 | dataloader_num_workers: 2
48 | dataloader_pin_memory: true
49 |
50 | gc_steps: 1
51 | sequence_len: 1024
52 | sample_packing: false
53 | eval_sample_packing: false
54 | pad_to_sequence_len: false
55 |
56 | gradient_accumulation_steps: 2
57 | micro_batch_size: 32
58 | num_epochs: 1
59 | max_steps: 2500
60 |
61 |
62 | optimizer: adamw_torch_fused
63 | lr_scheduler: warmup_stable_decay
64 | lr_scheduler_kwargs:
65 | num_stable_steps: 1500
66 | num_decay_steps: 500
67 | min_lr_ratio: 0.1
68 | num_cycles: 0.5
69 |
70 | learning_rate: 5.3e-6
71 | max_grad_norm: 1.0
72 |
73 | train_on_inputs: false
74 | group_by_length: false
75 |
76 | bf16: true
77 | tf32: true
78 | early_stopping_patience:
79 | resume_from_checkpoint:
80 | local_rank:
81 | logging_steps: 1
82 | gradient_checkpointing: true
83 | gradient_checkpointing_kwargs:
84 | use_reentrant: false
85 | flash_attention: true
86 |
87 | warmup_steps: 500
88 | evals_per_epoch: 0
89 | saves_per_epoch: 0
90 | save_steps: 0.5
91 |
92 |
93 | # wandb_project:
94 | # wandb_entity:
95 | # wandb_name: passk
96 | # hub_model_id:
97 |
--------------------------------------------------------------------------------
/wasm/python-3.12.0.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axolotl-ai-cloud/grpo_code/148ea79321f34bbed79b3b55f04c0a7de002665d/wasm/python-3.12.0.wasm
--------------------------------------------------------------------------------
/wasm/python-3.12.0.wasm.sha256sum:
--------------------------------------------------------------------------------
1 | e5dc5a398b07b54ea8fdb503bf68fb583d533f10ec3f930963e02b9505f7a763 wasm/python-3.12.0.wasm
2 |
--------------------------------------------------------------------------------