├── LICENSE
├── MODEL_LICENSE
├── README.md
├── README_EN.md
├── README_FR.md
├── README_JA.md
├── benchmark
└── humanevalx
│ ├── go
│ ├── go.mod
│ ├── go.sum
│ └── vendor.tar.gz
│ ├── humanevalx_cpp.jsonl.gz
│ ├── humanevalx_go.jsonl.gz
│ ├── humanevalx_java.jsonl.gz
│ ├── humanevalx_js.jsonl.gz
│ ├── humanevalx_python.jsonl.gz
│ ├── humanevalx_rust.jsonl.gz
│ └── rust
│ ├── Cargo.lock
│ └── Cargo.toml
├── demo
├── example_inputs.jsonl
├── fastapicpu.py
├── gpus.py
└── run_demo.py
├── docs
└── zh
│ └── inference_zh.md
├── evaluation
├── __init__.py
├── evaluation.py
├── execution.py
├── generation.py
├── inspect_jsonl.py
└── utils.py
├── requirements.txt
├── resources
├── codegeex_demo.png
├── codegeex_logo.png
├── join_wechat.png
└── wechat.md
└── scripts
├── run_humanevalx.sh
└── sanity_check.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright Zhengxiao Du
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The CodeGeeX License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means the CodeGeeX Model Team that distributes its Software.
6 |
7 | “Software” means the CodeGeeX model parameters made available under this license.
8 |
9 | 2. License Grant
10 |
11 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
12 |
13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
14 |
15 | 3. Restriction
16 |
17 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
18 |
19 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
20 |
21 | 4. Disclaimer
22 |
23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |
25 | 5. Limitation of Liability
26 |
27 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
28 |
29 | 6. Dispute Resolution
30 |
31 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
32 |
33 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 |
4 | 🏠 主页|🛠 插件 VS Code, Jetbrains|🤗 模型下载|📄 论文|👋 加入微信开发者交流群
5 |
6 |
7 | Read this in [English](README_EN.md)
8 | [日本語](README_JA.md)で読む
9 | Lire en [Français](README_FR.md)
10 |
11 | ⭐️ 最新一代 [CodeGeeX4](https://github.com/THUDM/CodeGeeX4) 模型已经正式开源。
12 | The newest [CodeGeeX4](https://github.com/THUDM/CodeGeeX4) has been released.
13 |
14 | # CodeGeeX2: 更强大的多语言代码生成模型
15 |
16 | CodeGeeX2 是多语言代码生成模型 [CodeGeeX](https://github.com/THUDM/CodeGeeX) ([KDD’23](https://arxiv.org/abs/2303.17568)) 的第二代模型。不同于一代 CodeGeeX(完全在国产华为昇腾芯片平台训练) ,CodeGeeX2 是基于 [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) 架构加入代码预训练实现,得益于 ChatGLM2 的更优性能,CodeGeeX2 在多项指标上取得性能提升(+107% > CodeGeeX;仅60亿参数即超过150亿参数的 StarCoder-15B 近10%),更多特性包括:
17 |
18 | * **更强大的代码能力**:基于 ChatGLM2-6B 基座语言模型,CodeGeeX2-6B 进一步经过了 600B 代码数据预训练,相比一代模型,在代码能力上全面提升,[HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) 评测集的六种编程语言均大幅提升 (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%),在Python上达到 35.9\% 的 Pass@1 一次通过率,超越规模更大的 StarCoder-15B。
19 | * **更优秀的模型特性**:继承 ChatGLM2-6B 模型特性,CodeGeeX2-6B 更好支持中英文输入,支持最大 8192 序列长度,推理速度较一代 CodeGeeX-13B 大幅提升,量化后仅需6GB显存即可运行,支持轻量级本地化部署。
20 | * **更全面的AI编程助手**:CodeGeeX插件([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))后端升级,支持超过100种编程语言,新增上下文补全、跨文件补全等实用功能。结合 Ask CodeGeeX 交互式AI编程助手,支持中英文对话解决各种编程问题,包括且不限于代码解释、代码翻译、代码纠错、文档生成等,帮助程序员更高效开发。
21 | * **更开放的协议**:CodeGeeX2-6B 权重对学术研究完全开放,填写[登记表](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B)申请商业使用。
22 |
23 | ## 使用教程
24 |
25 | * [快速开始](#快速开始)
26 | * [推理教程(多卡推理,加速推理,多平台推理等)](docs/zh/inference_zh.md)
27 |
28 | ## AI编程助手
29 |
30 | 
31 |
32 | 我们开发了支持 VS Code、 IntelliJ IDEA、PyCharm、GoLand、WebStorm、Android Studio 等IDE的 CodeGeeX 插件。在插件中,可以更直接地体验到 CodeGeeX2 模型在代码生成与补全、添加注释、代码翻译及技术问答方面的能力为开发效率带来的提升。欢迎在IDE中下载 CodeGeeX 插件获得更加全面的AI编程体验,详情见[CodeGeeX主页](https://codegeex.cn/)。
33 |
34 |
35 | ## 快速开始
36 |
37 | ### 使用`transformers`快速调用[CodeGeeX2-6B](https://huggingface.co/THUDM/codegeex2-6b):
38 |
39 | ```python
40 | from transformers import AutoTokenizer, AutoModel
41 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
42 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
43 | model = model.eval()
44 |
45 | # remember adding a language tag for better performance
46 | prompt = "# language: Python\n# write a bubble sort function\n"
47 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
48 | outputs = model.generate(inputs, max_length=256, top_k=1)
49 | response = tokenizer.decode(outputs[0])
50 |
51 | >>> print(response)
52 | # language: Python
53 | # write a bubble sort function
54 |
55 |
56 | def bubble_sort(list):
57 | for i in range(len(list) - 1):
58 | for j in range(len(list) - 1):
59 | if list[j] > list[j + 1]:
60 | list[j], list[j + 1] = list[j + 1], list[j]
61 | return list
62 |
63 |
64 | print(bubble_sort([5, 2, 1, 8, 4]))
65 | ```
66 |
67 | ### 启动 Gradio DEMO:
68 | ```
69 | python ./demo/run_demo.py
70 |
71 | usage: run_demo.py [-h] [--model-path MODEL_PATH] [--example-path EXAMPLE_PATH] [--quantize QUANTIZE]
72 | [--chatglm-cpp] [--fastllm] [--n-gpus N_GPUS] [--gpu GPU] [--cpu] [--auth] [--username yourname]
73 | [--password yourpassword]
74 | [--port PORT] [--listen ADDRESS]
75 |
76 | # 若要启用身份验证,请先启用--auth,然后定义--username与--password,如:
77 | python run_demo.py --auth --username user --password password # 若要监听所有地址请指定 --listen 0.0.0.0
78 | ```
79 | 支持使用 [ChatGLM.cpp](https://github.com/li-plus/chatglm.cpp) 量化推理加速:
80 | ```sh
81 | python ./demo/run_demo.py --quantize 4 --chatglm-cpp
82 | ```
83 | ### 启动FAST API:
84 | ```
85 | python ./demo/fastapicpu.py
86 | usage: fastapicpu.py [-h] [--model-path MODEL_PATH] [--listen ADDRESS] [--port PORT] [--workders NUM] [--cpu] [--half] [--quantize QUANTIZE] [--chatglm-cpp]
87 | # --cpu启用cpu --half启用.half()
88 | ```
89 | 支持使用 [ChatGLM.cpp](https://github.com/li-plus/chatglm.cpp) 量化推理加速,同样添加 `--quantize 4 --chatglm-cpp` 参数即可。
90 | ### API使用示例
91 | ```
92 | curl -X POST "http://127.0.0.1:7860" \
93 | -H 'Content-Type: application/json' \
94 | -d '{"lang": "Python", "prompt": "# Write a quick sort function"}'
95 | ```
96 |
97 |
98 | ❗️请注意:
99 | * CodeGeeX2-6B 是一个基座代码生成模型,不具备聊天能力。请前往插件中体验更全面的 Ask CodeGeeX 聊天功能。
100 | * 在使用 CodeGeeX2-6B 的补全功能时,输入prompt需要遵循特定的格式以获得最好的效果。比如需要在开头加入编程语言标签(`# language: Python`,请查看[完整语言列表](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14)),以注释的形式写prompt等。参考`run_demo.py`中的处理。
101 | * 如果显卡不支持`bfloat16`格式,将会输出错误的内容,需要将模型转换成`float16`格式:
102 | ```python
103 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).half().cuda()
104 | ```
105 | * 如果需要使用多显卡加载模型,可以将以下代码:
106 | ```python
107 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
108 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
109 | model = model.eval()
110 | ```
111 | 替换为
112 |
113 | ```python
114 | def get_model():
115 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
116 | from gpus import load_model_on_gpus
117 | # gpus文件在demo文件夹中
118 | model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
119 | model = model.eval()
120 | return tokenizer, model
121 |
122 | tokenizer, model = get_model()
123 | ```
124 |
125 | ## 代码能力评测
126 |
127 | CodeGeeX2 作为一个多语言代码生成基座模型,代码能力较上一代大幅提升,以下是在 HumanEval,HumanEval-X, DS1000 基准上的评测结果(评价指标 Pass@k 定义与[论文](https://arxiv.org/abs/2303.17568)中一致):
128 |
129 | ### HumanEval (Pass@1,10,100)
130 |
131 | | **Model** | **Pass@1** | **Pass@10** | **Pass@100** |
132 | | :-----------------: | :--------: | :---------: | :----------: |
133 | | CodeGen-16B-multi | 19\.2 | 34\.6 | 55\.2 |
134 | | CodeGeeX-13B | 22\.9 | 39\.6 | 60\.9 |
135 | | Codex-12B | 28\.8 | 46\.8 | 72\.3 |
136 | | CodeT5Plus-16B-mono | 30\.9 | 51\.6 | 76\.7 |
137 | | Code-Cushman-001 | 33\.5 | 54\.3 | 77\.4 |
138 | | LLaMA-65B | 23\.7 | - | 79\.3 |
139 | | LLaMA2-70B | 29\.9 | - | - |
140 | | CodeGen2\.5-7B-mono | 33\.4 | 58\.4 | 82\.7 |
141 | | StarCoder-15B | 33\.2 | 61\.0 | 84\.7 |
142 | | **CodeGeeX2-6B** | **35\.9** | **62\.6** | **88\.3** |
143 | > **Pass@1** 使用 `n=20, t=0.2, top_p=0.95`;**Pass@10,Pass@100** 使用 `n=200, t=0.8, top_p=0.95`。
144 |
145 | ### HumanEval-X (Pass@1)
146 |
147 | | **Model** | **Python** | **C++** | **Java** | **JavaScript** | **Go** | **Rust** | **Overall** |
148 | | :------------------: | :--------: | :-------: | :-------: | :------------: | :-------: | :-------: | :---------: |
149 | | CodeGen-16B-multi | 19\.2 | 18\.1 | 15\.0 | 18\.4 | 13\.0 | 1\.8 | 14\.2 |
150 | | CodeGeeX-13B | 22\.9 | 17\.1 | 20\.0 | 17\.6 | 14\.4 | 4\.3 | 16\.0 |
151 | | Replit-code-v1-3B | 22\.0 | 20\.1 | 20\.1 | 20\.1 | 12\.2 | 8\.6 | 17\.2 |
152 | | CodeGen2\.5-7B-multi | 30\.6 | 24\.3 | 29\.0 | 27\.5 | 18\.9 | **20\.1** | 25\.1 |
153 | | StarCoder-15B | 35\.5 | 28\.2 | **31\.5** | **33\.2** | 21\.3 | 17\.8 | 27\.9 |
154 | | **CodeGeeX2-6B** | **35\.9** | **29\.3** | 30\.8 | 32\.2 | **22\.5** | 18\.1 | **28\.1** |
155 | > **Pass@1** 使用 `n=20, t=0.2, top_p=0.95`。
156 |
157 | 以上结果可使用脚本`scripts/run_humanevalx.sh`复现。环境配置和说明参见[评测环境](https://github.com/THUDM/CodeGeeX/blob/main/codegeex/benchmark/README_zh.md)。
158 |
159 | ### DS1000 (Pass@1)
160 |
161 | | **Model** | **Matplotlib** | **Numpy** | **Pandas** | **Pytorch** | **SciPy** | **Scikit-learn** | **TensorFlow** | **Overall** |
162 | | :--------------: | :------------: | :-------: | :--------: | :---------: | :-------: | :--------------: | :------------: | :---------: |
163 | | \# Samples | 155 | 220 | 291 | 68 | 106 | 115 | 45 | 1000 |
164 | | CodeGen-16B-Mono | 31\.7 | 10\.9 | 3\.4 | 7\.0 | 9\.0 | 10\.8 | 15\.2 | 11\.7 |
165 | | code-cushman-001 | 40\.7 | 21\.8 | 7\.9 | 12\.4 | 11\.3 | 18\.0 | 12\.2 | 18\.1 |
166 | | Codex-001 | 41\.8 | 26\.6 | 9\.4 | 9\.7 | 15\.0 | 18\.5 | 17\.2 | 20\.2 |
167 | | **CodeGeeX2-6B** | 40\.5 | 25\.5 | 14\.5 | 17\.3 | 19\.3 | 24\.0 | 23\.0 | 23\.1 |
168 | | StarCoder-15B | 51\.7 | 29\.7 | 11\.4 | 21\.4 | 20\.2 | 29\.5 | 24\.5 | 26\.0 |
169 | | Codex-002 | **57\.0** | **43\.1** | **26\.5** | **41\.8** | **31\.8** | **44\.8** | **39\.3** | **39\.2** |
170 | > **Pass@1** 使用 `n=40, t=0.2, top_p=0.5`。
171 |
172 | 以上结果可使用[DS1000评测代码](https://github.com/HKUNLP/DS-1000.git)复现。
173 |
174 | ## 量化推理性能
175 |
176 | CodeGeeX2 与上一代相比,对部署更加友好。得益于使用 Multi-Query Attention 和 Flash Attention,推理速度更快,且量化后仅需6GB显存即可运行:
177 |
178 | ### 量化
179 |
180 | | **Model** | FP16/BF16 | INT8 | INT4 |
181 | | :--------------: | :-------: | :-----: | :----: |
182 | | CodeGeeX-13B | 26\.9 GB | 14\.7 GB | - |
183 | | **CodeGeeX2-6B** | 13\.1 GB | 8\.2 GB | 5\.5 GB |
184 | > 基于 PyTorch 2.0 测试,利用`torch.nn.functional.scaled_dot_product_attention`实现高效的 Attention 计算。
185 |
186 | ### 推理
187 |
188 | | **Model** | **推理速度 (字符/秒)** |
189 | | :--------------: | :-------------: |
190 | | CodeGeeX-13B | 32 |
191 | | **CodeGeeX2-6B** | 94 |
192 | > `batch_size=1, max_length=2048`,均使用加速框架,测试硬件为`GeForce RTX-3090`。
193 |
194 | ## 协议
195 |
196 | 本仓库的代码依照 [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 协议开源,模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。CodeGeeX2-6B 权重对学术研究完全开放,填写[登记表](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B)申请商业使用。
197 |
198 |
199 | ## 引用
200 |
201 | 如果觉得我们的工作有帮助,欢迎引用以下论文:
202 |
203 | ```
204 | @inproceedings{zheng2023codegeex,
205 | title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
206 | author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
207 | booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
208 | pages={5673--5684},
209 | year={2023}
210 | }
211 | ```
212 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 |
4 | 🏠 Homepage|🛠 Extensions VS Code, Jetbrains|🤗 HF Repo|📄 Paper
5 |
6 |
7 |
8 | 👋 Join our Discord, Slack, Telegram, WeChat
9 |
10 |
11 | 查看[中文版](README.md)
12 | [日本語](README_JA.md)で読む
13 | Lire en [Français](README_FR.md)
14 |
15 | # CodeGeeX2: A More Powerful Multilingual Code Generation Model
16 |
17 | CodeGeeX2 is the second-generation model of the multilingual code generation model [CodeGeeX](https://github.com/THUDM/CodeGeeX) ([KDD’23](https://arxiv.org/abs/2303.17568)), which is implemented based on the [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) architecture trained on more code data. Due to the advantage of ChatGLM2, CodeGeeX2 has been comprehensively improved in coding capability (+107% > CodeGeeX; with only 6B parameters, surpassing larger StarCoder-15B for some tasks). It has the following features:
18 |
19 | * **More Powerful Coding Capabilities**: Based on the ChatGLM2-6B model, CodeGeeX2-6B has been further pre-trained on 600B code tokens, which has been comprehensively improved in coding capability compared to the first-generation. On the [HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) benchmark, all six languages have been significantly improved (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%), and in Python it reached 35.9% of Pass@1 one-time pass rate, surpassing the larger StarCoder-15B.
20 | * **More Useful Features**: Inheriting the ChatGLM2-6B model features, CodeGeeX2-6B better supports both Chinese and English prompts, maximum 8192 sequence length, and the inference speed is significantly improved compared to the first-generation. After quantization, it only needs 6GB of GPU memory for inference, thus supports lightweight local deployment.
21 | * **Comprehensive AI Coding Assistant**: The backend of CodeGeeX plugin ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)) is upgraded, supporting 100+ programming languages, and adding practical functions such as infilling and cross-file completion. Combined with the "Ask CodeGeeX" interactive AI coding assistant, it can be used to solve various programming problems via Chinese or English dialogue, including but not limited to code summarization, code translation, debugging, and comment generation, which helps increasing the efficiency of developpers.
22 | * **Open License**: CodeGeeX2-6B weights are fully open to academic research, and please apply for commercial use by filling in the [registration form](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B).
23 |
24 |
25 | ## AI Coding Assistant
26 |
27 | 
28 |
29 | We have developed the CodeGeeX plugin, which supports IDEs such as VS Code, IntelliJ IDEA, PyCharm, GoLand, WebStorm, and Android Studio. The plugin allows you to experience the CodeGeeX2 model's capabilities in code generation and completion, annotation, code translation, and "Ask CodeGeeX" interactive programming, which can help improve your development efficiency. Please download the CodeGeeX plugin in your IDE to get a more comprehensive AI coding experience. You can find more details on our [homepage]( https://codegeex.cn/).
30 |
31 | ## Get Started
32 |
33 | Use `transformers` to quickly launch [CodeGeeX2-6B](https://huggingface.co/THUDM/codegeex2-6b):
34 |
35 | ```python
36 | from transformers import AutoTokenizer, AutoModel
37 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
38 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
39 | model = model.eval()
40 |
41 | # remember adding a language tag for better performance
42 | prompt = "# language: Python\n# write a bubble sort function\n"
43 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
44 | outputs = model.generate(inputs, max_length=256, top_k=1)
45 | response = tokenizer.decode(outputs[0])
46 |
47 | >>> print(response)
48 | # language: Python
49 | # write a bubble sort function
50 |
51 |
52 | def bubble_sort(list):
53 | for i in range(len(list) - 1):
54 | for j in range(len(list) - 1):
55 | if list[j] > list[j + 1]:
56 | list[j], list[j + 1] = list[j + 1], list[j]
57 | return list
58 |
59 |
60 | print(bubble_sort([5, 2, 1, 8, 4]))
61 | ```
62 |
63 | Launch Gradio DEMO:
64 | ```
65 | python ./demo/run_demo.py
66 | ```
67 |
68 | ❗️Attention:
69 | * CodeGeeX2 is a base model, which is not instruction-tuned for chatting. It can do tasks like code completion/translation/explaination. To try the instruction-tuned version in CodeGeeX plugins ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)).
70 | * Programming languages can be controled by adding `language tag`, e.g., `# language: Python`. The format should be respected to ensure performance, full list can be found [here](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14). Please write comments under the format of the selected programming language to achieve better results.
71 | * If the GPU doesn't support `bfloat16` format, it will cause incorrect output. Please convert the model to `float16` format:
72 | ```python
73 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).half().cuda()
74 | ```
75 | * If you need to use Multiple GPUs to load the model, you can use the following code:
76 | ```python
77 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
78 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
79 | model = model.eval()
80 | ```
81 | Replace with
82 |
83 | ```python
84 | def get_model():
85 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
86 | from gpus import load_model_on_gpus
87 | # The "gpus" file is located in the demo folder
88 | model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
89 | model = model.eval()
90 | return tokenizer, model
91 |
92 | tokenizer, model = get_model()
93 | ```
94 |
95 | ## Evaluation
96 |
97 | CodeGeeX2 is a base model for multilingual code generation, which has been significantly improved in its coding ability compared to the previous generation. The following are the evaluation results on the HumanEval, HumanEval-X, and DS1000 benchmarks (the evaluation metric Pass@k is the same as in the [paper](https://arxiv.org/abs/2303.17568)):
98 |
99 | ### HumanEval (Pass@1,10,100)
100 |
101 | | **Model** | **Pass@1** | **Pass@10** | **Pass@100** |
102 | | :-----------------: | :--------: | :---------: | :----------: |
103 | | CodeGen-16B-multi | 19\.2 | 34\.6 | 55\.2 |
104 | | CodeGeeX-13B | 22\.9 | 39\.6 | 60\.9 |
105 | | Codex-12B | 28\.8 | 46\.8 | 72\.3 |
106 | | CodeT5Plus-16B-mono | 30\.9 | 51\.6 | 76\.7 |
107 | | Code-Cushman-001 | 33\.5 | 54\.3 | 77\.4 |
108 | | LLaMA-65B | 23\.7 | - | 79\.3 |
109 | | LLaMA2-70B | 29\.9 | - | - |
110 | | CodeGen2\.5-7B-mono | 33\.4 | 58\.4 | 82\.7 |
111 | | StarCoder-15B | 33\.2 | 61\.0 | 84\.7 |
112 | | **CodeGeeX2-6B** | **35\.9** | **62\.6** | **88\.3** |
113 | > `n=20, t=0.2, top_p=0.95` for **Pass@1**; `n=200, t=0.8, top_p=0.95` for **Pass@10** and **Pass@100**.
114 |
115 | ### HumanEval-X (Pass@1)
116 |
117 | | **Model** | **Python** | **C++** | **Java** | **JavaScript** | **Go** | **Rust** | **Overall** |
118 | | :------------------: | :--------: | :-------: | :-------: | :------------: | :-------: | :-------: | :---------: |
119 | | CodeGen-16B-multi | 19\.2 | 18\.1 | 15\.0 | 18\.4 | 13\.0 | 1\.8 | 14\.2 |
120 | | CodeGeeX-13B | 22\.9 | 17\.1 | 20\.0 | 17\.6 | 14\.4 | 4\.3 | 16\.0 |
121 | | Replit-code-v1-3B | 22\.0 | 20\.1 | 20\.1 | 20\.1 | 12\.2 | 8\.6 | 17\.2 |
122 | | CodeGen2\.5-7B-multi | 30\.6 | 24\.3 | 29\.0 | 27\.5 | 18\.9 | **20\.1** | 25\.1 |
123 | | StarCoder-15B | 35\.5 | 28\.2 | **31\.5** | **33\.2** | 21\.3 | 17\.8 | 27\.9 |
124 | | **CodeGeeX2-6B** | **35\.9** | **29\.3** | 30\.8 | 32\.2 | **22\.5** | 18\.1 | **28\.1** |
125 | > `n=20, t=0.2, top_p=0.95` for **Pass@1**.
126 |
127 | The above results can be reproduced by running `scripts/run_humanevalx.sh`. Refer to [HumanEval-X environment](https://github.com/THUDM/CodeGeeX/blob/main/codegeex/benchmark/README_zh.md) for the experiment setups.
128 |
129 | ### DS1000 (Pass@1)
130 |
131 | | **Model** | **Matplotlib** | **Numpy** | **Pandas** | **Pytorch** | **SciPy** | **Scikit-learn** | **TensorFlow** | **Overall** |
132 | | :--------------: | :------------: | :-------: | :--------: | :---------: | :-------: | :--------------: | :------------: | :---------: |
133 | | \# Samples | 155 | 220 | 291 | 68 | 106 | 115 | 45 | 1000 |
134 | | CodeGen-16B-Mono | 31\.7 | 10\.9 | 3\.4 | 7\.0 | 9\.0 | 10\.8 | 15\.2 | 11\.7 |
135 | | code-cushman-001 | 40\.7 | 21\.8 | 7\.9 | 12\.4 | 11\.3 | 18\.0 | 12\.2 | 18\.1 |
136 | | Codex-001 | 41\.8 | 26\.6 | 9\.4 | 9\.7 | 15\.0 | 18\.5 | 17\.2 | 20\.2 |
137 | | **CodeGeeX2-6B** | 40\.5 | 25\.5 | 14\.5 | 17\.3 | 19\.3 | 24\.0 | 23\.0 | 23\.1 |
138 | | StarCoder-15B | 51\.7 | 29\.7 | 11\.4 | 21\.4 | 20\.2 | 29\.5 | 24\.5 | 26\.0 |
139 | | Codex-002 | **57\.0** | **43\.1** | **26\.5** | **41\.8** | **31\.8** | **44\.8** | **39\.3** | **39\.2** |
140 | > `n=40, t=0.2, top_p=0.5` for **Pass@1**。
141 |
142 | The above results can be reproduced by the code in [DS1000 repo](https://github.com/HKUNLP/DS-1000.git).
143 |
144 | ## Inference
145 |
146 | CodeGeeX2 is more friendly to deployment than the previous generation. Thanks to the use of Multi-Query Attention and Flash Attention, the inference speed is faster, and only 6GB of GPU memory is required after INT4 quantization.
147 |
148 | ### Quantization
149 |
150 | | **Model** | FP16/BF16 | INT8 | INT4 |
151 | | :--------------: | :-------: | :-----: | :----: |
152 | | CodeGeeX-13B | 26\.9 GB | 14\.7 GB | - |
153 | | **CodeGeeX2-6B** | 13\.1 GB | 8\.2 GB | 5\.5 GB |
154 | > Based on PyTorch 2.0, using `torch.nn.functional.scaled_dot_product_attention` for effecient attention mechanism。
155 |
156 | ### Acceleration
157 |
158 | | **Model** | **Inference speed (token/s)** |
159 | | :--------------: | :-------------: |
160 | | CodeGeeX-13B | 32 |
161 | | **CodeGeeX2-6B** | 94 |
162 | > `batch_size=1, max_length=2048`, both using acceleration framework, in `GeForce RTX-3090`。
163 |
164 | ## License
165 |
166 | The code in this repository is open source under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) license. The model weights are licensed under the [Model License](MODEL_LICENSE). CodeGeeX2-6B weights are open for academic research, and please apply for commercial use by filling in the [registration form](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B).
167 |
168 |
169 | ## Citation
170 |
171 | If you find our work helpful, please feel free to cite the following paper:
172 |
173 | ```
174 | @inproceedings{zheng2023codegeex,
175 | title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
176 | author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
177 | booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
178 | pages={5673--5684},
179 | year={2023}
180 | }
181 | ```
182 |
--------------------------------------------------------------------------------
/README_FR.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 |
4 | 🏠 Homepage|🛠 Extensions VS Code, Jetbrains|🤗 HF Repo|📄 Paper
5 |
6 |
7 |
8 | 👋 Rejoignez nous sur Discord, Slack, Telegram, WeChat
9 |
10 |
11 | 查看[中文版](README.md)
12 | Read this in [English](README_EN.md)
13 | [日本語](README_JA.md)で読む
14 |
15 | # CodeGeeX2: Un Modèle de Génération de Code Plus Puissant
16 |
17 | CodeGeeX2 est la deuxième itération du modèle de génération de code multilingue [CodeGeeX](https://github.com/THUDM/CodeGeeX) ([KDD’23](https://arxiv.org/abs/2303.17568)), basé sur [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) et entrainé sur un large corpus de code. Grâce à l'architecture ChatGLM2, CodeGeeX2 excelle sur une multitude de tâches de génération de code (+107% > CodeGeeX; avec seulement 6 milliards de paramètres, dépassant StarCoder-15B pour certaines tâches). CodeGeeX2 possède les fonctionnalités suivantes:
18 |
19 | * **Capacités de Génération de Code Accrues**: Basé sur ChatGLM2-6B, CodeGeeX2-6B à été entrainé sur un dataset de 600 milliards de tokens de plus ce qui a propulsé ses capacités de génération de code par rapport à la génération précédente. Sur [HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x), le modèle opère bien mieux que son prédécesseur (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%). En Python, CodeGeeX atteint un score Pass@1 de 35.9%, surpassant StarCoder-15B malgré le fait que CodeGeeX ait ~3 fois moins de paramètres.
20 | * **Des Fonctionnalités Plus Utiles**: Héritant des fonctionnalités de ChatGLM2-6B, CodeGeeX2-6B prend mieux en charge les prompts en chinois et en anglais, peut ingérer jusqu'à 8192 tokens, et se dotte d'une vitesse de génération en inference fortement accrue comparé à la dernière génération. Après quantisation, CodeGeeX fonctionne sur un GPU avec >6GB de mémoire, permettant un déploiement local efficace.
21 | * **Un Assistant Intelligent dans votre Éditeur**: Les plugins ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), et [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)) ont été mis à jour et sont maintenant compatible avec plus de 100 langages de programmation. Le modèle, couplé à l'extension, permet désormais aux utilisateurs de générer du code pour plusieurs fichiers ainsi que de générer et modifier des sections de code. CodeGeeX2 est maintenant capable de résoudre de nombreux problèmes de programmation. Les utilisateurs peuvent profiter de la fonctionnalité "Ask CodeGeeX" pour discuter de manière interactive avec un AI-assistant afin de résumer et d'expliquer du code, traduire du code entre langages, rajouter des commentaires, etc. CodeGeeX permet de maximiser la productivité de ses utilisateurs.
22 | * **License Open-Source**: Les poids du modèle CodeGeeX2-6B sont en accès libre pour toute utilisation dans le cadre de la recherche. Pour toute utilisation commerciale, merci de consulter ce [formulaire](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B).
23 |
24 |
25 | ## Assistant Intelligent
26 |
27 | 
28 |
29 | Nous avons développé une extension pour VS Code, IntelliJ IDEA, PyCharm, GoLand, WebStorm, and Android Studio. L'extension permet de profiter des capacités du modèle CodeGeeX2 et de générer, annoter et traduire du code. La fonctionnalité "Ask CodeGeeX" permet de coder de manière interactive et améliore grandement votre productivité. Téléchargez l'extension CodeGeeX dans votre IDE pour une meilleure expérience de développement. Trouvez plus de détail sur notre [site]( https://codegeex.cn/).
30 |
31 | ## Utilisation
32 |
33 | Pour exécuter [CodeGeeX2-6B](https://huggingface.co/THUDM/codegeex2-6b), utilisez la librairie `transformers`:
34 |
35 | ```python
36 | from transformers import AutoTokenizer, AutoModel
37 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
38 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
39 | model = model.eval()
40 |
41 | # TIP: Utilisez un tag pour identifier le langage dans lequel vous souhaitez générer.
42 | prompt = "# language: Python\n# write a bubble sort function\n"
43 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
44 | outputs = model.generate(inputs, max_length=256, top_k=1)
45 | response = tokenizer.decode(outputs[0])
46 |
47 | >>> print(response)
48 | # language: Python
49 | # write a bubble sort function
50 |
51 |
52 | def bubble_sort(list):
53 | for i in range(len(list) - 1):
54 | for j in range(len(list) - 1):
55 | if list[j] > list[j + 1]:
56 | list[j], list[j + 1] = list[j + 1], list[j]
57 | return list
58 |
59 |
60 | print(bubble_sort([5, 2, 1, 8, 4]))
61 | ```
62 |
63 | Accéder à la démo Gradio:
64 | ```
65 | python ./demo/run_demo.py
66 | ```
67 |
68 | ❗️Attention:
69 | * Cette version de CodeGeeX2 est capable de compléter / expliquer / traduire du code mais n'a pas été fine-tuned pour être utilisé comme un chatbot. Pour accéder à la version chatbot de CodeGeeX, utilisez les extensions [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) et [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex).
70 | * Pour controller le langage dans lequel CodeGeeX2 opère, utilisez des tags formattés ainsi: `# language: Python`. La liste de tous les langages de programmations que CodeGeeX supporte est accessible [ici](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14).
71 | * Si vous avez besoin d'utiliser plusieurs GPU pour charger le modèle, vous pouvez utiliser le code suivant:
72 | ```python
73 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
74 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
75 | model = model.eval()
76 | ```
77 | Remplacer par
78 |
79 | ```python
80 | def get_model():
81 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
82 | from gpus import load_model_on_gpus
83 | # Le fichier "gpus" se trouve dans le dossier de démonstration
84 | model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
85 | model = model.eval()
86 | return tokenizer, model
87 |
88 | tokenizer, model = get_model()
89 | ```
90 |
91 | ## Evaluation
92 |
93 | CodeGeeX2 est un modèle de base capable de générer du code en plusieurs langages de programmation et qui est bien plus performant que la version précédente. Voici les capacités de CodeGeeX sur les benchmarks HumanEval, HumanEval-X, et DS1000 (la métrique Pass@k est la même que celle décrite dans ce [papier](https://arxiv.org/abs/2303.17568)):
94 |
95 | ### HumanEval (Pass@1,10,100)
96 |
97 | | **Model** | **Pass@1** | **Pass@10** | **Pass@100** |
98 | | :-----------------: | :--------: | :---------: | :----------: |
99 | | CodeGen-16B-multi | 19\.2 | 34\.6 | 55\.2 |
100 | | CodeGeeX-13B | 22\.9 | 39\.6 | 60\.9 |
101 | | Codex-12B | 28\.8 | 46\.8 | 72\.3 |
102 | | CodeT5Plus-16B-mono | 30\.9 | 51\.6 | 76\.7 |
103 | | Code-Cushman-001 | 33\.5 | 54\.3 | 77\.4 |
104 | | LLaMA-65B | 23\.7 | - | 79\.3 |
105 | | LLaMA2-70B | 29\.9 | - | - |
106 | | CodeGen2\.5-7B-mono | 33\.4 | 58\.4 | 82\.7 |
107 | | StarCoder-15B | 33\.2 | 61\.0 | 84\.7 |
108 | | **CodeGeeX2-6B** | **35\.9** | **62\.6** | **88\.3** |
109 | > `n=20, t=0.2, top_p=0.95` pour **Pass@1**; `n=200, t=0.8, top_p=0.95` pour **Pass@10** et **Pass@100**.
110 |
111 | ### HumanEval-X (Pass@1)
112 |
113 | | **Model** | **Python** | **C++** | **Java** | **JavaScript** | **Go** | **Rust** | **Overall** |
114 | | :------------------: | :--------: | :-------: | :-------: | :------------: | :-------: | :-------: | :---------: |
115 | | CodeGen-16B-multi | 19\.2 | 18\.1 | 15\.0 | 18\.4 | 13\.0 | 1\.8 | 14\.2 |
116 | | CodeGeeX-13B | 22\.9 | 17\.1 | 20\.0 | 17\.6 | 14\.4 | 4\.3 | 16\.0 |
117 | | Replit-code-v1-3B | 22\.0 | 20\.1 | 20\.1 | 20\.1 | 12\.2 | 8\.6 | 17\.2 |
118 | | CodeGen2\.5-7B-multi | 30\.6 | 24\.3 | 29\.0 | 27\.5 | 18\.9 | **20\.1** | 25\.1 |
119 | | StarCoder-15B | 35\.5 | 28\.2 | **31\.5** | **33\.2** | 21\.3 | 17\.8 | 27\.9 |
120 | | **CodeGeeX2-6B** | **35\.9** | **29\.3** | 30\.8 | 32\.2 | **22\.5** | 18\.1 | **28\.1** |
121 | > `n=20, t=0.2, top_p=0.95` for **Pass@1**.
122 |
123 | Les résultats ci-dessus peuvent être reproduits avec le script `scripts/run_humanevalx.sh`. Les environements utilisés sont renseignés [ici](https://github.com/THUDM/CodeGeeX/blob/main/codegeex/benchmark/README_zh.md).
124 |
125 | ### DS1000 (Pass@1)
126 |
127 | | **Model** | **Matplotlib** | **Numpy** | **Pandas** | **Pytorch** | **SciPy** | **Scikit-learn** | **TensorFlow** | **Overall** |
128 | | :--------------: | :------------: | :-------: | :--------: | :---------: | :-------: | :--------------: | :------------: | :---------: |
129 | | \# Samples | 155 | 220 | 291 | 68 | 106 | 115 | 45 | 1000 |
130 | | CodeGen-16B-Mono | 31\.7 | 10\.9 | 3\.4 | 7\.0 | 9\.0 | 10\.8 | 15\.2 | 11\.7 |
131 | | code-cushman-001 | 40\.7 | 21\.8 | 7\.9 | 12\.4 | 11\.3 | 18\.0 | 12\.2 | 18\.1 |
132 | | Codex-001 | 41\.8 | 26\.6 | 9\.4 | 9\.7 | 15\.0 | 18\.5 | 17\.2 | 20\.2 |
133 | | **CodeGeeX2-6B** | 40\.5 | 25\.5 | 14\.5 | 17\.3 | 19\.3 | 24\.0 | 23\.0 | 23\.1 |
134 | | StarCoder-15B | 51\.7 | 29\.7 | 11\.4 | 21\.4 | 20\.2 | 29\.5 | 24\.5 | 26\.0 |
135 | | Codex-002 | **57\.0** | **43\.1** | **26\.5** | **41\.8** | **31\.8** | **44\.8** | **39\.3** | **39\.2** |
136 | > `n=40, t=0.2, top_p=0.5` for **Pass@1**。
137 |
138 | Les résultats ci-dessus peuvent être reproduits avec le code présent sur le repository [HKUNLP/DS-1000](https://github.com/HKUNLP/DS-1000.git).
139 |
140 | ## Inference
141 |
142 | CodeGeeX2 est bien plus simple à déployer que la génération précédente. L'utilisation de "Multi-Query Attention" et "Flash Attention" accélère grandement la vitesse de génération et le modèle n'a besoin que de 6GB de mémoire après avoir été quantisé en INT4.
143 |
144 | ### Quantisation
145 |
146 | | **Model** | FP16/BF16 | INT8 | INT4 |
147 | | :--------------: | :-------: | :-----: | :----: |
148 | | CodeGeeX-13B | 26\.9 GB | 14\.7 GB | - |
149 | | **CodeGeeX2-6B** | 13\.1 GB | 8\.2 GB | 5\.5 GB |
150 | > Résultats obtenus avec PyTorch 2.0, avec `torch.nn.functional.scaled_dot_product_attention` qui est une version plus rapide du calcul de l'attention.
151 |
152 | ### Accelération
153 |
154 | | **Model** | **Inference speed (token/s)** |
155 | | :--------------: | :-------------: |
156 | | CodeGeeX-13B | 32 |
157 | | **CodeGeeX2-6B** | 94 |
158 | > `batch_size=1, max_length=2048` et en utilisant l'accélération des GPUs `GeForce RTX-3090`。
159 |
160 | ## License
161 |
162 | Le code dans ce dépôt est en libre accès selon les droits et devoirs prévu par la license [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0). Les poids du modèle sont régis par la [license du modèle](MODEL_LICENSE). Les poids du modèle CodeGeeX2-6B sont en accès libre pour toute utilisation dans le cadre de la recherche. Pour toute utilisation commerciale, merci de consulter ce [formulaire](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B).
163 |
164 |
165 | ## Citation
166 |
167 | Si vous trouvez ce projet utile, n'hésitez pas à citer notre papier:
168 |
169 | ```
170 | @inproceedings{zheng2023codegeex,
171 | title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
172 | author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
173 | booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
174 | pages={5673--5684},
175 | year={2023}
176 | }
177 | ```
178 |
--------------------------------------------------------------------------------
/README_JA.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 |
4 | 🏠 ホームページ|🛠 拡張 VS Code, Jetbrains|🤗 HF Repo|📄 論文
5 |
6 |
7 |
8 | 👋 Discord に参加, Slack, Telegram, WeChat
9 |
10 |
11 | 查看[中文版](README.md)
12 | Read this in [English](README_EN.md)
13 | Lire en [Français](README_FR.md)
14 |
15 | # CodeGeeX2: より強力な多言語コード生成モデル
16 |
17 | CodeGeeX2 は、多言語コード生成モデル [CodeGeeX](https://github.com/THUDM/CodeGeeX)([KDD'23](https://arxiv.org/abs/2303.17568)) の第 2 世代モデルであり、より多くのコードデータで学習された [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) アーキテクチャに基づいて実装されています。ChatGLM2 のアドバンテージにより、CodeGeeX2 のコーディング能力は包括的に向上しています(+107% > CodeGeeX; わずか 6B のパラメータで、いくつかのタスクではより大規模な StarCoder-15B を凌駕しています)。以下の特徴があります:
18 |
19 | * **より強力なコーディング機能**: CodeGeeX2-6B は、ChatGLM2-6B モデルをベースに、さらに 600B のコードトークンに対して事前学習を行っており、第一世代と比較してコーディング能力が総合的に向上しています。[HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) ベンチマークでは、6 言語すべてで大幅な改善が見られ(Python +57%、C++ +71%、Java +54%、JavaScript +83%、Go +56%、Rust +321%)、Python では Pass@1 一回合格率 35.9% に達し、より大規模な StarCoder-15B を上回りました。
20 | * **その他の便利な機能**: ChatGLM2-6B モデルの特徴を継承し、CodeGeeX2-6B は中国語と英語のプロンプト、最大 8192 シーケンス長をサポートし、推論速度は第一世代と比較して大幅に改善されています。量子化後、推論に必要な GPU メモリは 6GB のみで、軽量なローカル展開をサポートします。
21 | * **包括的な AI コーディングアシスタント**: CodeGeeX プラグイン([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex)、[Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))のバックエンドがアップグレードされ、100 以上のプログラミング言語をサポートし、インフィルやクロスファイル補完などの実用的な機能が追加されました。対話型 AI コーディングアシスタント "Ask CodeGeeX" と組み合わせることで、中国語または英語の対話を通じて、コードの要約、コードの翻訳、デバッグ、コメント生成など、さまざまなプログラミング問題を解決することができ、開発者の作業効率を高めることができます。
22 | * **オープンライセンス**: CodeGeeX2-6B ウェイトは学術研究に全面的に開放しています。商用利用をご希望の方は、[登録フォーム](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B)にご記入の上、お申し込みください。
23 |
24 |
25 | ## AI コーディングアシスタント
26 |
27 | 
28 |
29 | VS Code、IntelliJ IDEA、PyCharm、GoLand、WebStorm、Android Studio などの IDE をサポートする CodeGeeX プラグインを開発しました。このプラグインを使用することで、CodeGeeX2 モデルのコード生成と補完、アノテーション、コード変換、"Ask CodeGeeX" 対話型プログラミングなどの機能を体験することができ、開発効率を向上させることができます。より包括的な AI コーディング体験を得るために、IDE に CodeGeeX プラグインをダウンロードしてください。詳しくは[ホームページ](https://codegeex.cn/)をご覧ください。
30 |
31 | ## 始める
32 |
33 | [CodeGeeX2-6B](https://huggingface.co/THUDM/codegeex2-6b) を素早く起動するには、`transformers` を使用します:
34 |
35 | ```python
36 | from transformers import AutoTokenizer, AutoModel
37 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
38 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
39 | model = model.eval()
40 |
41 | # remember adding a language tag for better performance
42 | prompt = "# language: Python\n# write a bubble sort function\n"
43 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
44 | outputs = model.generate(inputs, max_length=256, top_k=1)
45 | response = tokenizer.decode(outputs[0])
46 |
47 | >>> print(response)
48 | # language: Python
49 | # write a bubble sort function
50 |
51 |
52 | def bubble_sort(list):
53 | for i in range(len(list) - 1):
54 | for j in range(len(list) - 1):
55 | if list[j] > list[j + 1]:
56 | list[j], list[j + 1] = list[j + 1], list[j]
57 | return list
58 |
59 |
60 | print(bubble_sort([5, 2, 1, 8, 4]))
61 | ```
62 |
63 | Gradio DEMO の起動:
64 | ```
65 | python ./demo/run_demo.py
66 | ```
67 |
68 | ❗️注意:
69 | * CodeGeeX2 はベースモデルであり、チャット用の命令チューニングはされていません。コード補完/翻訳/説明のようなタスクは可能です。CodeGeeX のプラグイン([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))で命令チューニングされたバージョンを試すことができます。
70 | * プログラミング言語は、`# language: Python` のように `language tag` を追加することで制御できます。パフォーマンスを確保するため、書式を守る必要があります。完全なリストは[こちら](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14)にあります。より良い結果を得るためには、選択したプログラミング言語のフォーマットでコメントを書いてください。
71 | * 複数のグラフィックカードを使用してモデルをロードする必要がある場合は、以下のコードを使用できます:
72 | ```python
73 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
74 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
75 | model = model.eval()
76 | ```
77 | をに置き換えてください
78 |
79 | ```python
80 | def get_model():
81 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
82 | from gpus import load_model_on_gpus
83 | # gpusファイルはdemoフォルダにあります
84 | model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
85 | model = model.eval()
86 | return tokenizer, model
87 |
88 | tokenizer, model = get_model()
89 | ```
90 | ## 評価
91 |
92 | CodeGeeX2 は多言語コード生成のベースモデルであり、前世代と比較してコーディング能力が大幅に向上しています。HumanEval、HumanEval-X、DS1000 ベンチマークでの評価結果を以下に示します(評価指標 Pass@k は[論文](https://arxiv.org/abs/2303.17568)と同じです):
93 |
94 | ### HumanEval (Pass@1,10,100)
95 |
96 | | **Model** | **Pass@1** | **Pass@10** | **Pass@100** |
97 | | :-----------------: | :--------: | :---------: | :----------: |
98 | | CodeGen-16B-multi | 19\.2 | 34\.6 | 55\.2 |
99 | | CodeGeeX-13B | 22\.9 | 39\.6 | 60\.9 |
100 | | Codex-12B | 28\.8 | 46\.8 | 72\.3 |
101 | | CodeT5Plus-16B-mono | 30\.9 | 51\.6 | 76\.7 |
102 | | Code-Cushman-001 | 33\.5 | 54\.3 | 77\.4 |
103 | | LLaMA-65B | 23\.7 | - | 79\.3 |
104 | | LLaMA2-70B | 29\.9 | - | - |
105 | | CodeGen2\.5-7B-mono | 33\.4 | 58\.4 | 82\.7 |
106 | | StarCoder-15B | 33\.2 | 61\.0 | 84\.7 |
107 | | **CodeGeeX2-6B** | **35\.9** | **62\.6** | **88\.3** |
108 | > **Pass@1** 使用 `n=20, t=0.2, top_p=0.95`; **Pass@10** および **Pass@100** を使用 `n=200, t=0.8, top_p=0.95`。
109 |
110 | ### HumanEval-X (Pass@1)
111 |
112 | | **Model** | **Python** | **C++** | **Java** | **JavaScript** | **Go** | **Rust** | **Overall** |
113 | | :------------------: | :--------: | :-------: | :-------: | :------------: | :-------: | :-------: | :---------: |
114 | | CodeGen-16B-multi | 19\.2 | 18\.1 | 15\.0 | 18\.4 | 13\.0 | 1\.8 | 14\.2 |
115 | | CodeGeeX-13B | 22\.9 | 17\.1 | 20\.0 | 17\.6 | 14\.4 | 4\.3 | 16\.0 |
116 | | Replit-code-v1-3B | 22\.0 | 20\.1 | 20\.1 | 20\.1 | 12\.2 | 8\.6 | 17\.2 |
117 | | CodeGen2\.5-7B-multi | 30\.6 | 24\.3 | 29\.0 | 27\.5 | 18\.9 | **20\.1** | 25\.1 |
118 | | StarCoder-15B | 35\.5 | 28\.2 | **31\.5** | **33\.2** | 21\.3 | 17\.8 | 27\.9 |
119 | | **CodeGeeX2-6B** | **35\.9** | **29\.3** | 30\.8 | 32\.2 | **22\.5** | 18\.1 | **28\.1** |
120 | > **Pass@1** 使用 `n=20, t=0.2, top_p=0.95`。
121 |
122 | 上記の結果は `scripts/run_humanevalx.sh` を実行することで再現できる。実験の設定は [HumanEval-X 環境](https://github.com/THUDM/CodeGeeX/blob/main/codegeex/benchmark/README_zh.md)を参照してください。
123 |
124 | ### DS1000 (Pass@1)
125 |
126 | | **Model** | **Matplotlib** | **Numpy** | **Pandas** | **Pytorch** | **SciPy** | **Scikit-learn** | **TensorFlow** | **Overall** |
127 | | :--------------: | :------------: | :-------: | :--------: | :---------: | :-------: | :--------------: | :------------: | :---------: |
128 | | \# Samples | 155 | 220 | 291 | 68 | 106 | 115 | 45 | 1000 |
129 | | CodeGen-16B-Mono | 31\.7 | 10\.9 | 3\.4 | 7\.0 | 9\.0 | 10\.8 | 15\.2 | 11\.7 |
130 | | code-cushman-001 | 40\.7 | 21\.8 | 7\.9 | 12\.4 | 11\.3 | 18\.0 | 12\.2 | 18\.1 |
131 | | Codex-001 | 41\.8 | 26\.6 | 9\.4 | 9\.7 | 15\.0 | 18\.5 | 17\.2 | 20\.2 |
132 | | **CodeGeeX2-6B** | 40\.5 | 25\.5 | 14\.5 | 17\.3 | 19\.3 | 24\.0 | 23\.0 | 23\.1 |
133 | | StarCoder-15B | 51\.7 | 29\.7 | 11\.4 | 21\.4 | 20\.2 | 29\.5 | 24\.5 | 26\.0 |
134 | | Codex-002 | **57\.0** | **43\.1** | **26\.5** | **41\.8** | **31\.8** | **44\.8** | **39\.3** | **39\.2** |
135 | > **Pass@1** 使用 `n=40, t=0.2, top_p=0.5`。
136 |
137 | 上記の結果は [DS1000 repo](https://github.com/HKUNLP/DS-1000.git) のコードで再現できる。
138 |
139 | ## 推論
140 |
141 | CodeGeeX2 は、前世代よりも導入が容易になりました。マルチクエリーアテンションとフラッシュアテンションの使用により、推論速度が速くなり、INT4 量子化後に必要な GPU メモリは 6GB のみです。
142 |
143 | ### 量子化
144 |
145 | | **Model** | FP16/BF16 | INT8 | INT4 |
146 | | :--------------: | :-------: | :-----: | :----: |
147 | | CodeGeeX-13B | 26\.9 GB | 14\.7 GB | - |
148 | | **CodeGeeX2-6B** | 13\.1 GB | 8\.2 GB | 5\.5 GB |
149 | > PyTorch 2.0に基づき、`torch.nn.functional.scaled_dot_product_attention` を使用して、効率的なアテンションメカニズムを実現。
150 |
151 | ### 加速
152 |
153 | | **Model** | **推論速度 (token/秒)** |
154 | | :--------------: | :-------------: |
155 | | CodeGeeX-13B | 32 |
156 | | **CodeGeeX2-6B** | 94 |
157 | > `batch_size=1, max_length=2048`, どちらもアクセラレーションフレームワークを使用、`GeForce RTX-3090` の場合。
158 |
159 | ## ライセンス
160 |
161 | このリポジトリのコードは、[Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) ライセンスの下でのオープンソースです。モデルのウェイトは [Model License](MODEL_LICENSE) に基づいてライセンスされています。CodeGeeX2-6B のウェイトは学術研究用に公開されています。商用利用を希望される方は、[登録フォーム](https://open.bigmodel.cn/mla/form?mcode=CodeGeeX2-6B)にご記入の上、お申し込みください。
162 |
163 |
164 | ## 引用
165 |
166 | 私たちの研究がお役に立ちましたら、ぜひ以下の論文を引用してください:
167 |
168 | ```
169 | @inproceedings{zheng2023codegeex,
170 | title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
171 | author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
172 | booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
173 | pages={5673--5684},
174 | year={2023}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/benchmark/humanevalx/go/go.mod:
--------------------------------------------------------------------------------
1 | module humanEval
2 |
3 | go 1.18
4 |
5 | require (
6 | github.com/go-openapi/inflect v0.19.0
7 | github.com/stretchr/testify v1.8.0
8 | )
9 |
10 | require (
11 | github.com/davecgh/go-spew v1.1.1 // indirect
12 | github.com/pmezard/go-difflib v1.0.0 // indirect
13 | gopkg.in/yaml.v3 v3.0.1 // indirect
14 | )
15 |
--------------------------------------------------------------------------------
/benchmark/humanevalx/go/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4 | github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNPXu/4=
5 | github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4=
6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
9 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
10 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
11 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
12 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
16 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
17 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
18 |
--------------------------------------------------------------------------------
/benchmark/humanevalx/go/vendor.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/go/vendor.tar.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_cpp.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_cpp.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_go.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_go.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_java.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_java.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_js.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_js.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_python.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_python.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/humanevalx_rust.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/benchmark/humanevalx/humanevalx_rust.jsonl.gz
--------------------------------------------------------------------------------
/benchmark/humanevalx/rust/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 3
4 |
5 | [[package]]
6 | name = "aho-corasick"
7 | version = "0.7.20"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
10 | dependencies = [
11 | "memchr",
12 | ]
13 |
14 | [[package]]
15 | name = "fuchsia-cprng"
16 | version = "0.1.1"
17 | source = "registry+https://github.com/rust-lang/crates.io-index"
18 | checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba"
19 |
20 | [[package]]
21 | name = "libc"
22 | version = "0.2.139"
23 | source = "registry+https://github.com/rust-lang/crates.io-index"
24 | checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
25 |
26 | [[package]]
27 | name = "md5"
28 | version = "0.7.0"
29 | source = "registry+https://github.com/rust-lang/crates.io-index"
30 | checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
31 |
32 | [[package]]
33 | name = "memchr"
34 | version = "2.5.0"
35 | source = "registry+https://github.com/rust-lang/crates.io-index"
36 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
37 |
38 | [[package]]
39 | name = "rand"
40 | version = "0.4.6"
41 | source = "registry+https://github.com/rust-lang/crates.io-index"
42 | checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293"
43 | dependencies = [
44 | "fuchsia-cprng",
45 | "libc",
46 | "rand_core 0.3.1",
47 | "rdrand",
48 | "winapi",
49 | ]
50 |
51 | [[package]]
52 | name = "rand_core"
53 | version = "0.3.1"
54 | source = "registry+https://github.com/rust-lang/crates.io-index"
55 | checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
56 | dependencies = [
57 | "rand_core 0.4.2",
58 | ]
59 |
60 | [[package]]
61 | name = "rand_core"
62 | version = "0.4.2"
63 | source = "registry+https://github.com/rust-lang/crates.io-index"
64 | checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc"
65 |
66 | [[package]]
67 | name = "rdrand"
68 | version = "0.4.0"
69 | source = "registry+https://github.com/rust-lang/crates.io-index"
70 | checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2"
71 | dependencies = [
72 | "rand_core 0.3.1",
73 | ]
74 |
75 | [[package]]
76 | name = "regex"
77 | version = "1.7.1"
78 | source = "registry+https://github.com/rust-lang/crates.io-index"
79 | checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
80 | dependencies = [
81 | "aho-corasick",
82 | "memchr",
83 | "regex-syntax",
84 | ]
85 |
86 | [[package]]
87 | name = "regex-syntax"
88 | version = "0.6.28"
89 | source = "registry+https://github.com/rust-lang/crates.io-index"
90 | checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
91 |
92 | [[package]]
93 | name = "rust"
94 | version = "0.1.0"
95 | dependencies = [
96 | "md5",
97 | "rand",
98 | "regex",
99 | ]
100 |
101 | [[package]]
102 | name = "winapi"
103 | version = "0.3.9"
104 | source = "registry+https://github.com/rust-lang/crates.io-index"
105 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
106 | dependencies = [
107 | "winapi-i686-pc-windows-gnu",
108 | "winapi-x86_64-pc-windows-gnu",
109 | ]
110 |
111 | [[package]]
112 | name = "winapi-i686-pc-windows-gnu"
113 | version = "0.4.0"
114 | source = "registry+https://github.com/rust-lang/crates.io-index"
115 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
116 |
117 | [[package]]
118 | name = "winapi-x86_64-pc-windows-gnu"
119 | version = "0.4.0"
120 | source = "registry+https://github.com/rust-lang/crates.io-index"
121 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
122 |
--------------------------------------------------------------------------------
/benchmark/humanevalx/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rust"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 |
8 | [dependencies]
9 | rand = "0.4"
10 | regex = "1"
11 | md5 = "0.7.0"
12 |
13 |
--------------------------------------------------------------------------------
/demo/example_inputs.jsonl:
--------------------------------------------------------------------------------
1 | {"code": "# Write a quick sort function\n", "langauge": "Python"}
2 | {"code": "// 写一个冒泡排序函数\n", "langauge": "C++"}
3 | {"code": "// 写一个二叉树的类\npublic class", "langauge": "Java"}
4 | {"code": "// 矩阵求行列式\n", "langauge": "Matlab"}
5 | {"code": "\n", "langauge": "HTML"}
6 | {"code": "// 写一个服务器框架, 接收浏览器发过来的请求,并返回处理后的内容\n", "langauge": "JavaScript"}
7 | {"code": "// Write a binary search function\n", "langauge": "Rust"}
8 | {"code": "-- 查询品类最多的三种食品,但是不包含出现时间在2014年及之后\n", "langauge": "SQL"}
9 | {"code": "// Write a simple file system that allows parallel read/write in Golang\n", "langauge": "Go"}
--------------------------------------------------------------------------------
/demo/fastapicpu.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from transformers import AutoTokenizer, AutoModel
3 | import uvicorn, json, datetime
4 | import torch
5 | import argparse
6 |
7 | try:
8 | import chatglm_cpp
9 | enable_chatglm_cpp = True
10 | except:
11 | print("[WARN] chatglm-cpp not found. Install it by `pip install chatglm-cpp` for better performance. "
12 | "Check out https://github.com/li-plus/chatglm.cpp for more details.")
13 | enable_chatglm_cpp = False
14 |
15 |
16 | #获取选项
17 | def add_code_generation_args(parser):
18 | group = parser.add_argument_group(title="CodeGeeX2 DEMO")
19 | group.add_argument(
20 | "--model-path",
21 | type=str,
22 | default="THUDM/codegeex2-6b",
23 | )
24 | group.add_argument(
25 | "--listen",
26 | type=str,
27 | default="127.0.0.1",
28 | )
29 | group.add_argument(
30 | "--port",
31 | type=int,
32 | default=7860,
33 | )
34 | group.add_argument(
35 | "--workers",
36 | type=int,
37 | default=1,
38 | )
39 | group.add_argument(
40 | "--cpu",
41 | action="store_true",
42 | )
43 | group.add_argument(
44 | "--half",
45 | action="store_true",
46 | )
47 | group.add_argument(
48 | "--quantize",
49 | type=int,
50 | default=None,
51 | )
52 | group.add_argument(
53 | "--chatglm-cpp",
54 | action="store_true",
55 | )
56 | return parser
57 |
58 | LANGUAGE_TAG = {
59 | "Abap" : "* language: Abap",
60 | "ActionScript" : "// language: ActionScript",
61 | "Ada" : "-- language: Ada",
62 | "Agda" : "-- language: Agda",
63 | "ANTLR" : "// language: ANTLR",
64 | "AppleScript" : "-- language: AppleScript",
65 | "Assembly" : "; language: Assembly",
66 | "Augeas" : "// language: Augeas",
67 | "AWK" : "// language: AWK",
68 | "Basic" : "' language: Basic",
69 | "C" : "// language: C",
70 | "C#" : "// language: C#",
71 | "C++" : "// language: C++",
72 | "CMake" : "# language: CMake",
73 | "Cobol" : "// language: Cobol",
74 | "CSS" : "/* language: CSS */",
75 | "CUDA" : "// language: Cuda",
76 | "Dart" : "// language: Dart",
77 | "Delphi" : "{language: Delphi}",
78 | "Dockerfile" : "# language: Dockerfile",
79 | "Elixir" : "# language: Elixir",
80 | "Erlang" : f"% language: Erlang",
81 | "Excel" : "' language: Excel",
82 | "F#" : "// language: F#",
83 | "Fortran" : "!language: Fortran",
84 | "GDScript" : "# language: GDScript",
85 | "GLSL" : "// language: GLSL",
86 | "Go" : "// language: Go",
87 | "Groovy" : "// language: Groovy",
88 | "Haskell" : "-- language: Haskell",
89 | "HTML" : "",
90 | "Isabelle" : "(*language: Isabelle*)",
91 | "Java" : "// language: Java",
92 | "JavaScript" : "// language: JavaScript",
93 | "Julia" : "# language: Julia",
94 | "Kotlin" : "// language: Kotlin",
95 | "Lean" : "-- language: Lean",
96 | "Lisp" : "; language: Lisp",
97 | "Lua" : "// language: Lua",
98 | "Markdown" : "",
99 | "Matlab" : f"% language: Matlab",
100 | "Objective-C" : "// language: Objective-C",
101 | "Objective-C++": "// language: Objective-C++",
102 | "Pascal" : "// language: Pascal",
103 | "Perl" : "# language: Perl",
104 | "PHP" : "// language: PHP",
105 | "PowerShell" : "# language: PowerShell",
106 | "Prolog" : f"% language: Prolog",
107 | "Python" : "# language: Python",
108 | "R" : "# language: R",
109 | "Racket" : "; language: Racket",
110 | "RMarkdown" : "# language: RMarkdown",
111 | "Ruby" : "# language: Ruby",
112 | "Rust" : "// language: Rust",
113 | "Scala" : "// language: Scala",
114 | "Scheme" : "; language: Scheme",
115 | "Shell" : "# language: Shell",
116 | "Solidity" : "// language: Solidity",
117 | "SPARQL" : "# language: SPARQL",
118 | "SQL" : "-- language: SQL",
119 | "Swift" : "// language: swift",
120 | "TeX" : f"% language: TeX",
121 | "Thrift" : "/* language: Thrift */",
122 | "TypeScript" : "// language: TypeScript",
123 | "Vue" : "",
124 | "Verilog" : "// language: Verilog",
125 | "Visual Basic" : "' language: Visual Basic",
126 | }
127 |
128 | app = FastAPI()
129 | def device():
130 | if enable_chatglm_cpp and args.chatglm_cpp:
131 | print("Using chatglm-cpp to improve performance")
132 | dtype = "f16" if args.half else "f32"
133 | if args.quantize in [4, 5, 8]:
134 | dtype = f"q{args.quantize}_0"
135 | model = chatglm_cpp.Pipeline(args.model_path, dtype=dtype)
136 | return model
137 |
138 | print("chatglm-cpp not enabled, falling back to transformers")
139 | if not args.cpu:
140 | if not args.half:
141 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).cuda()
142 | else:
143 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).cuda().half()
144 | if args.quantize in [4, 8]:
145 | print(f"Model is quantized to INT{args.quantize} format.")
146 | model = model.half().quantize(args.quantize)
147 | else:
148 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
149 |
150 | return model.eval()
151 |
152 | @app.post("/")
153 | async def create_item(request: Request):
154 | global model, tokenizer
155 | json_post_raw = await request.json()
156 | json_post = json.dumps(json_post_raw)
157 | json_post_list = json.loads(json_post)
158 | lang = json_post_list.get('lang')
159 | prompt = json_post_list.get('prompt')
160 | max_length = json_post_list.get('max_length', 128)
161 | top_p = json_post_list.get('top_p', 0.95)
162 | temperature = json_post_list.get('temperature', 0.2)
163 | top_k = json_post_list.get('top_k', 0)
164 | if lang != "None":
165 | prompt = LANGUAGE_TAG[lang] + "\n" + prompt
166 | if enable_chatglm_cpp and args.chatglm_cpp:
167 | response = model.generate(prompt,
168 | max_length=max_length,
169 | do_sample=temperature > 0,
170 | top_p=top_p,
171 | top_k=top_k,
172 | temperature=temperature)
173 | else:
174 | response = model.chat(tokenizer,
175 | prompt,
176 | max_length=max_length,
177 | top_p=top_p,
178 | top_k=top_k,
179 | temperature=temperature)
180 | now = datetime.datetime.now()
181 | time = now.strftime("%Y-%m-%d %H:%M:%S")
182 | answer = {
183 | "response": response,
184 | "lang": lang,
185 | "status": 200,
186 | "time": time
187 | }
188 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
189 | print(log)
190 |
191 | return answer
192 |
193 |
194 | if __name__ == '__main__':
195 | parser = argparse.ArgumentParser()
196 | parser = add_code_generation_args(parser)
197 | args, _ = parser.parse_known_args()
198 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
199 | model = device()
200 | uvicorn.run(app, host=args.listen, port=args.port, workers=args.workers)
201 |
--------------------------------------------------------------------------------
/demo/gpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Tuple, Union, Optional
3 |
4 | from torch.nn import Module
5 | from transformers import AutoModel
6 |
7 |
8 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
9 | # transformer.word_embeddings 占用1层
10 | # transformer.final_layernorm 和 lm_head 占用1层
11 | # transformer.layers 占用 28 层
12 | # 总共30层分配到num_gpus张卡上
13 | num_trans_layers = 28
14 | per_gpu_layers = 30 / num_gpus
15 |
16 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
17 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
18 | # linux下 model.device 会被设置成 lm_head.device
19 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
20 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
21 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
22 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
23 | # 仅此处做少许修改以支持ChatGLM2,CodeGeeX2
24 | device_map = {
25 | 'transformer.embedding.word_embeddings': 0,
26 | 'transformer.encoder.final_layernorm': 0,
27 | 'transformer.output_layer': 0,
28 | 'transformer.rotary_pos_emb': 0,
29 | 'lm_head': 0
30 | }
31 |
32 | used = 2
33 | gpu_target = 0
34 | for i in range(num_trans_layers):
35 | if used >= per_gpu_layers:
36 | gpu_target += 1
37 | used = 0
38 | assert gpu_target < num_gpus
39 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target
40 | used += 1
41 |
42 | return device_map
43 |
44 |
45 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
46 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
47 | if num_gpus < 2 and device_map is None:
48 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
49 | else:
50 | from accelerate import dispatch_model
51 |
52 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
53 |
54 | if device_map is None:
55 | device_map = auto_configure_device_map(num_gpus)
56 |
57 | model = dispatch_model(model, device_map=device_map)
58 |
59 | return model
60 |
--------------------------------------------------------------------------------
/demo/run_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy
4 | import torch
5 | import random
6 | import argparse
7 | import gradio as gr
8 |
9 | from transformers import AutoTokenizer, AutoModel
10 |
11 | try:
12 | # Should first install fastllm (https://github.com/ztxz16/fastllm.git)
13 | from fastllm_pytools import llm
14 | enable_fastllm = True
15 | except:
16 | print("fastllm disabled.")
17 | enable_fastllm = False
18 |
19 | try:
20 | from gpus import load_model_on_gpus
21 | enable_multiple_gpus = True
22 | except:
23 | print("Multiple GPUs support disabled.")
24 | enable_multiple_gpus = False
25 |
26 | try:
27 | import chatglm_cpp
28 | enable_chatglm_cpp = True
29 | except:
30 | print("[WARN] chatglm-cpp not found. Install it by `pip install chatglm-cpp` for better performance. "
31 | "Check out https://github.com/li-plus/chatglm.cpp for more details.")
32 | enable_chatglm_cpp = False
33 |
34 |
35 | def get_model(args):
36 | if not args.cpu:
37 | if torch.cuda.is_available():
38 | device = f"cuda:{args.gpu}"
39 | elif torch.backends.mps.is_built():
40 | device = "mps"
41 | else:
42 | device = "cpu"
43 | else:
44 | device = "cpu"
45 |
46 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
47 |
48 | if args.n_gpus > 1 and enable_multiple_gpus:
49 | # 如需实现多显卡模型加载,传入"n_gpus"为需求的显卡数量 / To enable Multiple GPUs model loading, please adjust "n_gpus" to the desired number of graphics cards.
50 | print(f"Runing on {args.n_gpus} GPUs.")
51 | model = load_model_on_gpus(args.model_path, num_gpus=args.n_gpus)
52 | model = model.eval()
53 | elif enable_chatglm_cpp and args.chatglm_cpp:
54 | print("Using chatglm-cpp to improve performance")
55 | dtype = "f16"
56 | if args.quantize in [4, 5, 8]:
57 | dtype = f"q{args.quantize}_0"
58 | model = chatglm_cpp.Pipeline(args.model_path, dtype=dtype)
59 | else:
60 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
61 | model = model.eval()
62 |
63 | if enable_fastllm and args.fastllm:
64 | print("fastllm enabled.")
65 | model = model.half()
66 | llm.set_device_map(device)
67 | if args.quantize in [4, 8]:
68 | model = llm.from_hf(model, dtype=f"int{args.quantize}")
69 | else:
70 | model = llm.from_hf(model, dtype="float16")
71 | else:
72 | print("chatglm-cpp and fastllm not installed, using transformers.")
73 | if args.quantize in [4, 8]:
74 | print(f"Model is quantized to INT{args.quantize} format.")
75 | model = model.half().quantize(args.quantize)
76 | model = model.to(device)
77 |
78 | return tokenizer, model
79 |
80 |
81 | def add_code_generation_args(parser):
82 | group = parser.add_argument_group(title="CodeGeeX2 DEMO")
83 | group.add_argument(
84 | "--model-path",
85 | type=str,
86 | default="THUDM/codegeex2-6b",
87 | )
88 | group.add_argument(
89 | "--example-path",
90 | type=str,
91 | default=None,
92 | )
93 | group.add_argument(
94 | "--quantize",
95 | type=int,
96 | default=None,
97 | )
98 | group.add_argument(
99 | "--chatglm-cpp",
100 | action="store_true",
101 | )
102 | group.add_argument(
103 | "--fastllm",
104 | action="store_true",
105 | )
106 | group.add_argument(
107 | "--n-gpus",
108 | type=int,
109 | default=1,
110 | )
111 | group.add_argument(
112 | "--gpu",
113 | type=int,
114 | default=0,
115 | )
116 | group.add_argument(
117 | "--cpu",
118 | action="store_true",
119 | )
120 | group.add_argument(
121 | "--listen",
122 | type=str,
123 | default="127.0.0.1",
124 | )
125 | group.add_argument(
126 | "--port",
127 | type=int,
128 | default=7860,
129 | )
130 | group.add_argument(
131 | "--username",
132 | type=str,
133 | default=None,
134 | )
135 | group.add_argument(
136 | "--password",
137 | type=str,
138 | default=None,
139 | )
140 | group.add_argument(
141 | "--auth",
142 | action="store_true",
143 | )
144 |
145 |
146 | return parser
147 |
148 |
149 | # 更完编程语言列表请查看 evaluation/utils.py / Full list of supported languages in evaluation/utils.py
150 | LANGUAGE_TAG = {
151 | "Abap" : "* language: Abap",
152 | "ActionScript" : "// language: ActionScript",
153 | "Ada" : "-- language: Ada",
154 | "Agda" : "-- language: Agda",
155 | "ANTLR" : "// language: ANTLR",
156 | "AppleScript" : "-- language: AppleScript",
157 | "Assembly" : "; language: Assembly",
158 | "Augeas" : "// language: Augeas",
159 | "AWK" : "// language: AWK",
160 | "Basic" : "' language: Basic",
161 | "C" : "// language: C",
162 | "C#" : "// language: C#",
163 | "C++" : "// language: C++",
164 | "CMake" : "# language: CMake",
165 | "Cobol" : "// language: Cobol",
166 | "CSS" : "/* language: CSS */",
167 | "CUDA" : "// language: Cuda",
168 | "Dart" : "// language: Dart",
169 | "Delphi" : "{language: Delphi}",
170 | "Dockerfile" : "# language: Dockerfile",
171 | "Elixir" : "# language: Elixir",
172 | "Erlang" : f"% language: Erlang",
173 | "Excel" : "' language: Excel",
174 | "F#" : "// language: F#",
175 | "Fortran" : "!language: Fortran",
176 | "GDScript" : "# language: GDScript",
177 | "GLSL" : "// language: GLSL",
178 | "Go" : "// language: Go",
179 | "Groovy" : "// language: Groovy",
180 | "Haskell" : "-- language: Haskell",
181 | "HTML" : "",
182 | "Isabelle" : "(*language: Isabelle*)",
183 | "Java" : "// language: Java",
184 | "JavaScript" : "// language: JavaScript",
185 | "Julia" : "# language: Julia",
186 | "Kotlin" : "// language: Kotlin",
187 | "Lean" : "-- language: Lean",
188 | "Lisp" : "; language: Lisp",
189 | "Lua" : "// language: Lua",
190 | "Markdown" : "",
191 | "Matlab" : f"% language: Matlab",
192 | "Objective-C" : "// language: Objective-C",
193 | "Objective-C++": "// language: Objective-C++",
194 | "Pascal" : "// language: Pascal",
195 | "Perl" : "# language: Perl",
196 | "PHP" : "// language: PHP",
197 | "PowerShell" : "# language: PowerShell",
198 | "Prolog" : f"% language: Prolog",
199 | "Python" : "# language: Python",
200 | "R" : "# language: R",
201 | "Racket" : "; language: Racket",
202 | "RMarkdown" : "# language: RMarkdown",
203 | "Ruby" : "# language: Ruby",
204 | "Rust" : "// language: Rust",
205 | "Scala" : "// language: Scala",
206 | "Scheme" : "; language: Scheme",
207 | "Shell" : "# language: Shell",
208 | "Solidity" : "// language: Solidity",
209 | "SPARQL" : "# language: SPARQL",
210 | "SQL" : "-- language: SQL",
211 | "Swift" : "// language: swift",
212 | "TeX" : f"% language: TeX",
213 | "Thrift" : "/* language: Thrift */",
214 | "TypeScript" : "// language: TypeScript",
215 | "Vue" : "",
216 | "Verilog" : "// language: Verilog",
217 | "Visual Basic" : "' language: Visual Basic",
218 | }
219 |
220 |
221 | def set_random_seed(seed):
222 | """Set random seed for reproducability."""
223 | random.seed(seed)
224 | numpy.random.seed(seed)
225 | torch.manual_seed(seed)
226 |
227 |
228 | def main():
229 | parser = argparse.ArgumentParser()
230 | parser = add_code_generation_args(parser)
231 | args, _ = parser.parse_known_args()
232 |
233 | tokenizer, model = get_model(args)
234 |
235 | examples = []
236 | if args.example_path is None:
237 | example_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "example_inputs.jsonl")
238 | else:
239 | example_path = args.example_path
240 |
241 | # Load examples for gradio DEMO
242 | with open(example_path, "r", encoding="utf-8") as f:
243 | for line in f:
244 | examples.append(list(json.loads(line).values()))
245 |
246 |
247 | def predict(
248 | prompt,
249 | lang,
250 | seed,
251 | out_seq_length,
252 | temperature,
253 | top_k,
254 | top_p,
255 | ):
256 | set_random_seed(seed)
257 | if lang != "None":
258 | prompt = LANGUAGE_TAG[lang] + "\n" + prompt
259 |
260 | if enable_fastllm and args.fastllm:
261 | model.direct_query = True
262 | outputs = model.chat(tokenizer,
263 | prompt,
264 | max_length=out_seq_length,
265 | top_p=top_p,
266 | top_k=top_k,
267 | temperature=temperature)
268 | response = prompt + outputs[0]
269 | elif enable_chatglm_cpp and args.chatglm_cpp:
270 | inputs = tokenizer([prompt], return_tensors="pt")
271 | pipeline = model
272 | outputs = pipeline.generate(prompt,
273 | max_length=inputs['input_ids'].shape[-1] + out_seq_length,
274 | do_sample=temperature > 0,
275 | top_p=top_p,
276 | top_k=top_k,
277 | temperature=temperature)
278 | response = prompt + outputs
279 | else:
280 | inputs = tokenizer([prompt], return_tensors="pt")
281 | inputs = inputs.to(model.device)
282 | outputs = model.generate(**inputs,
283 | max_length=inputs['input_ids'].shape[-1] + out_seq_length,
284 | do_sample=True,
285 | top_p=top_p,
286 | top_k=top_k,
287 | temperature=temperature,
288 | pad_token_id=2,
289 | eos_token_id=2)
290 | response = tokenizer.decode(outputs[0])
291 |
292 | return response
293 |
294 | with gr.Blocks(title="CodeGeeX2 DEMO") as demo:
295 | gr.Markdown(
296 | """
297 |
298 |
299 |
300 | """)
301 | gr.Markdown(
302 | """
303 |
304 | 🏠 Homepage|💻 GitHub|🛠 Tools VS Code, Jetbrains|🤗 Download|📄 Paper
305 |
306 | """)
307 | gr.Markdown(
308 | """
309 | 这是 CodeGeeX2 的简易DEMO。请注意:
310 | * CodeGeeX2 是一个基座模型,它可以完成代码补全/翻译/解释等任务,没有针对聊天进行指令微调。可以在 CodeGeeX 插件[VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex)、[Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)中体验指令微调后的版本。
311 | * 可以通过添加`language tag`来控制编程语言,例如`# language: Python`,查看[完整支持语言列表](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14)。
312 | * 按照所选编程语言的格式写注释可以获得更好的结果,请参照下方给出的示例。
313 |
314 | This is the DEMO for CodeGeeX2. Please note that:
315 | * CodeGeeX2 is a base model, which is not instruction-tuned for chatting. It can do tasks like code completion/translation/explaination. To try the instruction-tuned version in CodeGeeX plugins ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)).
316 | * Programming languages can be controled by adding `language tag`, e.g., `# language: Python`. The format should be respected to ensure performance, full list can be found [here](https://github.com/THUDM/CodeGeeX2/blob/main/evaluation/utils.py#L14).
317 | * Write comments under the format of the selected programming language to achieve better results, see examples below.
318 | """)
319 |
320 | with gr.Row():
321 | with gr.Column():
322 | prompt = gr.Textbox(lines=14, placeholder='Please enter the description or select an example input below.',label='Input')
323 | with gr.Row():
324 | gen = gr.Button("Generate")
325 | clr = gr.Button("Clear")
326 |
327 | outputs = gr.Textbox(lines=15, label='Output')
328 |
329 | gr.Markdown(
330 | """
331 | Generation Parameter
332 | """)
333 |
334 | with gr.Row():
335 | with gr.Row():
336 | seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed')
337 | with gr.Row():
338 | out_seq_length = gr.Slider(maximum=8192, value=128, minimum=1, step=1, label='Output Sequence Length')
339 | temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature')
340 | with gr.Row():
341 | top_k = gr.Slider(maximum=100, value=0, minimum=0, step=1, label='Top K')
342 | top_p = gr.Slider(maximum=1, value=0.95, minimum=0, label='Top P')
343 | with gr.Row():
344 | lang = gr.Radio(
345 | choices=["None"] + list(LANGUAGE_TAG.keys()), value='None', label='Programming Language')
346 | inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p]
347 | gen.click(fn=predict, inputs=inputs, outputs=outputs)
348 | clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt)
349 |
350 | gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang],
351 | label="Example Inputs (Click to insert an examplet it into the input box)",
352 | examples_per_page=20)
353 | if not args.auth:
354 | demo.launch(server_name=args.listen, server_port=args.port)
355 | else:
356 | demo.launch(server_name=args.listen, server_port=args.port, auth=(args.username, args.password))
357 |
358 | #如果需要监听0.0.0.0和其他端口 可以改成 demo.launch(server_name="0.0.0.0", server_port=6666)
359 | #如果需要加密码 demo.launch(server_name="0.0.0.0", server_port=6666, auth=("admin", "password"))
360 |
361 | if __name__ == '__main__':
362 | with torch.no_grad():
363 | main()
364 |
365 |
--------------------------------------------------------------------------------
/docs/zh/inference_zh.md:
--------------------------------------------------------------------------------
1 | # CodeGeeX2推理教程
2 |
3 | CodeGeeX2 是多语言代码生成模型 [CodeGeeX](https://github.com/THUDM/CodeGeeX) ([KDD’23](https://arxiv.org/abs/2303.17568)) 的第二代模型,更强,更快,更轻量,是适合本地部署的AI代码生成助手。CodeGeeX2 支持在多种不同平台上进行推理,本教程将会介绍几种不同的推理方式,包括CPU推理,多卡推理,加速推理等。
4 |
5 | - [快速开始](#快速开始)
6 | - [多精度/量化推理](#多精度/量化推理)
7 | - [多GPU推理](#多GPU推理)
8 | - [Mac推理](#Mac推理)
9 | - [fastllm加速推理](#fastllm加速推理)
10 | - [ChatGLM.cpp量化推理](#chatglmcpp-量化推理)
11 |
12 | ## 快速开始
13 |
14 | 下载本仓库并使用`pip`安装环境依赖:
15 |
16 | ```shell
17 | git clone https://github.com/THUDM/CodeGeeX2
18 | cd CodeGeeX2
19 | pip install -r requirements.txt
20 | ```
21 |
22 | 使用`transformers`快速调用[CodeGeeX2-6B](https://huggingface.co/THUDM/codegeex2-6b),将自动下载权重到本地:
23 |
24 | ```python
25 | from transformers import AutoTokenizer, AutoModel
26 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
27 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda') # 如使用CPU推理,device='cpu'
28 | model = model.eval()
29 |
30 | # CodeGeeX2支持100种编程语言,加入语言标签引导生成相应的语言
31 | prompt = "# language: Python\n# write a bubble sort function\n"
32 | inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
33 | outputs = model.generate(inputs, max_length=256, top_k=1) # 示例中使用greedy decoding,检查输出结果是否对齐
34 | response = tokenizer.decode(outputs[0])
35 |
36 | >>> print(response)
37 | # language: Python
38 | # write a bubble sort function
39 |
40 |
41 | def bubble_sort(list):
42 | for i in range(len(list) - 1):
43 | for j in range(len(list) - 1):
44 | if list[j] > list[j + 1]:
45 | list[j], list[j + 1] = list[j + 1], list[j]
46 | return list
47 |
48 |
49 | print(bubble_sort([5, 2, 1, 8, 4]))
50 | ```
51 |
52 | 亦可以手动下载权重:
53 |
54 | ```shell
55 | # huggingface下载
56 | git clone https://huggingface.co/THUDM/codegeex2-6b
57 | ```
58 |
59 | 将tokenizer和model路径改为本地路径:
60 |
61 | ```python
62 | model_path = "/path/to/codegeex2-6b"
63 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
64 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
65 | ```
66 |
67 | ## 多精度/量化推理
68 |
69 | CodeGeeX2 使用BF16训练,推理时支持BF16/FP16/INT8/INT4,可以根据显卡显存选择合适的精度格式:
70 |
71 | | **Model** | FP16/BF16 | INT8 | INT4 |
72 | | :--------------: | :-------: | :------: | :-----: |
73 | | CodeGeeX-13B | 26\.9 GB | 14\.7 GB | - |
74 | | **CodeGeeX2-6B** | 13\.1 GB | 8\.2 GB | 5\.5 GB |
75 |
76 | 默认使用BF16精度进行推理,如显卡不支持BF16(❗️如使用错误的格式,推理结果将出现乱码),需要转换为FP16格式:
77 |
78 | ```python
79 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().to("cuda")
80 | ```
81 |
82 | 量化推理以INT4为例,可以下载转换好的权重([INT4权重](https://huggingface.co/THUDM/codegeex2-6b-int4))或手动转换,如果显卡不支持BF16,也需要先转换为FP16格式:
83 |
84 | ```python
85 | # 下载转换好的权重
86 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b-int4", trust_remote_code=True)
87 |
88 | # 手动转换权重
89 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).quantize(4).to("cuda")
90 |
91 | # 如果显卡不支持BF16,需要先转换为FP16格式
92 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).half().quantize(4).to("cuda")
93 | ```
94 |
95 | ## 多GPU推理
96 |
97 | 用[gpus.py](https://github.com/THUDM/CodeGeeX2/blob/main/demo/gpus.py)实现多GPU推理:
98 |
99 | ```python
100 | from gpus import load_model_on_gpus
101 | model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
102 | ```
103 |
104 | ## Mac推理
105 |
106 | 对于搭载了 Apple Silicon 或者 AMD GPU 的 Mac,可以使用 MPS 后端运行。参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly(正确的版本号应该是2.x.x.dev2023xxxx,如2.1.0.dev20230729):
107 |
108 | ```shell
109 | pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
110 | ```
111 |
112 | 在 MacOS 上只支持从本地加载模型(提前下载权重[codegeex2-6b](https://huggingface.co/THUDM/codegeex2-6b),[codegeex2-6b-int4](https://huggingface.co/THUDM/codegeex2-6b-int4)),支持FP16/INT8/INT4格式,并使用 mps 后端:
113 |
114 | ```python
115 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().to('mps')
116 | ```
117 |
118 | ## fastllm加速推理
119 |
120 | 可以使用[fastllm](https://github.com/ztxz16/fastllm)对 CodeGeeX2 进行加速,fastllm是目前支持GLM架构的最快开源框架。首先安装fastllm_pytools:
121 |
122 | ```shell
123 | git clone https://github.com/ztxz16/fastllm
124 | cd fastllm
125 | mkdir build
126 | cd build
127 | # 使用GPU编译,需要添加CUDA路径:export CUDA_HOME=/usr/local/cuda/bin:$PATH,export PATH=$PATH:$CUDA_HOME/bin
128 | cmake .. -DUSE_CUDA=ON # 如果不使用GPU编译 cmake .. -DUSE_CUDA=OFF
129 | make -j
130 | cd tools && python setup.py install # 确认安装是否成功,在python中 import fastllm_pytools 不报错
131 | ```
132 |
133 | 如出现架构不支持的报错,需要调整`CMakeLists.txt`,注释掉下面一行:
134 |
135 | ```shell
136 | # set(CMAKE_CUDA_ARCHITECTURES "native")
137 | ```
138 | 如果是E5系列的CPU可能会出现下面的编译报错
139 | ```
140 | error: inlining failed in call to ‘always_inline’ ‘__m256i _mm256_add_epi32(__m256i, __m256i)’: target specific option mismatch
141 | ```
142 | 此时将'CmakeLists.txt'的第20行修改如下即可编译成功:
143 | ```
144 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2")
145 | ```
146 |
147 | 将huggingface转换成fastllm格式:
148 |
149 | ```python
150 | # 原本的调用代码
151 | from transformers import AutoTokenizer, AutoModel
152 | tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
153 | model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
154 |
155 | # 加入下面这两行,将huggingface模型转换成fastllm模型
156 | from fastllm_pytools import llm
157 | model = llm.from_hf(model, tokenizer, dtype="float16") # dtype支持 "float16", "int8", "int4"
158 | ```
159 |
160 | fastllm中模型接口和huggingface不完全相同,可以参考[demo/run_demo.py](https://github.com/THUDM/CodeGeeX2/blob/main/demo/run_demo.py)中的相关实现:
161 |
162 | ```python
163 | model.direct_query = True
164 | outputs = model.chat(tokenizer,
165 | prompt,
166 | max_length=out_seq_length,
167 | top_p=top_p,
168 | top_k=top_k,
169 | temperature=temperature)
170 | response = outputs[0]
171 | ```
172 |
173 | ## ChatGLM.cpp 量化推理
174 |
175 | [ChatGLM.cpp](https://github.com/li-plus/chatglm.cpp) 是类似 LLaMA.cpp 的全平台量化加速方案,支持 q4_0/q4_1/q5_0/q5_1/q8_0 多种量化精度,CPU/CUDA/Metal 多种后端,仅用一行代码实现推理加速。
176 |
177 | 首先安装 chatglm-cpp。如需使用 CUDA 加速,需要添加环境变量 `CMAKE_ARGS="-DGGML_CUBLAS=ON"`;如果仅使用 CPU 加速,将该环境变量去掉即可。
178 | ```sh
179 | CMAKE_ARGS="-DGGML_CUBLAS=ON" pip install chatglm-cpp -v
180 | ```
181 |
182 | 仅需一行代码即可量化加速 Hugging Face 模型,`dtype` 可指定 `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `f16`,表示不同的量化类型。
183 | ```python
184 | >>> import chatglm_cpp
185 | >>>
186 | >>> pipeline = chatglm_cpp.Pipeline("THUDM/codegeex2-6b", dtype="q4_0") # Load HF model and quantize it into int4
187 | Loading checkpoint shards: 100%|███████████████████████████████████████████████| 7/7 [00:09<00:00, 1.33s/it]
188 | Processing model states: 100%|█████████████████████████████████████████████| 199/199 [00:21<00:00, 9.21it/s]
189 | ...
190 | >>> print(pipeline.generate("# language: Python\n# write a bubble sort function\n", do_sample=False))
191 |
192 |
193 | def bubble_sort(list):
194 | for i in range(len(list) - 1):
195 | for j in range(len(list) - 1):
196 | if list[j] > list[j + 1]:
197 | list[j], list[j + 1] = list[j + 1], list[j]
198 | return list
199 |
200 |
201 | print(bubble_sort([5, 4, 3, 2, 1]))
202 | ```
203 |
204 | ChatGLM.cpp 已集成到本仓库,demo 添加选项 `--quantize 4 --chatglm-cpp` 即可开启 int4 (q4_0) 量化加速,例如:
205 | ```sh
206 | python ./demo/run_demo.py --quantize 4 --chatglm-cpp
207 | ```
208 |
209 | Fast API 同样支持 ChatGLM.cpp 加速,添加同样参数启动服务:
210 | ```sh
211 | python ./demo/fastapicpu.py --quantize 4 --chatglm-cpp
212 | ```
213 |
214 | 测试服务接口:
215 | ```sh
216 | curl -X POST "http://127.0.0.1:7860" \
217 | -H 'Content-Type: application/json' \
218 | -d '{"lang": "Python", "prompt": "# Write a bubble sort function", "max_length": 512}'
219 | ```
220 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | import fire
5 | import json
6 | import gzip
7 | import glob
8 | import numpy as np
9 |
10 | from typing import *
11 | from tqdm.auto import tqdm
12 | from collections import defaultdict
13 | from concurrent.futures import ThreadPoolExecutor, as_completed
14 |
15 | from execution import check_correctness
16 | from utils import Logger, IMPORT_HELPER, read_dataset, stream_jsonl_all, estimate_pass_at_k
17 |
18 |
19 | LANGUAGE_NAME = {
20 | "CPP" : "cpp",
21 | "Go" : "go",
22 | "Java" : "java",
23 | "JavaScript" : "js",
24 | "Python" : "python",
25 | "Rust" : "rust",
26 | }
27 |
28 |
29 | def postprocess_generation(sample, generation_mode="completion"):
30 | code = sample["generation"]
31 | if generation_mode == "instruction":
32 | if "```" in code:
33 | pattern = r'```(.*?)\n(.*?)```'
34 | matches = re.findall(pattern, code, re.DOTALL)
35 | for match in matches:
36 | code = match[1]
37 | break
38 | sample["generation"] = code
39 |
40 | return sample
41 |
42 |
43 | def process_test(sample, problems, dataset_type, language_type, generation_mode):
44 | if dataset_type == "humanevalx":
45 | task_id = sample["task_id"]
46 | prompt = problems[task_id]["prompt"]
47 | test = problems[task_id]["test"]
48 | code = sample["generation"]
49 |
50 | # Pre-process for different languages
51 | if language_type == "python":
52 | test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
53 | test_string = test_setup + prompt + code + "\n" + test + "\n"
54 | elif language_type == "cpp":
55 | test_set_up = ""
56 | for s in IMPORT_HELPER["cpp"]:
57 | if s not in prompt:
58 | test_set_up += s + "\n"
59 | test_string = test_set_up + "\n" + prompt + code + "\n" + test
60 | elif language_type == "java":
61 | test_string = prompt + code + "\n" + test
62 | elif language_type == "js" or language_type == "javascript":
63 | test_string = prompt + code + "\n" + test
64 | elif language_type == "go":
65 | import_string = problems[task_id]["import"]
66 | prompt = prompt.replace(import_string, "")
67 | test = problems[task_id]["test"]
68 | test_setup = problems[task_id]["test_setup"]
69 | other_pkgs = []
70 | for pkg in IMPORT_HELPER["go"]:
71 | if pkg not in test_setup:
72 | p = pkg.split("/")[-1]
73 | if p + "." in code:
74 | other_pkgs.append(f"\"{pkg}\"")
75 | if other_pkgs:
76 | import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
77 | test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
78 | else:
79 | test_string = test_setup + "\n" + prompt + code + "\n" + test
80 | elif language_type == "rust":
81 | main = "\nfn main(){ \n } \n"
82 | test_string = main + prompt + code + test
83 | elif dataset_type == "mbpp":
84 | task_id = sample["task_id"]
85 | prompt = sample["prompt"]
86 | test = "\n".join(problems[task_id]["test_list"]) + "\n" + "\n".join(problems[task_id]["challenge_test_list"])
87 | code = sample["generation"]
88 | test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
89 | test_string = test_setup + "\n" + prompt + code + "\n" + problems[task_id]["test_setup_code"] + "\n" + test + "\n"
90 |
91 | return test_string
92 |
93 |
94 | def evaluate_functional_correctness(
95 | input_path: str = None,
96 | output_path: str = None,
97 | log_path: str = None,
98 | tmp_dir: str = "./",
99 | n_workers: int = 32,
100 | timeout: float = 5.0,
101 | k: List[int] = [1, 10, 100],
102 | model_name: str = None,
103 | problem_file: str = None,
104 | language_type: str = None,
105 | dataset_type: str = "humanevalx",
106 | generation_mode: str = "completion",
107 | test_groundtruth: bool = False,
108 | ):
109 | if log_path is None:
110 | log_path = os.path.join(output_path, "evaluation.log")
111 | logger = Logger(__name__, log_file=log_path)
112 |
113 | if os.path.isdir(input_path):
114 | input_list = glob.glob(input_path + '/*generation*.jsonl')
115 | sample_jsonl = []
116 | for input_file in input_list:
117 | sample_jsonl += stream_jsonl_all(input_file)
118 | else:
119 | input_file = input_path
120 | sample_jsonl = stream_jsonl_all(input_file)
121 |
122 | problems = read_dataset(problem_file, dataset_type=dataset_type)
123 |
124 | if output_path is not None:
125 | os.makedirs(output_path, exist_ok=True)
126 |
127 | with ThreadPoolExecutor(max_workers=n_workers) as executor:
128 |
129 | futures = []
130 | completion_id = Counter()
131 | n_samples = 0
132 | results = defaultdict(list)
133 |
134 | if test_groundtruth:
135 | logger.info("Testing ground truth...")
136 | else:
137 | logger.info("Testing generation...")
138 | for sample in sample_jsonl:
139 | task_id = sample["task_id"]
140 | if language_type is None:
141 | language_type = LANGUAGE_NAME[task_id.split("/")[0]]
142 | if test_groundtruth:
143 | if dataset_type == "humanevalx":
144 | sample["generation"] = sample["canonical_solution"]
145 | sample["prompt"] = problems[task_id]["prompt"]
146 | if dataset_type == "mbpp":
147 | sample["generation"] = sample["code"]
148 | sample["prompt"] = problems[task_id]["prompt"]
149 | sample = postprocess_generation(sample, generation_mode)
150 | sample["test_code"] = process_test(sample, problems, dataset_type, language_type, generation_mode)
151 | if sample["test_code"] is None:
152 | continue
153 | if "completion_id" in sample:
154 | completion_id_ = sample["completion_id"]
155 | else:
156 | completion_id_ = completion_id[task_id]
157 | args = (task_id, sample, language_type, timeout, tmp_dir, completion_id_)
158 | future = executor.submit(check_correctness, *args)
159 | futures.append(future)
160 | completion_id[task_id] += 1
161 | n_samples += 1
162 |
163 | if len(completion_id) == len(problems):
164 | evaluate_pass_at_k = True
165 | else:
166 | evaluate_pass_at_k = False
167 |
168 | logger.info("Running test suites...")
169 | for future in tqdm(as_completed(futures), total=len(futures)):
170 | result = future.result()
171 | results[result["task_id"]].append((result["completion_id"], result))
172 |
173 | # Calculate pass@k.
174 | total, correct = [], []
175 | for result in results.values():
176 | passed = [r[1]["passed"] for r in result]
177 | total.append(len(passed))
178 | correct.append(sum(passed))
179 | total = np.array(total)
180 | correct = np.array(correct)
181 | if evaluate_pass_at_k:
182 | ks = k
183 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
184 | for k in ks if (total >= k).all()}
185 | logger.info(pass_at_k)
186 | else:
187 | logger.info("Total: {}".format(np.sum(total)))
188 | logger.info("Correct: {}".format(np.sum(correct)))
189 |
190 | if test_groundtruth:
191 | out_file = os.path.join(output_path, "ground_truth.jsonl")
192 | else:
193 | out_file = os.path.join(output_path, "result-" + input_file.split("/")[-2] + "." + input_file.split("/")[-1].split(".")[-1])
194 |
195 | logger.info("Writing to: {}".format(out_file))
196 | if out_file.endswith(".gz"):
197 | fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb")
198 | for res in results.values():
199 | for r in res:
200 | fp.write((json.dumps(r[1], ensure_ascii=False) + "\n").encode("utf-8"))
201 | else:
202 | fp = open(out_file, 'w')
203 | for res in results.values():
204 | for r in res:
205 | fp.write(json.dumps(r[1], ensure_ascii=False) + "\n")
206 | fp.close()
207 |
208 | if test_groundtruth:
209 | logger.info("Ground-truth test finished.")
210 | else:
211 | logger.info("Evaluation finished.")
212 |
213 |
214 | def main():
215 | fire.Fire(evaluate_functional_correctness)
216 |
217 |
218 | if __name__ == "__main__":
219 | sys.exit(main())
220 |
--------------------------------------------------------------------------------
/evaluation/execution.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import signal
4 | import random
5 | import gzip
6 | import json
7 | import tempfile
8 | import platform
9 | import subprocess
10 | import contextlib
11 | import faulthandler
12 | import multiprocessing
13 | from typing import *
14 |
15 |
16 | def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> None:
17 | """
18 | Method saves list of dicts into jsonl file.
19 | :param data: (list) list of dicts to be stored,
20 | :param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends
21 | .jsonl suffix into the file.
22 | :param compress: (bool) should file be compressed into a gzip archive?
23 | """
24 | sjsonl = '.jsonl'
25 | sgz = '.gz'
26 | # Check filename
27 | if not filename.endswith(sjsonl):
28 | filename = filename + sjsonl
29 | # Save data
30 |
31 | if compress:
32 | filename = filename + sgz
33 | with gzip.open(filename, 'w') as compressed:
34 | for ddict in data_list:
35 | jout = json.dumps(ddict) + '\n'
36 | jout = jout.encode('utf-8')
37 | compressed.write(jout)
38 | else:
39 | with open(filename, 'w') as out:
40 | for ddict in data_list:
41 | jout = json.dumps(ddict) + '\n'
42 | out.write(jout)
43 |
44 |
45 | def check_correctness(
46 | task_id: str,
47 | sample: dict,
48 | language_type: str,
49 | timeout: float = 3.0,
50 | tmp_dir: str = None,
51 | completion_id: Optional[int] = None,
52 | ) -> Dict:
53 | """
54 | Evaluates the functional correctness of a completion by running the test
55 | suite provided in the problem.
56 | """
57 |
58 | def unsafe_execute(tmp_dir):
59 | random_id = random.uniform(1, 1000)
60 | if "python" in language_type.lower():
61 | with create_tempdir():
62 |
63 | # These system calls are needed when cleaning up tempdir.
64 | import os
65 | import shutil
66 | rmtree = shutil.rmtree
67 | rmdir = os.rmdir
68 | chdir = os.chdir
69 |
70 | # Disable functionalities that can make destructive changes to the test.
71 | reliability_guard()
72 |
73 | try:
74 | exec_globals = {}
75 | with swallow_io():
76 | with time_limit(timeout):
77 | # WARNING
78 | # This program exists to execute untrusted model-generated code. Although
79 | # it is highly unlikely that model-generated code will do something overtly
80 | # malicious in response to this test suite, model-generated code may act
81 | # destructively due to a lack of model capability or alignment.
82 | # Users are strongly encouraged to sandbox this evaluation suite so that it
83 | # does not perform destructive actions on their host or network.
84 | # Once you have read this disclaimer and taken appropriate precautions,
85 | # uncomment the following line and proceed at your own risk:
86 | exec(sample["test_code"], exec_globals)
87 | result.append("passed")
88 | except TimeoutException:
89 | result.append("timed out")
90 | except AssertionError as e:
91 | result.append(f"failed: AssertionError")
92 | except BaseException as e:
93 | result.append(f"failed: {e}")
94 |
95 | # Needed for cleaning up.
96 | shutil.rmtree = rmtree
97 | os.rmdir = rmdir
98 | os.chdir = chdir
99 |
100 | elif "go" in language_type.lower():
101 | assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed."
102 |
103 | import os
104 | import shutil
105 |
106 | if "tmp" not in tmp_dir:
107 | tmp_dir = os.path.join(tmp_dir, "tmp")
108 | tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
109 | if not os.path.exists(tmp_dir):
110 | os.makedirs(tmp_dir)
111 |
112 | os.chdir(tmp_dir)
113 | open(f"main_test.go", 'w').write(sample["test_code"])
114 | try:
115 | exec_result = None
116 | with time_limit(timeout):
117 | # WARNING
118 | # This program exists to execute untrusted model-generated code. Although
119 | # it is highly unlikely that model-generated code will do something overtly
120 | # malicious in response to this test suite, model-generated code may act
121 | # destructively due to a lack of model capability or alignment.
122 | # Users are strongly encouraged to sandbox this evaluation suite so that it
123 | # does not perform destructive actions on their host or network.
124 | # Once you have read this disclaimer and taken appropriate precautions,
125 | # uncomment the following line and proceed at your own risk:
126 | exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
127 |
128 | if exec_result.returncode == 0:
129 | result.append("passed")
130 | else:
131 | if exec_result.stderr:
132 | try:
133 | err = exec_result.stderr.decode()
134 | except:
135 | err = exec_result.stderr
136 | else:
137 | try:
138 | err = exec_result.stdout.decode()
139 | except:
140 | err = exec_result.stdout
141 | result.append(f"failed: {err}")
142 |
143 | except TimeoutException:
144 | result.append("timed out")
145 |
146 | shutil.rmtree(tmp_dir)
147 | elif "js" in language_type.lower():
148 | import os
149 | import shutil
150 |
151 | if "tmp" not in tmp_dir:
152 | tmp_dir = os.path.join(tmp_dir, "tmp")
153 | tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
154 | if not os.path.exists(tmp_dir):
155 | os.makedirs(tmp_dir)
156 |
157 | os.chdir(tmp_dir)
158 | open(f"test.js", 'w').write(sample["test_code"])
159 | try:
160 | exec_result = None
161 | with time_limit(timeout):
162 | # WARNING
163 | # This program exists to execute untrusted model-generated code. Although
164 | # it is highly unlikely that model-generated code will do something overtly
165 | # malicious in response to this test suite, model-generated code may act
166 | # destructively due to a lack of model capability or alignment.
167 | # Users are strongly encouraged to sandbox this evaluation suite so that it
168 | # does not perform destructive actions on their host or network.
169 | # Once you have read this disclaimer and taken appropriate precautions,
170 | # uncomment the following line and proceed at your own risk:
171 | exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
172 |
173 | if exec_result.stderr.decode():
174 | err = exec_result.stderr.decode()
175 | result.append(f"failed: {err}")
176 | elif exec_result.stdout.decode():
177 | err = exec_result.stdout.decode()
178 | result.append(f"failed: {err}")
179 | else:
180 | result.append("passed")
181 |
182 | except TimeoutException:
183 | result.append("timed out")
184 |
185 | shutil.rmtree(tmp_dir)
186 | elif "cpp" in language_type.lower():
187 | import os
188 | import shutil
189 |
190 | if "tmp" not in tmp_dir:
191 | tmp_dir = os.path.join(tmp_dir, "tmp")
192 | tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
193 | if not os.path.exists(tmp_dir):
194 | os.makedirs(tmp_dir)
195 |
196 | os.chdir(tmp_dir)
197 | open(f"test.cpp", 'w').write(sample["test_code"])
198 | if "162" in task_id:
199 | compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"],
200 | timeout=timeout,
201 | capture_output=True)
202 | else:
203 | compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout,
204 | capture_output=True)
205 | if compilation_result.returncode != 0:
206 | if compilation_result.stderr:
207 | err = compilation_result.stderr.decode()
208 | else:
209 | err = compilation_result.stdout.decode()
210 | result.append(f"failed: compilation error: {err}")
211 | else:
212 | try:
213 | exec_result = None
214 | with time_limit(timeout):
215 | # WARNING
216 | # This program exists to execute untrusted model-generated code. Although
217 | # it is highly unlikely that model-generated code will do something overtly
218 | # malicious in response to this test suite, model-generated code may act
219 | # destructively due to a lack of model capability or alignment.
220 | # Users are strongly encouraged to sandbox this evaluation suite so that it
221 | # does not perform destructive actions on their host or network.
222 | # Once you have read this disclaimer and taken appropriate precautions,
223 | # uncomment the following line and proceed at your own risk:
224 | exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
225 |
226 | if exec_result.returncode == 0:
227 | result.append("passed")
228 | else:
229 | if exec_result.stderr:
230 | try:
231 | err = exec_result.stderr.decode()
232 | except:
233 | err = exec_result.stderr
234 | else:
235 | try:
236 | err = exec_result.stdout.decode()
237 | except:
238 | err = exec_result.stdout
239 | result.append(f"failed: {err}")
240 | except TimeoutException:
241 | result.append("timed out")
242 |
243 | shutil.rmtree(tmp_dir)
244 | elif "rust" in language_type.lower():
245 | import os
246 | WD: str = os.path.dirname(tmp_dir)
247 | RUST_DIR: str = os.path.join(WD, "rust")
248 | RUST_SRC: str = os.path.join(RUST_DIR, "src")
249 | RUST_BIN: str = os.path.join(RUST_SRC, "bin")
250 | RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
251 | RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
252 | RUST_EXT: str = ".rs"
253 |
254 | # Create mandatory tmp directories
255 | os.makedirs(RUST_TMP_DIR, exist_ok=True)
256 | os.makedirs(RUST_LOGS, exist_ok=True)
257 | os.makedirs(RUST_SRC, exist_ok=True)
258 | os.makedirs(RUST_BIN, exist_ok=True)
259 |
260 | with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
261 | # temporal file name
262 | file_prefix = sample["task_id"].lower().replace("/", "_")
263 | file_name:str = file_prefix +RUST_EXT
264 |
265 | os.rename(f.name, os.path.join(RUST_BIN, file_name))
266 |
267 | # Sample to pure Rust function
268 | rust_code: str = sample["test_code"]
269 |
270 | # dump the rust source code in the target temporal file
271 | f.write(rust_code.encode('utf-8'))
272 |
273 | # Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
274 | os.chdir(RUST_DIR)
275 |
276 | # Two possible outcomes
277 | # Pass OR Fail compilation
278 | log_filename: str = file_prefix + ".jsonl"
279 | log_path: str = os.path.join(RUST_LOGS, log_filename)
280 | cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path
281 | # Compilation build status
282 | returned_val_compilation: int
283 |
284 | # Overwrite file content
285 | if os.path.exists(log_path):
286 | if(file_size := os.path.getsize(log_path)) >= 0:
287 | os.remove(log_path)
288 | returned_val_compilation = os.system(cargo_check)
289 |
290 | else:
291 | returned_val_compilation = os.system(cargo_check)
292 |
293 | # 0 means success
294 | if returned_val_compilation == 0:
295 |
296 | #Execution pipeline
297 | cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path
298 | returned_val_execution = os.system(cargo_test)
299 |
300 | if returned_val_execution == 0:
301 | result.append("passed")
302 | else:
303 | result.append(f"failed: execution error")
304 |
305 | else:
306 | result.append(f"failed: compilation error")
307 |
308 |
309 | elif "java" in language_type.lower():
310 | assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
311 |
312 | import os
313 | import shutil
314 |
315 | if "tmp" not in tmp_dir:
316 | tmp_dir = os.path.join(tmp_dir, "tmp")
317 | tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
318 | if not os.path.exists(tmp_dir):
319 | os.makedirs(tmp_dir)
320 |
321 | os.chdir(tmp_dir)
322 | open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"])
323 | res = "failed: unknown error"
324 | compile_returncode = -1
325 | for _ in range(5):
326 | try:
327 | compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5,
328 | capture_output=True)
329 | compile_returncode = compilation_result.returncode
330 | break
331 | except subprocess.TimeoutExpired as e:
332 | continue
333 | if compile_returncode != 0:
334 | res = "failed: compilation error"
335 | else:
336 | exec_result = None
337 | try:
338 | # WARNING
339 | # This program exists to execute untrusted model-generated code. Although
340 | # it is highly unlikely that model-generated code will do something overtly
341 | # malicious in response to this test suite, model-generated code may act
342 | # destructively due to a lack of model capability or alignment.
343 | # Users are strongly encouraged to sandbox this evaluation suite so that it
344 | # does not perform destructive actions on their host or network.
345 | # Once you have read this disclaimer and taken appropriate precautions,
346 | # uncomment the following line and proceed at your own risk:
347 | exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
348 | if exec_result.returncode == 0:
349 | res = "passed"
350 | elif exec_result.returncode == 1:
351 | if "AssertionError" in exec_result.stderr.decode('unicode-escape'):
352 | res = "failed: wrong answer"
353 | else:
354 | res = f"failed: {exec_result.stderr.decode()}"
355 | except subprocess.TimeoutExpired as e:
356 | res = "time out"
357 | except BaseException as e:
358 | res = f"failed: {e}"
359 | result.append(res)
360 |
361 | shutil.rmtree(tmp_dir)
362 |
363 | manager = multiprocessing.Manager()
364 | result = manager.list()
365 |
366 | p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,))
367 | p.start()
368 | p.join(timeout=timeout + 1)
369 | if p.is_alive():
370 | p.kill()
371 |
372 | if not result:
373 | result.append("timed out")
374 |
375 | return {
376 | "task_id" : task_id,
377 | "completion_id": completion_id,
378 | "test_code" : sample["test_code"],
379 | "prompt" : sample["prompt"],
380 | "generation" : sample["generation"],
381 | "result" : result[0],
382 | "passed" : result[0] == "passed",
383 | "finish" : -1 if "finish" not in sample else sample["finish"],
384 | "file" : "" if "file" not in sample else sample["file"],
385 | "output" : [] if "output" not in sample else sample["output"],
386 | }
387 |
388 | # Copyright (c) OpenAI (https://openai.com)
389 |
390 | # Permission is hereby granted, free of charge, to any person obtaining a copy
391 | # of this software and associated documentation files (the "Software"), to deal
392 | # in the Software without restriction, including without limitation the rights
393 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
394 | # copies of the Software, and to permit persons to whom the Software is
395 | # furnished to do so, subject to the following conditions:
396 |
397 | # The above copyright notice and this permission notice shall be included in
398 | # all copies or substantial portions of the Software.
399 |
400 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
401 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
402 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
403 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
404 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
405 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
406 | # THE SOFTWARE.
407 | # ============================================================================
408 | @contextlib.contextmanager
409 | def time_limit(seconds: float):
410 | def signal_handler(signum, frame):
411 | raise TimeoutException("Timed out!")
412 |
413 | signal.setitimer(signal.ITIMER_REAL, seconds)
414 | signal.signal(signal.SIGALRM, signal_handler)
415 | try:
416 | yield
417 | finally:
418 | signal.setitimer(signal.ITIMER_REAL, 0)
419 |
420 |
421 | @contextlib.contextmanager
422 | def swallow_io():
423 | stream = WriteOnlyStringIO()
424 | with contextlib.redirect_stdout(stream):
425 | with contextlib.redirect_stderr(stream):
426 | with redirect_stdin(stream):
427 | yield
428 |
429 |
430 | @contextlib.contextmanager
431 | def create_tempdir():
432 | with tempfile.TemporaryDirectory() as dirname:
433 | with chdir(dirname):
434 | yield dirname
435 |
436 |
437 | class TimeoutException(Exception):
438 | pass
439 |
440 |
441 | class WriteOnlyStringIO(io.StringIO):
442 | """ StringIO that throws an exception when it's read from """
443 |
444 | def read(self, *args, **kwargs):
445 | raise IOError
446 |
447 | def readline(self, *args, **kwargs):
448 | raise IOError
449 |
450 | def readlines(self, *args, **kwargs):
451 | raise IOError
452 |
453 | def readable(self, *args, **kwargs):
454 | """ Returns True if the IO object can be read. """
455 | return False
456 |
457 |
458 | class redirect_stdin(contextlib._RedirectStream): # type: ignore
459 | _stream = 'stdin'
460 |
461 |
462 | @contextlib.contextmanager
463 | def chdir(root):
464 | if root == ".":
465 | yield
466 | return
467 | cwd = os.getcwd()
468 | os.chdir(root)
469 | try:
470 | yield
471 | except BaseException as exc:
472 | raise exc
473 | finally:
474 | os.chdir(cwd)
475 |
476 |
477 | def reliability_guard(maximum_memory_bytes: Optional[int] = None):
478 | """
479 | This disables various destructive functions and prevents the generated code
480 | from interfering with the test (e.g. fork bomb, killing other processes,
481 | removing filesystem files, etc.)
482 |
483 | WARNING
484 | This function is NOT a security sandbox. Untrusted code, including, model-
485 | generated code, should not be blindly executed outside of one. See the
486 | Codex paper for more information about OpenAI's code sandbox, and proceed
487 | with caution.
488 | """
489 |
490 | if maximum_memory_bytes is not None:
491 | import resource
492 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
493 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
494 | if not platform.uname().system == 'Darwin':
495 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
496 |
497 | faulthandler.disable()
498 |
499 | import builtins
500 | builtins.exit = None
501 | builtins.quit = None
502 |
503 | import os
504 | os.environ['OMP_NUM_THREADS'] = '1'
505 |
506 | os.kill = None
507 | os.system = None
508 | os.putenv = None
509 | os.remove = None
510 | os.removedirs = None
511 | os.rmdir = None
512 | os.fchdir = None
513 | os.setuid = None
514 | os.fork = None
515 | os.forkpty = None
516 | os.killpg = None
517 | os.rename = None
518 | os.renames = None
519 | os.truncate = None
520 | os.replace = None
521 | os.unlink = None
522 | os.fchmod = None
523 | os.fchown = None
524 | os.chmod = None
525 | os.chown = None
526 | os.chroot = None
527 | os.fchdir = None
528 | os.lchflags = None
529 | os.lchmod = None
530 | os.lchown = None
531 | os.getcwd = None
532 | os.chdir = None
533 |
534 | import shutil
535 | shutil.rmtree = None
536 | shutil.move = None
537 | shutil.chown = None
538 |
539 | import subprocess
540 | subprocess.Popen = None # type: ignore
541 |
542 | __builtins__['help'] = None
543 |
544 | import sys
545 | sys.modules['ipdb'] = None
546 | sys.modules['joblib'] = None
547 | sys.modules['resource'] = None
548 | sys.modules['psutil'] = None
549 | sys.modules['tkinter'] = None
550 |
--------------------------------------------------------------------------------
/evaluation/generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zmq
3 | import time
4 | import json
5 | import torch
6 | import random
7 | import socket
8 | import argparse
9 |
10 | from typing import *
11 | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
12 | from utils import Logger, read_dataset, process_extra_prompt, is_code_generation_finished, cleanup_code
13 |
14 | logger = Logger(__name__)
15 |
16 |
17 | def add_code_generation_specific_args(parser):
18 | group = parser.add_argument_group("Code Generation")
19 | group.add_argument(
20 | "--hostfile",
21 | type=str,
22 | default="./hostfile",
23 | )
24 | group.add_argument(
25 | "--channel-ip",
26 | type=str,
27 | default=None,
28 | help="IP for ZeroMQ channel",
29 | )
30 | group.add_argument(
31 | "--channel-port",
32 | type=int,
33 | default=5555,
34 | help="Port for ZeroMQ channel",
35 | )
36 | group.add_argument(
37 | "--master-port",
38 | type=int,
39 | default=6007,
40 | help="Port for distributed channel",
41 | )
42 | group.add_argument(
43 | "--model-per-device",
44 | type=int,
45 | default=1,
46 | help="Number of models per device",
47 | )
48 | group.add_argument(
49 | "--max-length",
50 | type=int,
51 | default=8192,
52 | help="Max sequence length",
53 | )
54 | group.add_argument(
55 | "--top-p",
56 | type=float,
57 | default=1.0,
58 | help="Top-p Probability for sampling",
59 | )
60 | group.add_argument(
61 | "--top-k",
62 | type=int,
63 | default=0,
64 | help="Top-k for sampling",
65 | )
66 | group.add_argument(
67 | "--temperature",
68 | type=float,
69 | default=1.0,
70 | help="Temperature for sampling",
71 | )
72 | group.add_argument(
73 | "--greedy",
74 | type=int,
75 | default=0,
76 | help="Use greedy decoding instead of sampling",
77 | )
78 | group.add_argument(
79 | "--seed",
80 | type=int,
81 | default=42,
82 | help="Random seed",
83 | )
84 | group.add_argument(
85 | "--micro-batch-size",
86 | type=int,
87 | default=1,
88 | help="Micro batch size for each GPU",
89 | )
90 | group.add_argument(
91 | "--samples-per-problem",
92 | type=int,
93 | default=200,
94 | help="Number of samples to generate for each problem",
95 | )
96 | group.add_argument(
97 | "--gen-node-world-size",
98 | type=int,
99 | default=1,
100 | help="Number of machines to use for generation",
101 | )
102 | group.add_argument(
103 | '--task-name',
104 | default="generation",
105 | help='Name of task',
106 | )
107 | group.add_argument(
108 | '--model-name',
109 | default="codegeex2-6b",
110 | help='Name of model, support ["codegeex2-6b", "starcoder", "replit-code-v1-3b", "codegen25-7b-multi", "codegen25-7b-mono", "codegen-16B-multi"]',
111 | )
112 | group.add_argument(
113 | '--data-path',
114 | required=True,
115 | )
116 | group.add_argument(
117 | '--output-path',
118 | required=True,
119 | )
120 | group.add_argument(
121 | '--log-path',
122 | default=None,
123 | help='Path to log output',
124 | )
125 | group.add_argument(
126 | '--model-path',
127 | required=True,
128 | )
129 | group.add_argument(
130 | '--dataset-type',
131 | default="humanevalx",
132 | help='Identify the evaluation dataset [humanevalx]',
133 | )
134 | group.add_argument(
135 | '--language-type',
136 | default="python",
137 | help='Identify the type of programming language to generate',
138 | )
139 | group.add_argument(
140 | '--generation-mode',
141 | default="instruction",
142 | )
143 |
144 |
145 | class CodeStoppingCriteria(StoppingCriteria):
146 | """
147 | This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length` or meet the code generation stopping criteria.
148 | """
149 |
150 | def __init__(
151 | self,
152 | max_length: int,
153 | micro_batch_size: int,
154 | tokenizer,
155 | dataset_type: str,
156 | language_type: str,
157 | prompt: str,
158 | ):
159 | self.max_length = max_length
160 | self.tokenizer = tokenizer
161 | self.dataset_type = dataset_type
162 | self.language_type = language_type
163 | self.prompt = prompt
164 | self.stop_index = [-1 for _ in range(micro_batch_size)]
165 |
166 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
167 | for i, input_id in enumerate(input_ids):
168 | if self.stop_index[i] > -1:
169 | continue
170 | code = self.tokenizer.decode(input_id)
171 | code = code[len(self.prompt):]
172 | if is_code_generation_finished(
173 | code,
174 | dataset_type=self.dataset_type,
175 | language_type=self.language_type) or input_id.shape[-1] >= self.max_length:
176 | self.stop_index[i] = len(code) + len(self.prompt)
177 | if all([s != -1 for s in self.stop_index]):
178 | return True
179 |
180 | return False
181 |
182 |
183 | def run_generation_distributed(args, model, tokenizer):
184 | logger.info(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}")
185 | context = zmq.Context()
186 | socket = context.socket(zmq.REQ)
187 | socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}")
188 |
189 | os.makedirs(args.output_path, exist_ok=True)
190 | output_path = os.path.join(
191 | args.output_path,
192 | f"{args.task_name}-t{args.temperature}-topp{args.top_p}-ns{args.samples_per_problem}-rank{args.rank}.jsonl",
193 | )
194 |
195 | def process(obj):
196 | results = []
197 | prompt = obj["prompt"]
198 | if args.generation_mode == "instruction":
199 | inputs = tokenizer([prompt] * args.micro_batch_size, return_tensors="pt")
200 | inputs = inputs.to(model.device)
201 | outputs = model.generate(**inputs,
202 | max_length=args.max_length,
203 | do_sample=True if not args.greedy else False,
204 | use_cache=True,
205 | top_p=args.top_p,
206 | top_k=args.top_k,
207 | temperature=args.temperature,
208 | pad_token_id=tokenizer.eos_token_id)
209 | for i, output in enumerate(outputs):
210 | response = tokenizer.decode(output)
211 | res = obj.copy()
212 | res["generation"] = response[len(prompt):].strip()
213 | results.append(res)
214 | elif args.generation_mode == "completion":
215 | inputs = tokenizer([prompt for _ in range(args.micro_batch_size)], return_tensors="pt")
216 | inputs = inputs.to(model.device)
217 | stop_criteria = CodeStoppingCriteria(
218 | max_length=args.max_length,
219 | micro_batch_size=args.micro_batch_size,
220 | tokenizer=tokenizer,
221 | dataset_type=args.dataset_type,
222 | language_type=args.language_type,
223 | prompt=prompt)
224 | outputs = model.generate(**inputs,
225 | max_length=args.max_length,
226 | do_sample=True if not args.greedy else False,
227 | use_cache=True,
228 | stopping_criteria=[stop_criteria],
229 | top_p=args.top_p,
230 | top_k=args.top_k,
231 | temperature=args.temperature,
232 | pad_token_id=tokenizer.eos_token_id)
233 | for i, output in enumerate(outputs):
234 | response = tokenizer.decode(output)
235 | res = obj.copy()
236 | res["generation_raw"] = response
237 | res["generation"] = cleanup_code(
238 | response[len(prompt):],
239 | dataset_type=args.dataset_type,
240 | language_type=args.language_type)
241 | results.append(res)
242 |
243 | return results
244 |
245 | fout = open(output_path, "w", encoding="utf-8")
246 | while True:
247 | socket.send_json({"rank": args.rank, "action": "pull"})
248 | resp = socket.recv_json()
249 | try:
250 | if resp["task_id"] is None:
251 | break
252 |
253 | current_spec = resp["task_id"]
254 | results = process(current_spec)
255 |
256 | for res in results:
257 | fout.write(json.dumps(res, ensure_ascii=False) + "\n")
258 | fout.flush()
259 |
260 | socket.send_json(
261 | {
262 | "rank" : args.rank,
263 | "action" : "success",
264 | "task_id": current_spec['task_id']
265 | }
266 | )
267 | socket.recv()
268 |
269 | except Exception as e:
270 | logger.error(f"*** (rank={args.rank}) crashed.")
271 | logger.error(f" error: {repr(e)}")
272 | socket.send_json(
273 | {
274 | "rank" : args.rank,
275 | "action" : "fail",
276 | "task_id": current_spec['task_id']
277 | }
278 | )
279 | socket.recv()
280 | continue
281 |
282 |
283 | def main(args, node_rank: int, local_rank: int, master_port: int, num_devices: int):
284 | world_size = args.gen_node_world_size * num_devices
285 | args.rank = num_devices * node_rank + local_rank
286 | args.world_size = world_size
287 | logger.info(f"Generating on rank {args.rank} of {args.world_size}")
288 |
289 | try:
290 | if args.model_name in ["codegeex2-6b"]:
291 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
292 | else:
293 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, clean_up_tokenization_spaces=False, trust_remote_code=True)
294 | if args.model_name in ["codegeex2-6b"]:
295 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
296 | elif args.model_name in ["starcoder", "replit-code-v1-3b", "codegen25-7b-multi", "codegen25-7b-mono", "codegen-16B-multi"]:
297 | model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
298 | else:
299 | try:
300 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
301 | except:
302 | logger.error(f"Model {args.model_name} not supported.")
303 | raise NotImplementedError
304 | except Exception as e:
305 | logger.error(e)
306 |
307 | model = model.eval()
308 | # Generate samples.
309 | run_generation_distributed(args, model, tokenizer)
310 |
311 | logger.info(f"rank={args.rank} worker finished, waiting ...")
312 | exit(0)
313 |
314 |
315 | def server(args):
316 | logger.info(f"[ server ] starting ...")
317 | entries = read_dataset(args.data_path, dataset_type=args.dataset_type)
318 |
319 | assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by batch_size"
320 |
321 | for entry in entries.values():
322 | entry["prompt"] = process_extra_prompt(
323 | entry["prompt"],
324 | language_type=args.language_type,
325 | dataset_type=args.dataset_type,
326 | generation_mode=args.generation_mode,
327 | )
328 |
329 | res = []
330 | for entry in entries.values():
331 | res.extend([entry] * (args.samples_per_problem // args.micro_batch_size))
332 | random.shuffle(res)
333 | all_entries = res
334 |
335 | # setup zeromq channel
336 | logger.info(f"[ server ] starting up on port {args.channel_port}")
337 | context = zmq.Context()
338 | logger.info(f"[ server ] creating socket")
339 | socket = context.socket(zmq.REP)
340 | logger.info(f"[ server ] binding to port {args.channel_port}")
341 | socket.bind(f"tcp://*:{args.channel_port}")
342 |
343 | logger.info(
344 | f"[ server ] loaded {len(entries)} entries, generating {len(entries) * args.samples_per_problem} samples",
345 | )
346 |
347 | remaining_entries = all_entries.copy()
348 | running_workers = args.gen_node_world_size * torch.cuda.device_count()
349 | num_finished = 0
350 |
351 | logger.info(f"[ server ] listening for requests ...")
352 | start_time = time.perf_counter()
353 | while True:
354 | # Wait for next request from client
355 | msg = socket.recv_json()
356 | rank = msg["rank"]
357 | action = msg["action"]
358 |
359 | if action == "pull":
360 | if len(remaining_entries) == 0:
361 | socket.send_json({"task_id": None})
362 | running_workers -= 1
363 | logger.info(f"[ server ] Shutting down worker {rank}, remaining {running_workers} workers")
364 | if running_workers == 0 and num_finished == len(all_entries):
365 | logger.info(f"[ server ] All workers finished")
366 | break
367 | else:
368 | entry = remaining_entries.pop()
369 | time_elapsed = time.perf_counter() - start_time
370 | logger.info(f"[ server ] Sending entry {entry['task_id']} to worker {rank}")
371 | remaining = (
372 | len(remaining_entries)
373 | / (len(all_entries) - len(remaining_entries))
374 | * time_elapsed
375 | )
376 | time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished / args.micro_batch_size
377 | logger.info(
378 | f"[ server ] total {len(all_entries)}, assigned {len(all_entries) - len(remaining_entries)}, finished {num_finished}, elapsed {time_elapsed:.4f}, speed {time_per_sampple:.4f}s/sample, remaining {remaining:.4f}",
379 | )
380 | socket.send_json({"task_id": entry})
381 | else:
382 | if action == "success":
383 | logger.info(f"[ server ] {msg['task_id']} is finished")
384 | socket.send_json({"pong": 1})
385 | else:
386 | logger.info(f"[ server ] {msg['task_id']} is not finished")
387 | remaining_entries.append(msg['task_id'])
388 | socket.send_json({"pong": 1})
389 | break
390 |
391 | num_finished += 1
392 |
393 |
394 | if __name__ == "__main__":
395 | torch.multiprocessing.set_start_method("spawn")
396 | parser = argparse.ArgumentParser()
397 | add_code_generation_specific_args(parser)
398 | args = parser.parse_args()
399 |
400 | if args.log_path is None:
401 | args.log_path = os.path.join(args.output_path, "generation.log")
402 |
403 | logger.info("start method: " + torch.multiprocessing.get_start_method())
404 |
405 | processes = []
406 | num_devices = torch.cuda.device_count()
407 | hosts = open(args.hostfile, "r").readlines()
408 | hosts = [host.strip() for host in hosts]
409 | master_port = args.master_port
410 |
411 | node_rank = None
412 | for i in range(len(hosts)):
413 | if hosts[i] == socket.gethostbyname(socket.gethostname()):
414 | node_rank = i
415 | break
416 | assert (
417 | node_rank is not None
418 | ), f"Could not find hostname ({socket.gethostbyname(socket.gethostname())}) in hostlist"
419 |
420 | # launch server
421 | if socket.gethostbyname(socket.gethostname()) == hosts[0]:
422 | server_process = torch.multiprocessing.Process(target=server, args=(args,))
423 | logger.info(f"Launching server ...")
424 | server_process.start()
425 | processes.append(server_process)
426 |
427 | for i in range(num_devices):
428 | local_rank = i
429 | logger.info(f"launching local rank {i}")
430 |
431 | p = torch.multiprocessing.Process(
432 | target=main,
433 | args=(args, node_rank, local_rank, master_port, num_devices),
434 | )
435 | p.start()
436 | processes.append(p)
437 |
438 | for p in processes:
439 | p.join()
440 |
--------------------------------------------------------------------------------
/evaluation/inspect_jsonl.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import json
3 | import numpy as np
4 |
5 | from typing import *
6 | from utils import Logger
7 |
8 |
9 | def main(
10 | data_path: str = "./test.jsonl",
11 | threshold: int = -1,
12 | random: int = 0,
13 | log_path: str = 'inspect_jsonl.txt',
14 | random_rate: float = 0.5,
15 | ):
16 | logger = Logger(__name__, log_file=log_path, log_mode="file", disable_formatter=True)
17 |
18 | n = 0
19 | with open(data_path, "r") as f:
20 | for i, line in enumerate(f):
21 | if i == 0:
22 | logger.info("Data has the following keys")
23 | obj = json.loads(line)
24 | logger.info(obj.keys())
25 | if threshold > 0 and n > threshold:
26 | break
27 | if random and np.random.randint(10) > 10 * random_rate:
28 | continue
29 |
30 | obj = json.loads(line)
31 | n += 1
32 | logger.info(f"========== Sample {i} ==========")
33 | if 'code' in obj:
34 | try:
35 | code_splits = obj['code'].split("\n")
36 | logger.info(f"Length of chars: {len(obj['code'])}, length of lines: {len(code_splits)}.")
37 | except:
38 | pass
39 | for j, k in enumerate(obj.keys()):
40 | logger.info(f"** Key {j}: {k} **")
41 | logger.info(obj[k])
42 | print(f"Log saved in {log_path}")
43 |
44 |
45 | if __name__ == "__main__":
46 | fire.Fire(main)
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import gzip
4 | import torch
5 | import numpy
6 | import random
7 | import logging
8 | import itertools
9 | import numpy as np
10 |
11 | from typing import *
12 |
13 |
14 | LANGUAGE_TAG = {
15 | "c" : "// language: C",
16 | "c++" : "// language: C++",
17 | "cpp" : "// language: C++",
18 | "c#" : "// language: C#",
19 | "csharp" : "// language: C#",
20 | "c-sharp" : "// language: C#",
21 | "css" : "/* language: CSS */",
22 | "cuda" : "// language: Cuda",
23 | "dart" : "// language: Dart",
24 | "lua" : "// language: Lua",
25 | "objectivec" : "// language: Objective-C",
26 | "objective-c" : "// language: Objective-C",
27 | "objective-c++": "// language: Objective-C++",
28 | "python" : "# language: Python",
29 | "perl" : "# language: Perl",
30 | "prolog" : f"% language: Prolog",
31 | "swift" : "// language: swift",
32 | "lisp" : "; language: Lisp",
33 | "java" : "// language: Java",
34 | "scala" : "// language: Scala",
35 | "tex" : f"% language: TeX",
36 | "vue" : "",
37 | "markdown" : "",
38 | "html" : "",
39 | "php" : "// language: PHP",
40 | "js" : "// language: JavaScript",
41 | "javascript" : "// language: JavaScript",
42 | "typescript" : "// language: TypeScript",
43 | "go" : "// language: Go",
44 | "shell" : "# language: Shell",
45 | "rust" : "// language: Rust",
46 | "sql" : "-- language: SQL",
47 | "kotlin" : "// language: Kotlin",
48 | "vb" : "' language: Visual Basic",
49 | "ruby" : "# language: Ruby",
50 | "pascal" : "// language: Pascal",
51 | "r" : "# language: R",
52 | "fortran" : "!language: Fortran",
53 | "lean" : "-- language: Lean",
54 | "matlab" : f"% language: Matlab",
55 | "delphi" : "{language: Delphi}",
56 | "scheme" : "; language: Scheme",
57 | "basic" : "' language: Basic",
58 | "assembly" : "; language: Assembly",
59 | "groovy" : "// language: Groovy",
60 | "abap" : "* language: Abap",
61 | "gdscript" : "# language: GDScript",
62 | "haskell" : "-- language: Haskell",
63 | "julia" : "# language: Julia",
64 | "elixir" : "# language: Elixir",
65 | "excel" : "' language: Excel",
66 | "clojure" : "; language: Clojure",
67 | "actionscript" : "// language: ActionScript",
68 | "solidity" : "// language: Solidity",
69 | "powershell" : "# language: PowerShell",
70 | "erlang" : f"% language: Erlang",
71 | "cobol" : "// language: Cobol",
72 | "alloy" : "/* language: Alloy */",
73 | "awk" : "// language: AWK",
74 | "thrift" : "/* language: Thrift */",
75 | "sparql" : "# language: SPARQL",
76 | "augeas" : "// language: Augeas",
77 | "cmake" : "# language: CMake",
78 | "f-sharp" : "// language: F#",
79 | "stan" : "// language: Stan",
80 | "isabelle" : "(*language: Isabelle*)",
81 | "dockerfile" : "# language: Dockerfile",
82 | "rmarkdown" : "# language: RMarkdown",
83 | "literate-agda": "-- language: Literate Agda",
84 | "tcl" : "// language: Augeas",
85 | "glsl" : "// language: GLSL",
86 | "antlr" : "// language: ANTLR",
87 | "verilog" : "// language: Verilog",
88 | "racket" : "; language: Racket",
89 | "standard-ml" : "(*language:Standard ML*)",
90 | "elm" : "-- language: Elm",
91 | "yaml" : "# language: YAML",
92 | "smalltalk" : "'' language: Smalltalk",
93 | "ocaml" : "(*language: OCaml*)",
94 | "idris" : "-- language: Idris",
95 | "visual-basic" : "' language: Visual Basic",
96 | "protocol-buffer": "// language: Protocol Buffer",
97 | "bluespec" : "// language: Bluespec",
98 | "applescript" : "-- language: AppleScript",
99 | "makefile" : "# language: Makefile",
100 | "tcsh" : "# language: TCSH",
101 | "maple" : "# language: Maple",
102 | "systemverilog": "// language: SystemVerilog",
103 | "literate-coffeescript": "# language: Literate CoffeeScript",
104 | "vhdl" : "-- language: VHDL",
105 | "restructuredtext": ".. language: reStructuredText",
106 | "sas" : "* language: SAS",
107 | "literate-haskell": "> language: Literate Haskell",
108 | "java-server-pages": "// language: Java Server Pages",
109 | "coffeescript" : "# language: CoffeeScript",
110 | "emacs-lisp" : "; language: Emacs Lisp",
111 | "mathematica" : "// language: Mathematica",
112 | "xslt" : "",
113 | "zig" : "// language: Zig",
114 | "common-lisp" : "; language: Common Lisp",
115 | "stata" : "* language: Stata",
116 | "agda" : "-- language: Agda",
117 | "ada" : "-- language: Ada",
118 | }
119 |
120 |
121 | LANGUAGE_COMMENT_SIGN = {}
122 | for lang in LANGUAGE_TAG:
123 | LANGUAGE_COMMENT_SIGN[lang] = LANGUAGE_TAG[lang].split("language:")[0].strip()
124 |
125 |
126 | IMPORT_HELPER = {
127 | "python": [
128 | "import math",
129 | "import re",
130 | "import sys",
131 | "import copy",
132 | "import datetime",
133 | "import itertools",
134 | "import collections",
135 | "import heapq",
136 | "import statistics",
137 | "import functools",
138 | "import hashlib",
139 | "import numpy",
140 | "import numpy as np",
141 | "import string",
142 | "from typing import *",
143 | "from collections import *",
144 | ],
145 | "go" : [
146 | "math",
147 | "strings",
148 | "fmt",
149 | "strconv",
150 | "time",
151 | "bytes",
152 | "regexp",
153 | "sort",
154 | "math/rand",
155 | "crypto/md5",
156 | ],
157 | "cpp" : [
158 | "#include",
159 | "#include",
160 | "#include",
161 | "#include",
162 | "#include",
163 | "#include",
164 | "#include",
165 | "#include",
166 | "#include",
167 | ],
168 | }
169 |
170 |
171 |
172 | def set_random_seed(seed):
173 | """Set random seed for reproducability."""
174 | random.seed(seed)
175 | numpy.random.seed(seed)
176 | torch.manual_seed(seed)
177 |
178 |
179 | def stream_jsonl(filename: str) -> Iterable[Dict]:
180 | """
181 | Parses each jsonl line and yields it as a dictionary
182 | """
183 | if filename.endswith(".gz"):
184 | with open(filename, "rb") as gzfp:
185 | with gzip.open(gzfp, "rt") as fp:
186 | for line in fp:
187 | if any(not x.isspace() for x in line):
188 | yield json.loads(line)
189 | else:
190 | with open(filename, "r") as fp:
191 | for line in fp:
192 | if any(not x.isspace() for x in line):
193 | yield json.loads(line)
194 |
195 |
196 | def stream_jsonl_all(filename: str) -> Iterable[Dict]:
197 | results = []
198 | if filename.endswith(".gz"):
199 | fp = gzip.open(open(filename, "rb"), "rt")
200 | else:
201 | fp = open(filename, "r")
202 | for line in fp:
203 | if any(not x.isspace() for x in line):
204 | results.append(json.loads(line))
205 | fp.close()
206 |
207 | return results
208 |
209 |
210 | def read_dataset(
211 | data_file: str = None,
212 | dataset_type: str = "humanevalx",
213 | ) -> Dict:
214 | if "humanevalx" in dataset_type.lower():
215 | dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
216 | elif "mbpp" in dataset_type.lower():
217 | problems = {task["task_id"]: task for task in stream_jsonl(data_file)}
218 | task_ids = sorted(problems.keys())[10:510]
219 | dataset = {}
220 | for task_id in task_ids:
221 | sample = problems[task_id]
222 | description = sample["text"]
223 | test_example = sample["test_list"][0]
224 | prompt = f'"""\n{description}\n{test_example}\n"""\n'
225 | sample["prompt"] = prompt
226 | dataset[task_id] = sample
227 | elif "ds1000" in dataset_type.lower():
228 | # install ds1000 from https://github.com/HKUNLP/DS-1000
229 | from ds1000 import DS1000Dataset
230 | ds1000 = DS1000Dataset(source_dir=data_file, libs="all", mode="Completion")
231 | for lib in ds1000.libs:
232 | for problem_id in range(len(ds1000[lib])):
233 | prefix = ""
234 | suffix = ""
235 | insert_flag = False
236 | first_line_flag = True
237 | # extract prefix and suffix of the prompt
238 | for line in ds1000[lib][problem_id]["prompt"].split("\n"):
239 | if "[insert]" in line:
240 | insert_flag = True
241 | continue
242 | if first_line_flag:
243 | first_line_flag = False
244 | else:
245 | line = "\n" + line
246 | if not insert_flag:
247 | prefix += line
248 | else:
249 | suffix += line
250 |
251 | else:
252 | raise f"Dataset: {dataset_type} not supported."
253 |
254 | return dataset
255 |
256 |
257 | def read_translation_dataset(
258 | data_file_src: str = None,
259 | data_file_tgt: str = None,
260 | lang_src: str = None,
261 | lang_tgt: str = None,
262 | dataset_type: str = "humanevalx",
263 | ) -> Dict:
264 | if "humanevalx" in dataset_type.lower():
265 | dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)}
266 | dataset_tgt = {task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt)}
267 | for k, sample in dataset_src.items():
268 | prompt = "code translation\n"
269 | if lang_src == "cpp":
270 | prompt += "C++:\n"
271 | elif lang_src == "js":
272 | prompt += "JavaScript:\n"
273 | else:
274 | prompt += f"{lang_src}:\n".capitalize()
275 | prompt += dataset_src[k]["declaration"] + "\n" + dataset_src[k]["canonical_solution"].rstrip() + "\n"
276 | if lang_tgt == "cpp":
277 | prompt += "C++:\n"
278 | elif lang_tgt == "js":
279 | prompt += "JavaScript:\n"
280 | else:
281 | prompt += f"{lang_tgt}:\n".capitalize()
282 | prompt += dataset_tgt[k.split("/")[-1]]["declaration"]
283 | dataset_src[k]["prompt"] = prompt
284 | else:
285 | raise f"Dataset: {dataset_type} not supported."
286 |
287 | return dataset_src
288 |
289 |
290 | def process_extra_prompt(
291 | prompt: str,
292 | language_type: str = "python",
293 | dataset_type: str = None,
294 | generation_mode: str = "completion",
295 | ) -> str:
296 | """
297 | Processes the extra prompt.
298 | """
299 | language = language_type.lower()
300 | if dataset_type == "humanevalx":
301 | extra_prompt = ""
302 | # extra_prompt = LANGUAGE_TAG[language] + "\n"
303 | prompt = prompt.strip()
304 | if generation_mode == "instruction":
305 | return "问:" + extra_prompt + prompt + "\n答:"
306 | return extra_prompt + prompt
307 | elif dataset_type == "mbpp":
308 | extra_prompt = ""
309 | prompt = prompt.strip()
310 | return extra_prompt + prompt
311 | else:
312 | return prompt
313 |
314 |
315 | def is_code_generation_finished(
316 | code: str,
317 | dataset_type: str = None,
318 | language_type: str = None,
319 | ):
320 | """
321 | Checks whether the generated code is finished.
322 | """
323 | if dataset_type == "mbpp":
324 | end_words = ["\ndef", "\nassert"]
325 | for w in end_words:
326 | if w == "\ndef":
327 | if code.count(w) > 1:
328 | return True
329 | else:
330 | if w in code:
331 | return True
332 | else:
333 | if language_type.lower() == "python":
334 | for line in code.split("\n"):
335 | if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
336 | return True
337 | end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
338 | for w in end_words:
339 | if w in code:
340 | return True
341 | elif language_type.lower() == "java":
342 | if code.count("{") + 1 == code.count("}"):
343 | return True
344 | elif language_type.lower() == "go":
345 | if "\nfunc main(" in code:
346 | return True
347 | if code.count("{") + 1 == code.count("}"):
348 | return True
349 | elif language_type.lower() == "js":
350 | if code.count("{") + 1 == code.count("}"):
351 | return True
352 | elif language_type.lower() == "cpp":
353 | if "\nint main()" in code:
354 | return True
355 | if code.count("{") + 1 == code.count("}"):
356 | return True
357 | elif language_type.lower() == "rust":
358 | if "\nfn main()" in code:
359 | return True
360 | if code.count("{") + 1 == code.count("}"):
361 | return True
362 |
363 | return False
364 |
365 |
366 | # Modified from https://github.com/bigcode-project/bigcode-evaluation-harness/blob/main/lm_eval/tasks/mbpp.py
367 | stop_words=["\nclass", "\nassert", '\n"""', "\nprint", "\nif"]
368 | def first_block(string, stop_words):
369 | """Split off first block of code by scanning for class, def etc. on newlines."""
370 | return re.split("|".join(stop_words), string)[0].rstrip()
371 |
372 |
373 | def cleanup_code(
374 | code: str,
375 | dataset_type: str = None,
376 | language_type: str = None,
377 | ):
378 | """
379 | Cleans up the generated code.
380 | """
381 | if dataset_type == "mbpp":
382 | end_words = ["\nassert", "\ndef"]
383 | for w in end_words:
384 | if w == "\ndef":
385 | if code.count(w) > 1:
386 | code = code[:code.rfind(w)]
387 | else:
388 | code = code[:code.rfind(w)]
389 | code = first_block(code, stop_words)
390 | elif dataset_type == "humanevalx":
391 | if language_type.lower() == "python":
392 | code_splits = code.split("\n")
393 | is_empty_line = False
394 | ind_empty_line = None
395 | for i, line in enumerate(code_splits):
396 | if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
397 | is_empty_line = True
398 | ind_empty_line = i
399 | break
400 | if is_empty_line:
401 | code = "\n".join(code_splits[:ind_empty_line])
402 | else:
403 | end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"]
404 | for w in end_words:
405 | if w in code:
406 | code = code[:code.rfind(w)]
407 | elif language_type.lower() == "java":
408 | main_pos = code.find("public static void main")
409 | if main_pos != -1:
410 | code = code[:main_pos] + '}'
411 | if '}' in code:
412 | code = code[:code.rfind('}')] + '}'
413 | if code.count('{') + 1 == code.count('}'):
414 | code += "\n}"
415 | elif language_type.lower() == "go":
416 | if "\nfunc main(" in code:
417 | code = code[:code.rfind("func main(")]
418 | if '}' in code:
419 | code = code[:code.rfind('}')] + '}'
420 | elif language_type.lower() == "cpp":
421 | if "\nint main()" in code:
422 | code = code[:code.rfind("int main()")]
423 | if '}' in code:
424 | code = code[:code.rfind('}')] + '}'
425 | elif language_type.lower() == "js":
426 | if '}' in code:
427 | code = code[:code.rfind('}')] + '}'
428 | elif language_type.lower() == "rust":
429 | if '}' in code:
430 | code = code[:code.rfind('}')] + '}'
431 |
432 | return code
433 |
434 |
435 | def estimate_pass_at_k(
436 | num_samples: Union[int, List[int], np.ndarray],
437 | num_correct: Union[List[int], np.ndarray],
438 | k: int
439 | ) -> np.ndarray:
440 | """
441 | Estimates pass@k of each problem and returns them in an array.
442 | """
443 |
444 | def estimator(n: int, c: int, k: int) -> float:
445 | """
446 | Calculates 1 - comb(n - c, k) / comb(n, k).
447 | """
448 | if n - c < k:
449 | return 1.0
450 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
451 |
452 | if isinstance(num_samples, int):
453 | num_samples_it = itertools.repeat(num_samples, len(num_correct))
454 | else:
455 | assert len(num_samples) == len(num_correct)
456 | num_samples_it = iter(num_samples)
457 |
458 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
459 |
460 |
461 | class Logger:
462 | def __init__(self, name, log_level=logging.INFO, log_file=None, log_mode="both", disable_formatter=False):
463 | self.logger = logging.getLogger(name)
464 | self.logger.setLevel(log_level)
465 |
466 | self.formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
467 |
468 | # Log to console
469 | if log_mode == "both" or log_mode == "terminal":
470 | console_handler = logging.StreamHandler()
471 | if not disable_formatter:
472 | console_handler.setFormatter(self.formatter)
473 | self.logger.addHandler(console_handler)
474 |
475 | # Log to file
476 | if log_file is not None:
477 | if log_mode == "both" or log_mode == "file":
478 | file_handler = logging.FileHandler(log_file, mode='w')
479 | if not disable_formatter:
480 | file_handler.setFormatter(self.formatter)
481 | self.logger.addHandler(file_handler)
482 |
483 | def add_file_handler(self, file_name):
484 | file_handler = logging.FileHandler(file_name, mode='w')
485 | file_handler.setFormatter(self.formatter)
486 | self.logger.addHandler(file_handler)
487 |
488 | def debug(self, message):
489 | self.logger.debug(message)
490 |
491 | def info(self, message):
492 | self.logger.info(message)
493 |
494 | def warning(self, message):
495 | self.logger.warning(message)
496 |
497 | def error(self, message):
498 | self.logger.error(message)
499 |
500 | def critical(self, message):
501 | self.logger.critical(message)
502 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf
2 | transformers>=4.30.2
3 | accelerate
4 | cpm_kernels
5 | torch>=2.0
6 | sentencepiece
7 | gradio
--------------------------------------------------------------------------------
/resources/codegeex_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/resources/codegeex_demo.png
--------------------------------------------------------------------------------
/resources/codegeex_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/resources/codegeex_logo.png
--------------------------------------------------------------------------------
/resources/join_wechat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CodeGeeX2/754a2082356dec1293826c03d2dc4fcc9e48f160/resources/join_wechat.png
--------------------------------------------------------------------------------
/resources/wechat.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
扫码关注公众号加入「CodeGeeX交流群」
5 |
Scan the QR code to join the "CodeGeeX WeChat Group"
6 |
7 |
--------------------------------------------------------------------------------
/scripts/run_humanevalx.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This script is used to generate solutions of HumanEval-X.
3 |
4 | # Examples (MODE=(gen, eval, both)):
5 | # MODE=gen bash ./scripts/run_humanevalx.sh
6 |
7 | if [ -z "$MODE" ]
8 | then
9 | MODE="both"
10 | fi
11 |
12 | SCRIPT_PATH=$(realpath "$0")
13 | SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
14 | MAIN_DIR=$(dirname "$SCRIPT_DIR")
15 |
16 | # enviroment settings
17 | HOSTLIST=$SCRIPT_DIR/hostlist
18 | WORLD_SIZE=1
19 | DATASET=humanevalx
20 | GENERATION_MODE=completion
21 | MODEL_NAME=codegeex2-6b
22 | MODEL_PATH=/pathto/codegeex2-6b/
23 | N_CPU_WORKERS=16
24 | TIMEOUT=5
25 |
26 | # generation settings
27 | ## pass@1 greedy
28 | NUM_SAMPLES=1
29 | MICRO_BSZ=1
30 | TEMP=1.0
31 | TOPK=1
32 | TOPP=1.0
33 | MAX_LENGTH=1024
34 | SEED=42
35 | GREEDY=1
36 |
37 | ## pass@1 estimated
38 | # NUM_SAMPLES=20
39 | # MICRO_BSZ=1
40 | # TEMP=0.2
41 | # TOPK=0
42 | # TOPP=0.95
43 | # MAX_LENGTH=1024
44 | # SEED=42
45 | # GREEDY=0
46 |
47 | ## pass@10 & pass@100
48 | # NUM_SAMPLES=200
49 | # MICRO_BSZ=4
50 | # TEMP=0.8
51 | # TOPK=0
52 | # TOPP=0.95
53 | # MAX_LENGTH=1024
54 | # SEED=42
55 | # GREEDY=0
56 |
57 | for l in python java js cpp go rust;
58 | do
59 | LANGUAGE=$l
60 | DATA_DIR=$MAIN_DIR/benchmark/$DATASET/
61 | DATA_PATH=$DATA_DIR/$DATASET\_$LANGUAGE.jsonl.gz
62 | OUTPUT_PATH=$MAIN_DIR/output/$DATASET/$LANGUAGE
63 | TODAY=$(date +%y%m%d)
64 | CHANNEL_PORT=$(expr $RANDOM + 5000)
65 | MASTER_PORT=$(expr $RANDOM + 8000)
66 | JOB_ID=$MODEL_NAME-$LANGUAGE-greedy$GREEDY-ns$NUM_SAMPLES-t$TEMP-topp$TOPP-seed$SEED
67 | mkdir -p "$OUTPUT_PATH/$JOB_ID"
68 |
69 | # evaluation settings
70 | EVAL_INPUT_PATH=$OUTPUT_PATH/$JOB_ID
71 | EVAL_OUTPUT_PATH=$OUTPUT_PATH/$JOB_ID
72 |
73 | # nccl options
74 | OPTIONS_NCCL="export NCCL_DEBUG=warn; export NCCL_IB_DISABLE=0; export NCCL_IB_GID_INDEX=3"
75 | OPTIONS_PATH="export PATH=$PATH; export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
76 | CWD=$(pwd)
77 |
78 | gen_func() {
79 | echo "Generating......"
80 | # set master ip for zmq server
81 | if [ -z "$HOSTLIST" ]; then
82 | ZMQ_ADDR=$(hostname -i)
83 | echo "$ZMQ_ADDR" > "./hostfile"
84 | HOSTLIST="./hostfile"
85 | else
86 | ZMQ_ADDR=$(cat $HOSTLIST | head -n 1)
87 | fi
88 | echo "master_ip: $ZMQ_ADDR"
89 |
90 | # run generation
91 | RUN_CMD="python \
92 | $MAIN_DIR/evaluation/generation.py \
93 | --hostfile $HOSTLIST \
94 | --channel-ip $ZMQ_ADDR \
95 | --channel-port $CHANNEL_PORT \
96 | --master-port $MASTER_PORT \
97 | --model-path $MODEL_PATH \
98 | --temperature $TEMP \
99 | --top-p $TOPP \
100 | --top-k $TOPK \
101 | --greedy $GREEDY \
102 | --max-length $MAX_LENGTH \
103 | --micro-batch-size $MICRO_BSZ \
104 | --samples-per-problem $NUM_SAMPLES \
105 | --model-name $MODEL_NAME \
106 | --dataset-type $DATASET \
107 | --language-type $LANGUAGE \
108 | --generation-mode $GENERATION_MODE \
109 | --data-path $DATA_PATH \
110 | --output-path $OUTPUT_PATH/$JOB_ID \
111 | --log-path $OUTPUT_PATH/$JOB_ID/$TODAY-generation.log \
112 | --gen-node-world-size $WORLD_SIZE \
113 | --seed $SEED"
114 |
115 | RUN_CMD="$OPTIONS_NCCL; $OPTIONS_PATH; $RUN_CMD"
116 | RUN_CMD="cd $CWD; $RUN_CMD"
117 |
118 | if (( WORLD_SIZE != 1 )); then
119 | RUN_CMD="pdsh -R ssh -w ^$HOSTLIST \"$RUN_CMD\""
120 | fi
121 |
122 | eval "$RUN_CMD"
123 | }
124 |
125 | eval_func() {
126 | echo "Evaluating......"
127 |
128 | if [ $LANGUAGE = rust ]; then
129 | TIMEOUT=300
130 | echo "Setting timeout to $TIMEOUT for Rust"
131 | fi
132 | RUN_CMD="python \
133 | $MAIN_DIR/evaluation/evaluation.py \
134 | --input_path $EVAL_INPUT_PATH \
135 | --output_path $EVAL_OUTPUT_PATH \
136 | --log-path $OUTPUT_PATH/$JOB_ID/$TODAY-evaluation.log \
137 | --model_name $MODEL_NAME \
138 | --language_type $LANGUAGE \
139 | --dataset_type $DATASET \
140 | --generation_mode $GENERATION_MODE \
141 | --n_workers $N_CPU_WORKERS \
142 | --tmp_dir $MAIN_DIR/benchmark/$DATASET/$LANGUAGE \
143 | --problem_file $DATA_PATH \
144 | --timeout $TIMEOUT"
145 |
146 | # inspecting results
147 | INSPECT_CMD="python \
148 | $MAIN_DIR/evaluation/inspect_jsonl.py \
149 | --data_path $EVAL_OUTPUT_PATH/result-$JOB_ID.jsonl \
150 | --log-path $OUTPUT_PATH/$JOB_ID/$TODAY-inspect.txt"
151 |
152 | eval "$RUN_CMD && $INSPECT_CMD"
153 | }
154 |
155 | case $MODE in
156 | "gen")
157 | gen_func
158 | ;;
159 | "eval")
160 | eval_func
161 | ;;
162 | "both")
163 | gen_func
164 | eval_func
165 | ;;
166 | *)
167 | echo "Unsupported MODE (gen, eval, both): $MODE"
168 | exit 1
169 | ;;
170 | esac
171 | done
172 |
--------------------------------------------------------------------------------
/scripts/sanity_check.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This script is used to check the correctness of code generation benchmarks.
3 |
4 | SCRIPT_PATH=$(realpath "$0")
5 | SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
6 | MAIN_DIR=$(dirname "$SCRIPT_DIR")
7 |
8 | # enviroment settings
9 | DATASET=humanevalx
10 | GENERATION_MODE=completion
11 | N_CPU_WORKERS=16
12 | TIMEOUT=5
13 |
14 | # Check HumanEval-X
15 | for l in python java js cpp go rust;
16 | do
17 | LANGUAGE=$l
18 | echo "Evaluating $l"
19 | DATA_DIR=$MAIN_DIR/benchmark/$DATASET/
20 | DATA_PATH=$DATA_DIR/$DATASET\_$LANGUAGE.jsonl.gz
21 | OUTPUT_PATH=$MAIN_DIR/output/$DATASET/$LANGUAGE
22 |
23 | JOB_ID=sanity-check-$LANGUAGE
24 | mkdir -p "$OUTPUT_PATH/$JOB_ID"
25 |
26 | # evaluation settings
27 | EVAL_INPUT_PATH=$DATA_PATH
28 | EVAL_OUTPUT_PATH=$OUTPUT_PATH/$JOB_ID
29 |
30 | if [ $LANGUAGE = rust ]; then
31 | TIMEOUT=300
32 | echo "Setting timeout to $TIMEOUT for Rust"
33 | fi
34 |
35 | RUN_CMD="python \
36 | $MAIN_DIR/evaluation/evaluation.py \
37 | --test_groundtruth=True \
38 | --input_path $EVAL_INPUT_PATH \
39 | --output_path $EVAL_OUTPUT_PATH \
40 | --log-path $OUTPUT_PATH/$JOB_ID/$TODAY-evaluation.log \
41 | --model_name $MODEL_NAME \
42 | --language_type $LANGUAGE \
43 | --dataset_type $DATASET \
44 | --generation_mode $GENERATION_MODE \
45 | --n_workers $N_CPU_WORKERS \
46 | --tmp_dir $MAIN_DIR/benchmark/$DATASET/$LANGUAGE \
47 | --problem_file $DATA_PATH \
48 | --timeout $TIMEOUT"
49 |
50 | eval "$RUN_CMD"
51 | done
52 |
--------------------------------------------------------------------------------